Skip to content

[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776

Open
XiaomingFun233 wants to merge 4 commits intoNVIDIA:mainfrom
XiaomingFun233:pr/fused-router-topk-opt
Open

[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776
XiaomingFun233 wants to merge 4 commits intoNVIDIA:mainfrom
XiaomingFun233:pr/fused-router-topk-opt

Conversation

@XiaomingFun233
Copy link

Summary

This PR ports and keeps a focused set of CUDA fused-router optimizations that showed consistent gains on the tested workload, while avoiding heavier variants that regressed performance.

1. Add fused-router performance benchmark test

  • Add tests/pytorch/test_fused_router_perf.py.
  • Benchmark coverage:
    • fused_topk_with_score_function
    • fused_compute_score_for_moe_aux_loss
    • fused_moe_aux_loss

2. Keep low-risk fused-router CUDA optimizations

  • transformer_engine/common/fused_router/utils.h
    • Add warp-level sum helper used in backward normalization path.
  • transformer_engine/common/fused_router/fused_topk_with_score_function.cu
    • Use warp-level sum reduction in backward normalization.
    • Add safe expert_bias.has_data() handling in forward to avoid invalid dtype switch when bias is absent.
  • transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
    • Use warp-level sum reduction in backward normalization.

3. Optimize naive_topk_and_mask for small-k

  • transformer_engine/common/fused_router/utils.h
    • Add lightweight specialization for topk <= 8.
    • Keep generic fallback for compatibility.

Performance (A/B)

Measured with:

  • TE_RUN_PERF_TESTS=1 pytest -q tests/pytorch/test_fused_router_perf.py -s

Before

  • topk_router[softmax]: fused 0.029562 ms, speedup 8.3067x
  • topk_router[sigmoid]: fused 0.030138 ms, speedup 7.2715x
  • scores_for_aux_loss[softmax]: fused 0.026183 ms, speedup 3.8721x
  • scores_for_aux_loss[sigmoid]: fused 0.025872 ms, speedup 3.8892x
  • moe_aux_loss: fused 0.015680 ms, speedup 1.8884x

After

  • topk_router[softmax]: fused 0.022384 ms, speedup 11.1324x
  • topk_router[sigmoid]: fused 0.022840 ms, speedup 9.7714x
  • scores_for_aux_loss[softmax]: fused 0.017230 ms, speedup 5.9707x
  • scores_for_aux_loss[sigmoid]: fused 0.017049 ms, speedup 6.0205x
  • moe_aux_loss: fused 0.015412 ms, speedup 1.8424x

Notes

  • This PR intentionally avoids the larger full-port variant that previously regressed topk_router/scores_for_aux_loss performance on this setup.

XiaomingFun233 and others added 4 commits March 18, 2026 06:46
- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward
Add a lightweight register-based small-k path and keep the generic fallback for compatibility.
Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels.
@XiaomingFun233
Copy link
Author

Test on H200 ,CUDA version 13.0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 18, 2026

Greptile Summary

This PR makes three targeted improvements to the fused-router CUDA kernels — a crash-fix for the expert_bias-absent forward path, a warp-reduction refactor used in two backward kernels, and a compile-time–specialised top-k selection for topk ≤ 8 — and adds a performance benchmark test suite covering all three fused-router operations.

Key changes and findings:

  • Bug fix (fused_topk_with_score_function_forward): The previous code unconditionally called TE_ROUTER_PROBS_TYPE_SWITCH_ALL(expert_bias.data.dtype, ...) even when expert_bias was empty, potentially hitting an unsupported-dtype error. The new has_data() guard correctly falls through to a nullptr-bias instantiation when no bias is supplied.
  • Warp reduction refactor (warp_reduce_sum_float): Both backward kernels previously inlined a __shfl_xor_sync butterfly loop. This is replaced with a shared __shfl_down_sync-based helper. The helper correctly accumulates the total sum in lane 0 only, then broadcasts via __shfl_sync(…, 0) — the broadcast is load-bearing and should be documented.
  • Small-k topk specialisation (naive_topk_and_mask_smallk<K>): A template variant is introduced for K ∈ [1, 8]. The outer-loop #pragma unroll combined with a compile-time K allows the masking inner loop to be fully unrolled, avoiding runtime loop-bound checks. The XOR-reduction for argmax is logically equivalent to the generic version. OOB threads initialise index = 0 (a valid data index), but can never propagate that value to the output since they carry val = -inf, which can never win any comparison.
  • Performance benchmark (test_fused_router_perf.py): Correctness is verified by comparing against reference PyTorch implementations before timing. However, the timing assertions only check ms > 0; no minimum speedup threshold is enforced, so performance regressions will not be caught in CI.

Confidence Score: 4/5

  • Safe to merge with minor suggestions — no correctness bugs found in CUDA paths, and the forward crash-fix improves robustness.
  • The CUDA kernel changes are logically correct: warp_reduce_sum_float is semantically equivalent to the previous __shfl_xor_sync butterfly, and naive_topk_and_mask_smallk produces the same top-k results as the generic path. The has_data() guard fixes a real latent crash. The score is not 5 because the performance benchmark lacks minimum-speedup assertions (making it unable to catch regressions) and has a fragile module-level seed setup.
  • tests/pytorch/test_fused_router_perf.py — missing speedup assertions and fragile global seed setup. transformer_engine/common/fused_router/utils.h — the mandatory lane-0 broadcast in warp_reduce_sum_float should be documented.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Adds warp_reduce_sum_float (correct __shfl_down_sync + broadcast pattern), introduces naive_topk_and_mask_smallk<K> template for K∈[1,8] with unrolled masking, and restructures naive_topk_and_mask as a dispatch wrapper. Logic is sound; minor readability concerns around the OOB index=0 sentinel and missing comment on the mandatory lane-0 broadcast.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Forward path now safely checks expert_bias.has_data() before accessing data.dtype, fixing a latent crash when no bias is provided. Backward normalisation reduction replaced with warp_reduce_sum_float; semantics are equivalent to the old __shfl_xor_sync butterfly.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Backward normalisation reduction replaced with warp_reduce_sum_float; the change is a clean refactor with equivalent semantics and no functional impact.
tests/pytorch/test_fused_router_perf.py New perf benchmark covering three fused-router ops. Correctness is verified against reference PyTorch implementations. Key concern: performance assertions only check ms > 0, not minimum speedup thresholds, so regressions would go undetected in CI.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fused_topk_with_score_function_forward] --> B{expert_bias.has_data?}
    B -->|Yes| C[TE_ROUTER_PROBS_TYPE_SWITCH_ALL on bias dtype\nLaunches kernel with BiasType template]
    B -->|No - NEW| D[Launch kernel with DataType as BiasType\nexpert_bias ptr = nullptr]
    C --> E[fused_topk_with_score_function_forward_kernel]
    D --> E
    E --> F{score_function}
    F -->|softmax| G[apply_softmax_on_float]
    F -->|sigmoid| H[apply_sigmoid_on_float]
    F -->|sqrtsoftplus| I[apply_sqrtsoftplus_on_float]
    G & H & I --> J{group_topk > 0?}
    J -->|Yes| K[group_limited_topk path]
    J -->|No| L[naive_topk_and_mask dispatcher - NEW]
    K --> L
    L --> M{topk <= 8?}
    M -->|Yes| N[naive_topk_and_mask_smallk K - unrolled]
    M -->|No| O[naive_topk_and_mask_generic - original]
    N & O --> P[Write probs / routing_map to global mem]

    subgraph Backward [Backward - Both kernels]
        Q[Accumulate local_sum_Output_x_Grad per lane]
        Q --> R[warp_reduce_sum_float - NEW\n__shfl_down_sync + broadcast from lane 0]
        R --> S[In-place grad update]
    end
Loading

Last reviewed commit: "[pre-commit.ci] auto..."

Comment on lines +198 to +201
)
def test_fused_topk_router_perf_against_torch(
score_function, use_pre_softmax, enable_bias, record_property
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No minimum speedup assertion

The benchmark tests only verify that timings are positive (torch_ms > 0, fused_ms > 0) and confirm numerical correctness, but never assert that the fused kernel actually outperforms the PyTorch baseline. Any performance regression will silently pass CI while still printing a degraded speedup in the output.

Consider adding a minimum speedup threshold, for example:

Suggested change
)
def test_fused_topk_router_perf_against_torch(
score_function, use_pre_softmax, enable_bias, record_property
):
assert torch_ms > 0, _perf_assert_message(f"topk_router[{score_function}]", torch_ms, fused_ms)
assert fused_ms > 0, _perf_assert_message(f"topk_router[{score_function}]", torch_ms, fused_ms)
MIN_SPEEDUP = 2.0
assert torch_ms / fused_ms >= MIN_SPEEDUP, (
f"topk_router[{score_function}] speedup {torch_ms/fused_ms:.2f}x < required {MIN_SPEEDUP}x. "
+ _perf_assert_message(f"topk_router[{score_function}]", torch_ms, fused_ms)
)

The same gap exists in test_fused_scores_for_aux_loss_perf_against_torch (line 270–275) and test_fused_moe_aux_loss_perf_against_torch (line 345–348).

Comment on lines +41 to +46
__device__ inline float warp_reduce_sum_float(float val) {
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Broadcast from lane 0 should be documented

__shfl_down_sync accumulates the sum only in lane 0 after all steps — unlike the __shfl_xor_sync butterfly approach (used elsewhere in this file) which gives the correct sum to every lane simultaneously. The subsequent __shfl_sync(…, 0) is therefore load-bearing for correctness, not just an optimisation.

Adding a short comment here prevents a future reader from accidentally removing it thinking it's redundant:

Suggested change
__device__ inline float warp_reduce_sum_float(float val) {
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}
__device__ inline float warp_reduce_sum_float(float val) {
// __shfl_down_sync accumulates the total only in lane 0;
// the broadcast below is required for all lanes to see the result.
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}

Comment on lines +224 to +248
#pragma unroll
for (int k = 0; k < K; ++k) {
CompType val = -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
bool masked = false;
#pragma unroll
for (int j = 0; j < k; ++j) {
masked |= (selected[j] == i);
}
if (masked) continue;
CompType cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 OOB thread index initialised to 0 could shadow earlier selections

For threads where lane_id >= data_size, the inner loop body never executes, so val stays at -inf and index is set to 0:

int index = (lane_id < data_size) ? lane_id : 0;

0 is a valid data index that may already have been placed in selected by a previous k iteration. During the XOR-reduction phase, these OOB threads participate (they shuffle -inf values, which can never win the shuffled_val > val comparison), so the final chosen_index remains correct. However, after the broadcast:

selected[k] = chosen_index;

every thread — including OOB ones — writes chosen_index to their register copy of selected, keeping all threads in sync. The algorithm is therefore correct, but initialising the fallback index to a sentinel value (e.g., -1 or data_size - 1) would make the intent clearer and avoid confusion with a real element:

Suggested change
#pragma unroll
for (int k = 0; k < K; ++k) {
CompType val = -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
bool masked = false;
#pragma unroll
for (int j = 0; j < k; ++j) {
masked |= (selected[j] == i);
}
if (masked) continue;
CompType cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
int index = (lane_id < data_size) ? lane_id : -1; // -1: sentinel for out-of-range lane

This is purely a readability / defensive-programming concern given that invalid index values can never propagate to the output.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +17 to +21

seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level seed is fragile for tokens_per_expert randomness

The random seed is set once at import time. If pytest collects or runs other tests before this module's tests execute, the global random state will have advanced and test_fused_moe_aux_loss_perf_against_torch (which calls torch.randint for tokens_per_expert) will use an unknown seed. While the numerical correctness check in that test (torch.testing.assert_close(torch_loss, fused_loss)) passes regardless of the specific random values, reproducible benchmarks are easier to debug.

Consider moving the seed setup into each individual test function, or using a pytest fixture to ensure a consistent state per test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant