Skip to content

[Feature] Add TritonBF16MoEMethod for BF16 MoE inference#7734

Open
xuanyuanminzheng wants to merge 5 commits into
PaddlePaddle:developfrom
xuanyuanminzheng:develop
Open

[Feature] Add TritonBF16MoEMethod for BF16 MoE inference#7734
xuanyuanminzheng wants to merge 5 commits into
PaddlePaddle:developfrom
xuanyuanminzheng:develop

Conversation

@xuanyuanminzheng
Copy link
Copy Markdown
Collaborator

@xuanyuanminzheng xuanyuanminzheng commented May 7, 2026

Motivation

为 BF16 unquantized MoE 场景新增 Triton 原生 kernel 后端(TritonBF16MoEMethod),通过环境变量 FD_MOE_BACKEND=triton 激活。原有 Cutlass/量化路径无法直接处理 BF16 未量化权重,本 PR 补充该路径,支持更广泛的 BF16 模型推理场景。

Modifications

  • fused_moe_triton_backend.py:新增 TritonBF16MoEMethod 类,继承 QuantMethodBase,实现完整的 BF16 FusedMoE forward 流程(路由 → preprocess → Triton GEMM1 → SwiGLU → Triton GEMM2 + router weight 融合)
  • triton_moe_kernels.py:新增 fused_moe_kernel_bf16 Triton kernel,支持 BF16 累加、int64 token 索引防溢出、router weight 融合乘法(MUL_ROUTED_WEIGHT
  • moe.pyget_moe_method 中新增 FD_MOE_BACKEND=triton 分支,返回 TritonBF16MoEMethod
  • __init__.py:导出 TritonBF16MoEMethod

Usage or Command

export FD_MOE_BACKEND=triton
# 启动推理服务(BF16 MoE 模型)即可自动使用 Triton BF16 后端

Accuracy Tests

image

Checklist

  • [√] Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • [√ ] Format your code, run pre-commit before commit.
  • [√ ] Add unit tests. Please write the reason in this PR if no unit tests.
  • [√ ] Provide accuracy results.
  • [√ ] If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 7, 2026

Thanks for your contribution!

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 7, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-11 16:48:44

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

当前无 required 任务失败。但有 7 个 Workflow 处于 action_required 状态(等待人工审批后才会执行),包含主测试流水线 PR Build and Test,CI 主体尚未运行。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
2(0) 2 1 0 1 0 0

⚠️ 注意:以下 7 个 Workflow 处于 action_required 状态(等待审批后才会执行):ILUVATAR-CI、Approval、Check PR Template、Codestyle-Check、CI_HPU、PR Build and Test、CI_XPU。这些 Workflow 需人工审批触发,主测试任务 Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage 包含其中,当前尚未执行。

注意:action_required workflows 不计入上表的任务统计。


2 任务状态汇总

2.1 Required 任务 : 0/0 通过

当前 GitHub Branch Protection Rules 未检测到 Required 任务(可能尚未触发或 API 权限不足),所有任务均标记为可选。

暂无 Required 任务记录。

2.2 可选任务 — 1/2 通过

可选任务不阻塞合并,仅供参考。

状态 任务 耗时 日志 重跑
Trigger Jenkins for PR - Job -
其余 1 个可选任务通过 - - -

3 失败详情(仅 required)

无 required 失败任务。

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 为 BF16 未量化 MoE 场景新增 Triton 原生 kernel 后端(TritonBF16MoEMethod),并通过环境变量 FD_MOE_BACKEND=triton 在 CUDA 平台启用,以补齐现有 Cutlass/量化路径无法覆盖的 BF16 unquantized 推理链路。

Changes:

  • 新增 BF16 Triton MoE kernel:fused_moe_kernel_bf16,支持 BF16 计算/累加与 routed-weight 融合乘法。
  • 新增 TritonBF16MoEMethod:实现 BF16 FusedMoE forward(routing → preprocess → GEMM1 → SwiGLU → GEMM2(+router weight) → topk reduce)。
  • 扩展 get_moe_method__init__.py 导出,并补充单测/精度对比测试。

建议同时确认是否需要在文档(如环境变量说明)中补充/强调 FD_MOE_BACKEND=triton依赖条件(Triton 可用 + BF16 dtype) 与适用范围。

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/layers/test_fused_moe_triton_backend.py 新增 TritonBF16MoEMethod 单测与 Triton vs Cutlass BF16 精度对比测试
fastdeploy/model_executor/layers/moe/triton_moe_kernels.py 新增 BF16 Triton fused MoE GEMM kernel(fused_moe_kernel_bf16
fastdeploy/model_executor/layers/moe/moe.py 增加 FD_MOE_BACKEND=triton 分支选择 TritonBF16MoEMethod
fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py 新增 TritonBF16MoEMethod 实现与 Triton/ops 导入
fastdeploy/model_executor/layers/moe/init.py 导出 TritonBF16MoEMethod

Comment on lines +279 to +283
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

# A pointer: a_ptr[token_idx, :K] where token_idx = offs_token // top_k
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
Comment on lines +296 to +300
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0,
)
Comment on lines +37 to +42
try:
import triton.language as tl

from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func

from .triton_moe_kernels import fused_moe_kernel_paddle
from .triton_moe_kernels import fused_moe_kernel_bf16, fused_moe_kernel_paddle
Comment on lines +2059 to +2062
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type=tl.bfloat16,
)
Comment on lines 56 to +60
if current_platform.is_cuda():
moe_backend = envs.FD_MOE_BACKEND.lower()
if moe_backend == "triton":
from .fused_moe_triton_backend import TritonBF16MoEMethod

