-
Notifications
You must be signed in to change notification settings - Fork 743
[Feature] Add TritonBF16MoEMethod for BF16 MoE inference #7734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
46fdad2
f1b5847
21923b0
e308032
f643723
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,17 @@ | |
| from paddle import nn | ||
|
|
||
| import fastdeploy | ||
| from fastdeploy.model_executor.layers.moe.moe import get_moe_scores | ||
| from fastdeploy.model_executor.layers.moe.triton_moe_kernels import ( | ||
| fused_moe_kernel_bf16, | ||
| fused_moe_kernel_paddle, | ||
| ) | ||
| from fastdeploy.model_executor.layers.quantization.fp8_utils import ( | ||
| fused_stack_transpose_quant, | ||
| quant_weight_ue8m0, | ||
| transform_scale_ue8m0, | ||
| ) | ||
| from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant | ||
| from fastdeploy.model_executor.layers.utils import get_tensor | ||
| from fastdeploy.model_executor.utils import ( | ||
| TensorTracker, | ||
|
|
@@ -32,21 +43,15 @@ | |
| from fastdeploy.platforms import current_platform | ||
| from fastdeploy.utils import ceil_div, register_custom_python_op | ||
|
|
||
| from ..quantization.quant_base import QuantMethodBase | ||
|
|
||
| try: | ||
| from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func | ||
| import triton.language as tl | ||
|
|
||
| from .triton_moe_kernels import fused_moe_kernel_paddle | ||
| except ImportError: | ||
| from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func | ||
| except: | ||
| pass | ||
| from fastdeploy.model_executor.layers.moe.moe import get_moe_scores | ||
| from fastdeploy.model_executor.layers.quantization.fp8_utils import ( | ||
| fused_stack_transpose_quant, | ||
| quant_weight_ue8m0, | ||
| transform_scale_ue8m0, | ||
| ) | ||
| from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant | ||
|
|
||
| from ..quantization.quant_base import QuantMethodBase | ||
| from .fused_moe_backend_base import UnquantizedFusedMoEMethod | ||
|
|
||
|
|
||
| class TritonWeightOnlyMoEMethod(QuantMethodBase): | ||
|
|
@@ -780,8 +785,8 @@ def apply( | |
| stride_am=x_q.strides[0], | ||
| stride_ak=x_q.strides[1], | ||
| stride_be=layer.up_gate_proj_weight.strides[0], | ||
| stride_bk=layer.up_gate_proj_weight.strides[2], | ||
| stride_bn=layer.up_gate_proj_weight.strides[1], | ||
| stride_bk=layer.up_gate_proj_weight.strides[1], | ||
| stride_bn=layer.up_gate_proj_weight.strides[2], | ||
| stride_cm=up_gate_proj_out.strides[0], | ||
| stride_cn=up_gate_proj_out.strides[1], | ||
| # | ||
|
|
@@ -1885,3 +1890,244 @@ def apply( | |
| fc1_latent_proj, | ||
| fc2_latent_proj, | ||
| ) | ||
|
|
||
|
|
||
| class TritonMoEMethod(UnquantizedFusedMoEMethod): | ||
| """ | ||
| Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. | ||
|
|
||
| Activated via: export FD_MOE_BACKEND=triton | ||
| Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. | ||
| This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. | ||
| """ | ||
|
|
||
| def __init__(self, quant_config=None): | ||
| super().__init__(quant_config) | ||
|
|
||
| def process_loaded_weights(self, layer: nn.Layer, state_dict): | ||
| """Stack individual expert weights into the stacked parameter.""" | ||
| up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) | ||
| layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) | ||
| layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) | ||
|
|
||
| def _get_default_config(self, M: int, N: int, K: int, num_experts: int = 64) -> dict: | ||
| """ | ||
| Heuristic tile config for BF16 MoE, aligned with vLLM's get_default_config logic. | ||
| M: number of token-expert pairs | ||
| N: output dimension of the GEMM | ||
| K: input dimension of the GEMM | ||
| num_experts: number of local experts (for GROUP_SIZE_M heuristic) | ||
| """ | ||
| if M <= 32: | ||
| block_m, block_n, block_k = 16, 64, 128 | ||
| num_warps, num_stages = 4, 4 | ||
| elif M <= 96: | ||
| block_m, block_n, block_k = 32, 64, 128 | ||
| num_warps, num_stages = 4, 3 | ||
| elif M <= 512: | ||
| block_m, block_n, block_k = 64, 128, 64 | ||
| num_warps, num_stages = 8, 3 | ||
| else: | ||
| block_m, block_n, block_k = 128, 128, 64 | ||
| num_warps, num_stages = 8, 3 | ||
|
|
||
| tokens_per_expert = M // max(num_experts, 1) | ||
| group_m = 16 if tokens_per_expert > 128 else 1 | ||
|
|
||
| return { | ||
| "BLOCK_SIZE_M": block_m, | ||
| "BLOCK_SIZE_N": block_n, | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| "BLOCK_SIZE_K": block_k, | ||
| "GROUP_SIZE_M": group_m, | ||
| "num_warps": num_warps, | ||
| "num_stages": num_stages, | ||
| } | ||
|
|
||
| def apply_tp( | ||
| self, | ||
| layer: nn.Layer, | ||
| x: paddle.Tensor, | ||
| gate: nn.Layer, | ||
| topk_ids_hookfunc: Callable = None, | ||
| fc1_latent_proj: nn.Layer = None, | ||
| fc2_latent_proj: nn.Layer = None, | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| BF16 Triton Fused MoE forward. | ||
|
|
||
| Pipeline: | ||
| 1. Gate + topk routing | ||
| 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded | ||
| 3. fused_moe_kernel_bf16 GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] | ||
| 4. SwiGLU activation | ||
| 5. fused_moe_kernel_bf16 GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] | ||
| (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) | ||
| 6. Reshape + sum over topk dim | ||
| """ | ||
| token_num = x.shape[0] | ||
| if token_num == 0: | ||
| return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) | ||
|
|
||
| top_k = layer.top_k | ||
| num_local_experts = layer.num_local_experts | ||
| moe_intermediate_size = layer.moe_intermediate_size | ||
| hidden_size = layer.hidden_size | ||
|
|
||
| # --- 1. Routing --- | ||
| gate_out = gate(x) | ||
|
|
||
| if layer.topk_method == "noaux_tc": | ||
| use_fused = not fastdeploy.envs.FD_ENABLE_RL and current_platform.is_cuda() | ||
| if not use_fused: | ||
| gate_out = gate_out.cast("float32") | ||
|
|
||
| _, topk_weights, topk_ids = get_moe_scores( | ||
| gate_out, | ||
| layer.n_group, | ||
| layer.topk_group, | ||
| top_k, | ||
| layer.routed_scaling_factor, | ||
| layer.gate_correction_bias, | ||
| getattr(layer, "renormalize", True), | ||
| use_fused_cast=use_fused, | ||
| topk_reduce_func=getattr(layer, "topk_reduce_func", None), | ||
| ) | ||
| else: | ||
| gate_out = gate_out.cast("float32") | ||
| topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( | ||
| gate_out, | ||
| layer.gate_correction_bias, | ||
| top_k, | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| True, # apply_norm_weight | ||
| False, | ||
| ) | ||
|
|
||
| if topk_ids_hookfunc is not None: | ||
| topk_ids_hookfunc(topk_ids=topk_ids) | ||
|
|
||
| # --- 2. Preprocess: sort tokens by expert assignment --- | ||
| num_token_expert_pairs = token_num * top_k | ||
| cfg = self._get_default_config( | ||
| num_token_expert_pairs, moe_intermediate_size * 2, hidden_size, num_local_experts | ||
| ) | ||
|
|
||
| # Use naive_block_assignment when token count is very small (decode scenario). | ||
| # Each M-block handles exactly one token-expert pair, skipping the expensive | ||
| # preprocess sort kernel. | ||
| _SPARSITY_FACTOR = 4 | ||
| use_naive = num_token_expert_pairs * _SPARSITY_FACTOR <= num_local_experts | ||
|
|
||
| if use_naive: | ||
| expert_ids = topk_ids.reshape([-1]).cast("int32") | ||
| num_tokens_post_padded = paddle.full([1], num_token_expert_pairs * cfg["BLOCK_SIZE_M"], dtype="int32") | ||
| max_possible_num_post_padded = num_token_expert_pairs * cfg["BLOCK_SIZE_M"] | ||
| sorted_token_ids = expert_ids | ||
| else: | ||
| sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( | ||
| topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] | ||
| ) | ||
| max_possible_num_post_padded = sorted_token_ids.shape[0] | ||
|
|
||
| # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- | ||
| # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn | ||
| up_gate_proj_out = paddle.empty( | ||
| [num_token_expert_pairs, moe_intermediate_size * 2], | ||
| dtype=x.dtype, | ||
| ) | ||
| grid1 = ( | ||
| ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) | ||
| * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), | ||
| ) | ||
| fused_moe_kernel_bf16[grid1]( | ||
| x, | ||
| layer.up_gate_proj_weight, | ||
| up_gate_proj_out, | ||
| None, # topk_weights_ptr (no weight mul on GEMM1) | ||
| sorted_token_ids, | ||
| expert_ids, | ||
| num_tokens_post_padded, | ||
| N=moe_intermediate_size * 2, | ||
| K=hidden_size, | ||
| EM=max_possible_num_post_padded, | ||
| num_valid_tokens=num_token_expert_pairs, | ||
| stride_am=x.strides[0], | ||
| stride_ak=x.strides[1], | ||
| stride_be=layer.up_gate_proj_weight.strides[0], | ||
| stride_bk=layer.up_gate_proj_weight.strides[1], | ||
| stride_bn=layer.up_gate_proj_weight.strides[2], | ||
| stride_cm=up_gate_proj_out.strides[0], | ||
| stride_cn=up_gate_proj_out.strides[1], | ||
| BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], | ||
| BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], | ||
| BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], | ||
| GROUP_SIZE_M=cfg["GROUP_SIZE_M"], | ||
| MUL_ROUTED_WEIGHT=False, | ||
| top_k=top_k, | ||
| compute_type=tl.bfloat16, | ||
| naive_block_assignment=use_naive, | ||
| even_Ks=(hidden_size % cfg["BLOCK_SIZE_K"] == 0), | ||
| num_warps=cfg["num_warps"], | ||
| num_stages=cfg["num_stages"], | ||
| ) | ||
|
Comment on lines
+2064
to
+2071
|
||
|
|
||
| # --- 4. SwiGLU activation --- | ||
| down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) | ||
|
|
||
| # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- | ||
| if not topk_weights.is_contiguous(): | ||
| topk_weights = topk_weights.contiguous() | ||
|
|
||
| # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn | ||
| down_proj_out = paddle.empty( | ||
| (num_token_expert_pairs, hidden_size), | ||
| dtype=x.dtype, | ||
| ) | ||
| grid2 = ( | ||
| ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg["BLOCK_SIZE_N"]), | ||
| ) | ||
| fused_moe_kernel_bf16[grid2]( | ||
| down_proj_input, | ||
| layer.down_proj_weight, | ||
| down_proj_out, | ||
| topk_weights, | ||
| sorted_token_ids, | ||
| expert_ids, | ||
| num_tokens_post_padded, | ||
| N=hidden_size, | ||
| K=moe_intermediate_size, | ||
| EM=max_possible_num_post_padded, | ||
| num_valid_tokens=num_token_expert_pairs, | ||
| stride_am=down_proj_input.strides[0], | ||
| stride_ak=down_proj_input.strides[1], | ||
| stride_be=layer.down_proj_weight.strides[0], | ||
| stride_bk=layer.down_proj_weight.strides[1], | ||
| stride_bn=layer.down_proj_weight.strides[2], | ||
| stride_cm=down_proj_out.strides[0], | ||
| stride_cn=down_proj_out.strides[1], | ||
| BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], | ||
| BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], | ||
| BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], | ||
| GROUP_SIZE_M=cfg["GROUP_SIZE_M"], | ||
| MUL_ROUTED_WEIGHT=True, | ||
| top_k=1, | ||
| compute_type=tl.bfloat16, | ||
| naive_block_assignment=use_naive, | ||
| even_Ks=(moe_intermediate_size % cfg["BLOCK_SIZE_K"] == 0), | ||
| num_warps=cfg["num_warps"], | ||
| num_stages=cfg["num_stages"], | ||
| ) | ||
|
|
||
| # --- 6. Reduce over topk --- | ||
| down_proj_out.reshape_([token_num, top_k, hidden_size]) | ||
| out = down_proj_out.sum(axis=1) | ||
| return out | ||
|
|
||
| def apply_ep_prefill( | ||
| self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None | ||
| ): | ||
| raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 建议在 if moe_backend == "triton":
if layer is not None and getattr(layer, 'ep_size', 1) > 1:
raise ValueError("FD_MOE_BACKEND=triton does not support EP (ep_size > 1) yet.")
from .fused_moe_triton_backend import TritonMoEMethod
return TritonMoEMethod(None) |
||
|
|
||
| def apply_ep_decode( | ||
| self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None | ||
| ): | ||
| raise NotImplementedError("TritonMoEMethod does not support EP decode yet.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议 裸
except:替代了原先的except ImportError:,会捕获所有异常(包括SystemExit、KeyboardInterrupt),导致非 ImportError 类型的真实错误被静默吞掉,极难调试。建议改回
except ImportError:或至少except Exception:,并在必要时补充日志: