[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776
[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776XiaomingFun233 wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward
Add a lightweight register-based small-k path and keep the generic fallback for compatibility.
Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels.
for more information, see https://pre-commit.ci
|
Test on H200 ,CUDA version 13.0 |
Greptile SummaryThis PR makes three targeted improvements to the fused-router CUDA kernels — a crash-fix for the Key changes and findings:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fused_topk_with_score_function_forward] --> B{expert_bias.has_data?}
B -->|Yes| C[TE_ROUTER_PROBS_TYPE_SWITCH_ALL on bias dtype\nLaunches kernel with BiasType template]
B -->|No - NEW| D[Launch kernel with DataType as BiasType\nexpert_bias ptr = nullptr]
C --> E[fused_topk_with_score_function_forward_kernel]
D --> E
E --> F{score_function}
F -->|softmax| G[apply_softmax_on_float]
F -->|sigmoid| H[apply_sigmoid_on_float]
F -->|sqrtsoftplus| I[apply_sqrtsoftplus_on_float]
G & H & I --> J{group_topk > 0?}
J -->|Yes| K[group_limited_topk path]
J -->|No| L[naive_topk_and_mask dispatcher - NEW]
K --> L
L --> M{topk <= 8?}
M -->|Yes| N[naive_topk_and_mask_smallk K - unrolled]
M -->|No| O[naive_topk_and_mask_generic - original]
N & O --> P[Write probs / routing_map to global mem]
subgraph Backward [Backward - Both kernels]
Q[Accumulate local_sum_Output_x_Grad per lane]
Q --> R[warp_reduce_sum_float - NEW\n__shfl_down_sync + broadcast from lane 0]
R --> S[In-place grad update]
end
Last reviewed commit: "[pre-commit.ci] auto..." |
| ) | ||
| def test_fused_topk_router_perf_against_torch( | ||
| score_function, use_pre_softmax, enable_bias, record_property | ||
| ): |
There was a problem hiding this comment.
The benchmark tests only verify that timings are positive (torch_ms > 0, fused_ms > 0) and confirm numerical correctness, but never assert that the fused kernel actually outperforms the PyTorch baseline. Any performance regression will silently pass CI while still printing a degraded speedup in the output.
Consider adding a minimum speedup threshold, for example:
| ) | |
| def test_fused_topk_router_perf_against_torch( | |
| score_function, use_pre_softmax, enable_bias, record_property | |
| ): | |
| assert torch_ms > 0, _perf_assert_message(f"topk_router[{score_function}]", torch_ms, fused_ms) | |
| assert fused_ms > 0, _perf_assert_message(f"topk_router[{score_function}]", torch_ms, fused_ms) | |
| MIN_SPEEDUP = 2.0 | |
| assert torch_ms / fused_ms >= MIN_SPEEDUP, ( | |
| f"topk_router[{score_function}] speedup {torch_ms/fused_ms:.2f}x < required {MIN_SPEEDUP}x. " | |
| + _perf_assert_message(f"topk_router[{score_function}]", torch_ms, fused_ms) | |
| ) |
The same gap exists in test_fused_scores_for_aux_loss_perf_against_torch (line 270–275) and test_fused_moe_aux_loss_perf_against_torch (line 345–348).
| __device__ inline float warp_reduce_sum_float(float val) { | ||
| for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { | ||
| val += __shfl_down_sync(0xffffffff, val, offset); | ||
| } | ||
| return __shfl_sync(0xffffffff, val, 0); | ||
| } |
There was a problem hiding this comment.
Broadcast from lane 0 should be documented
__shfl_down_sync accumulates the sum only in lane 0 after all steps — unlike the __shfl_xor_sync butterfly approach (used elsewhere in this file) which gives the correct sum to every lane simultaneously. The subsequent __shfl_sync(…, 0) is therefore load-bearing for correctness, not just an optimisation.
Adding a short comment here prevents a future reader from accidentally removing it thinking it's redundant:
| __device__ inline float warp_reduce_sum_float(float val) { | |
| for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { | |
| val += __shfl_down_sync(0xffffffff, val, offset); | |
| } | |
| return __shfl_sync(0xffffffff, val, 0); | |
| } | |
| __device__ inline float warp_reduce_sum_float(float val) { | |
| // __shfl_down_sync accumulates the total only in lane 0; | |
| // the broadcast below is required for all lanes to see the result. | |
| for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { | |
| val += __shfl_down_sync(0xffffffff, val, offset); | |
| } | |
| return __shfl_sync(0xffffffff, val, 0); | |
| } |
| #pragma unroll | ||
| for (int k = 0; k < K; ++k) { | ||
| CompType val = -std::numeric_limits<CompType>::infinity(); | ||
| int index = (lane_id < data_size) ? lane_id : 0; | ||
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | ||
| bool masked = false; | ||
| #pragma unroll | ||
| for (int j = 0; j < k; ++j) { | ||
| masked |= (selected[j] == i); | ||
| } | ||
| if (masked) continue; | ||
| CompType cur_val = scores[i]; | ||
| if (cur_val > val) { | ||
| val = cur_val; | ||
| index = i; | ||
| } | ||
| } | ||
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | ||
| auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); | ||
| auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); | ||
| if (shuffled_val > val) { | ||
| val = shuffled_val; | ||
| index = shuffled_index; | ||
| } | ||
| } |
There was a problem hiding this comment.
OOB thread
index initialised to 0 could shadow earlier selections
For threads where lane_id >= data_size, the inner loop body never executes, so val stays at -inf and index is set to 0:
int index = (lane_id < data_size) ? lane_id : 0;0 is a valid data index that may already have been placed in selected by a previous k iteration. During the XOR-reduction phase, these OOB threads participate (they shuffle -inf values, which can never win the shuffled_val > val comparison), so the final chosen_index remains correct. However, after the broadcast:
selected[k] = chosen_index;every thread — including OOB ones — writes chosen_index to their register copy of selected, keeping all threads in sync. The algorithm is therefore correct, but initialising the fallback index to a sentinel value (e.g., -1 or data_size - 1) would make the intent clearer and avoid confusion with a real element:
| #pragma unroll | |
| for (int k = 0; k < K; ++k) { | |
| CompType val = -std::numeric_limits<CompType>::infinity(); | |
| int index = (lane_id < data_size) ? lane_id : 0; | |
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | |
| bool masked = false; | |
| #pragma unroll | |
| for (int j = 0; j < k; ++j) { | |
| masked |= (selected[j] == i); | |
| } | |
| if (masked) continue; | |
| CompType cur_val = scores[i]; | |
| if (cur_val > val) { | |
| val = cur_val; | |
| index = i; | |
| } | |
| } | |
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | |
| auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); | |
| auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); | |
| if (shuffled_val > val) { | |
| val = shuffled_val; | |
| index = shuffled_index; | |
| } | |
| } | |
| int index = (lane_id < data_size) ? lane_id : -1; // -1: sentinel for out-of-range lane |
This is purely a readability / defensive-programming concern given that invalid index values can never propagate to the output.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
|
||
| seed = 42 | ||
| torch.manual_seed(seed) | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.manual_seed(seed) |
There was a problem hiding this comment.
Module-level seed is fragile for
tokens_per_expert randomness
The random seed is set once at import time. If pytest collects or runs other tests before this module's tests execute, the global random state will have advanced and test_fused_moe_aux_loss_perf_against_torch (which calls torch.randint for tokens_per_expert) will use an unknown seed. While the numerical correctness check in that test (torch.testing.assert_close(torch_loss, fused_loss)) passes regardless of the specific random values, reproducible benchmarks are easier to debug.
Consider moving the seed setup into each individual test function, or using a pytest fixture to ensure a consistent state per test.
Summary
This PR ports and keeps a focused set of CUDA fused-router optimizations that showed consistent gains on the tested workload, while avoiding heavier variants that regressed performance.
1. Add fused-router performance benchmark test
tests/pytorch/test_fused_router_perf.py.fused_topk_with_score_functionfused_compute_score_for_moe_aux_lossfused_moe_aux_loss2. Keep low-risk fused-router CUDA optimizations
transformer_engine/common/fused_router/utils.htransformer_engine/common/fused_router/fused_topk_with_score_function.cuexpert_bias.has_data()handling in forward to avoid invalid dtype switch when bias is absent.transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu3. Optimize
naive_topk_and_maskfor small-ktransformer_engine/common/fused_router/utils.htopk <= 8.Performance (A/B)
Measured with:
TE_RUN_PERF_TESTS=1 pytest -q tests/pytorch/test_fused_router_perf.py -sBefore
topk_router[softmax]: fused0.029562 ms, speedup8.3067xtopk_router[sigmoid]: fused0.030138 ms, speedup7.2715xscores_for_aux_loss[softmax]: fused0.026183 ms, speedup3.8721xscores_for_aux_loss[sigmoid]: fused0.025872 ms, speedup3.8892xmoe_aux_loss: fused0.015680 ms, speedup1.8884xAfter
topk_router[softmax]: fused0.022384 ms, speedup11.1324xtopk_router[sigmoid]: fused0.022840 ms, speedup9.7714xscores_for_aux_loss[softmax]: fused0.017230 ms, speedup5.9707xscores_for_aux_loss[sigmoid]: fused0.017049 ms, speedup6.0205xmoe_aux_loss: fused0.015412 ms, speedup1.8424xNotes
topk_router/scores_for_aux_lossperformance on this setup.