Comment on lines +1265 to +1269
@pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA")
class TestTritonBF16MoEPrecision:
"""
Precision tests: Triton BF16 path vs. Cutlass BF16 path.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 7, 2026

Codecov Report

❌ Patch coverage is 59.67742% with 50 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@78b5462). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...oy/model_executor/layers/moe/triton_moe_kernels.py 6.66% 42 Missing ⚠️
...el_executor/layers/moe/fused_moe_triton_backend.py 93.24% 3 Missing and 2 partials ⚠️
fastdeploy/model_executor/layers/moe/moe.py 25.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7734   +/-   ##
==========================================
  Coverage           ?   71.54%           
==========================================
  Files              ?      396           
  Lines              ?    55828           
  Branches           ?     8727           
==========================================
  Hits               ?    39944           
  Misses             ?    13138           
  Partials           ?     2746           
Flag Coverage Δ
GPU 71.54% <59.67%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


# --- 1. Routing ---
gate_out = gate(x)
gate_out = gate_out.cast("float32")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qingqing01
qingqing01 previously approved these changes May 8, 2026
Comment on lines +1908 to +1927
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Reuse UnquantizedFusedMoEMethod weight creation logic.
Weight shapes on CUDA (non-torch format):
up_gate_proj_weight: [E, hidden_size, moe_intermediate_size * 2] (K-major)
down_proj_weight: [E, moe_intermediate_size, hidden_size] (K-major)
The Triton kernel reads B as [E, K, N] which maps directly to these shapes.
"""
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
UnquantizedFusedMoEMethod,
)

UnquantizedFusedMoEMethod.create_weights(self, layer, **extra_weight_attrs)

def process_weights_after_loading(self, layer: nn.Layer):
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
UnquantizedFusedMoEMethod,
)

UnquantizedFusedMoEMethod.process_weights_after_loading(self, layer)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TritonBF16MoEMethod是不是可以继承自UnquantizedFusedMoEMethod呢?感觉这样会更好

)


class TritonBF16MoEMethod(QuantMethodBase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

个人建议不需要强调是BF16,感觉推理场景中默认的理解就是BF16精度,直接就是TritonMoEMethod呢?

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-11 16:48:22

📋 Review 摘要

PR 概述:为 BF16 未量化 MoE 场景新增 Triton 原生 kernel 后端(TritonMoEMethod),通过环境变量 FD_MOE_BACKEND=triton 激活,并修复了 TritonWeightOnlyMoEMethod 中 stride 索引顺序的既有 bug。
变更范围fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.pytriton_moe_kernels.pymoe.py__init__.py)、tests/layers/
影响面 Tag[Feature] [OP]

📝 PR 规范检查

PR 标题格式合规([Feature] 为官方 Tag),描述模板结构完整(Motivation / Modifications / Usage / Accuracy Tests / Checklist 均已填写)。

唯一格式提醒:Checklist 中使用了 [√] 标记,GitHub Markdown 的任务列表仅识别 [x]/[X][√] 实际上不会渲染为"已勾选"状态,建议改为 [x]

问题

级别 文件 概述
🟡 建议 fused_moe_triton_backend.py:50 except: 取代 except ImportError:,静默吞掉一切异常
❓ 疑问 fused_moe_triton_backend.py:2128 EP 路径抛 NotImplementedError,但 get_moe_method 无前置 EP 保护,运行时报错不友好

总体评价

整体实现思路清晰,pipeline 设计(routing → preprocess → GEMM1 → SwiGLU → GEMM2)与 vLLM 对齐,naive_block_assignment decode 优化路径完整,单测覆盖度较高。建议在合入前修复裸 except: 回退问题,并补充 EP 保护判断,以保证异常场景的可调试性。

from .triton_moe_kernels import fused_moe_kernel_paddle
except ImportError:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
except:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议except: 替代了原先的 except ImportError:,会捕获所有异常(包括 SystemExitKeyboardInterrupt),导致非 ImportError 类型的真实错误被静默吞掉,极难调试。

建议改回 except ImportError: 或至少 except Exception:,并在必要时补充日志:

try:
    import triton.language as tl
    from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
except ImportError:
    pass

def apply_ep_prefill(
self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None
):
raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 TritonMoEMethod 对 EP 路径直接抛出 NotImplementedError,但 get_moe_method 里没有 EP 保护判断。当用户同时设置 FD_MOE_BACKEND=triton 且开启 EP(ep_size > 1)时,基类 apply() 会路由到 apply_ep_prefill 并在运行时抛出错误,错误信息不够友好。

建议在 get_moe_method 的 triton 分支加入 EP 检查,提前报错:

if moe_backend == "triton":
    if layer is not None and getattr(layer, 'ep_size', 1) > 1:
        raise ValueError("FD_MOE_BACKEND=triton does not support EP (ep_size > 1) yet.")
    from .fused_moe_triton_backend import TritonMoEMethod
    return TritonMoEMethod(None)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants