From 68443e923ca76f3d3562ad9652efb9c9179e0bed Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Mon, 11 May 2026 18:20:08 +0800 Subject: [PATCH 1/4] [Ops][Optimization]Kernel fusion: cast+sigmoid+bias+noauxtc --- custom_ops/gpu_ops/cpp_extensions.cc | 11 + custom_ops/gpu_ops/grouped_topk_kernels.cu | 759 ++++++++++++++++++++ custom_ops/setup_ops.py | 2 + fastdeploy/model_executor/layers/moe/moe.py | 35 +- 4 files changed, 795 insertions(+), 12 deletions(-) create mode 100644 custom_ops/gpu_ops/grouped_topk_kernels.cu diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 204ea33e50b..9ee1cfae4a7 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -691,6 +691,15 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor); + std::vector FusedCastSigmoidBias(const paddle::Tensor& input, const paddle::Tensor& bias, std::string cast_type); @@ -1706,6 +1715,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("grouped_topk", &grouped_topk, "fused grouped topk for MoE routing"); + m.def("fused_cast_sigmoid_bias", &FusedCastSigmoidBias, "Fused cast+sigmoid+bias for MoE gating scores", diff --git a/custom_ops/gpu_ops/grouped_topk_kernels.cu b/custom_ops/gpu_ops/grouped_topk_kernels.cu new file mode 100644 index 00000000000..d9c99908e72 --- /dev/null +++ b/custom_ops/gpu_ops/grouped_topk_kernels.cu @@ -0,0 +1,759 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "helper.h" + +namespace cg = cooperative_groups; + +constexpr unsigned FUSED_FULL_WARP_MASK = 0xffffffff; + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ inline float cuda_cast(__half val) { + return __half2float(val); +} + +template <> +__device__ inline __half cuda_cast<__half, float>(float val) { + return __float2half(val); +} + +// Numerically stable sigmoid via tanh: σ(x) = 0.5 * tanh(0.5*x) + 0.5 +template +__device__ __forceinline__ T sigmoid_device(T x) { + float xf = cuda_cast(x); + return cuda_cast(0.5f * tanhf(0.5f * xf) + 0.5f); +} + +// Sigmoid matching fused_cast_sigmoid_bias: 1 / (1 + exp(-x)). +// Must use the same formula to get bit-identical results when comparing +// against the fused_cast_sigmoid_bias + noaux_tc path. +template +__device__ __forceinline__ float sigmoid_to_float(InT x) { + float xf = cuda_cast(x); + return 1.0f / (1.0f + expf(-xf)); +} + +template +__device__ inline T neg_inf() { + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + +template +__device__ inline bool is_finite_val(T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + +namespace warp_topk_fused { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) return 0; + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, + T baseline, + idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + return res; +} + +template +struct BitonicMerge { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) + is_better = is_better_than( + val, other_val, idx_arr[i], idx_arr[other_i]); + else + is_better = is_better_than(val, other_val); + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = + __shfl_xor_sync(FUSED_FULL_WARP_MASK, *idx_arr, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FUSED_FULL_WARP_MASK, idx, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + __device__ __forceinline__ idxT get_idx(int i = 0) const { + return idx_arr_[i]; + } + __device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + int const lane_; + idxT const k_; + T const dummy_; +}; + +// WarpSelect WITHOUT __syncthreads() in done() — safe when only one warp is +// active. +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_idx_(0), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + else + do_add = is_better_than(val, k_th_); + + uint32_t mask = __ballot_sync(FUSED_FULL_WARP_MASK, do_add); + if (mask == 0) return; + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + // NOTE: no __syncthreads() here — callers must sync externally if needed. + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync( + FUSED_FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) + k_th_idx_ = __shfl_sync( + FUSED_FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + T& old = val_arr_[max_arr_len_ - 1]; + bool is_better; + if constexpr (is_stable) + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + else + is_better = is_better_than(val, old); + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + BitonicMerge::merge( + val_arr_, idx_arr_); + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; + +} // namespace warp_topk_fused + +// --------------------------------------------------------------------------- +// Fused kernel: group-score computation + group selection + expert topk +// + sparse scores write-back, in one kernel launch. +// +// gridDim = num_tokens (one block per token) +// blockDim = n_group * WARP_SIZE (one warp per group) +// --------------------------------------------------------------------------- +template +__global__ void grouped_topk_fused_kernel( + float* scores, // output: sparse routing weights [num_tokens, num_experts] + float* topk_values, // output: topk routing weights [num_tokens, topk] + IdxT* topk_indices, // output: topk expert indices [num_tokens, topk] + InT const* gating_output, // input: raw logits (float or bf16) + // [num_tokens, num_experts] + float const* e_score_correction_bias, // input: bias [num_experts] + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double routed_scaling_factor) { + int32_t const token_id = static_cast(blockIdx.x); + if (token_id >= static_cast(num_tokens)) return; + + int32_t const warp_id = threadIdx.x / WARP_SIZE; + int32_t const lane_id = threadIdx.x % WARP_SIZE; + int32_t const n_group_i32 = static_cast(n_group); + int32_t const topk_group_i32 = static_cast(topk_group); + int32_t const topk_i32 = static_cast(topk); + int32_t const num_warps = blockDim.x / WARP_SIZE; + + if (warp_id >= n_group_i32 || num_warps < n_group_i32) return; + + int32_t const num_experts_per_group = + static_cast(num_experts) / n_group_i32; + int32_t const align_epg = warp_topk_fused::round_up_to_multiple_of( + num_experts_per_group); + + InT const* gate_token = gating_output + (int64_t)token_id * num_experts; + float* scores_token = scores + (int64_t)token_id * num_experts; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + // smem layout: [val_staging (256B-aligned) | idx_staging | (16B pad) | + // s_group_scores] + extern __shared__ char smem_buf[]; + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + uintptr_t ptr = + (reinterpret_cast(smem_buf + val_aligned + idx_bytes) + 15) & + ~static_cast(15); + float* s_group_scores = reinterpret_cast(ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // ------------------------------------------------------------------ + // Phase 1 (all warps): compute group score = top2 sum of (gate + bias) + // ------------------------------------------------------------------ + { + int32_t const offset = warp_id * num_experts_per_group; + InT const* gate_g = gate_token + offset; + float const* bias_g = e_score_correction_bias + offset; + + float largest = neg_inf(); + float second_largest = neg_inf(); + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + float val = sigmoid_to_float(gate_g[i]) + bias_g[i]; + if (val > largest) { + second_largest = largest; + largest = val; + } else if (val > second_largest) { + second_largest = val; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) + largest = sigmoid_to_float(gate_g[i]) + bias_g[i]; + } + + float max1 = cg::reduce(tile, largest, cg::greater()); + float max2 = max1; + int cnt = __popc(__ballot_sync(FUSED_FULL_WARP_MASK, largest == max1)); + if (cnt == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + if (lane_id == 0) s_group_scores[warp_id] = max1 + max2; + } + + __syncthreads(); + + // ------------------------------------------------------------------ + // Phase 2 (warp0 only): group selection → expert selection → output + // ------------------------------------------------------------------ + if (warp_id != 0) return; + + topk_values += (int64_t)token_id * topk; + topk_indices += (int64_t)token_id * topk; + + // Select top-topk_group groups + warp_topk_fused::WarpSelect group_sel( + topk_group_i32, neg_inf()); + + float gscore = + (lane_id < n_group_i32) ? s_group_scores[lane_id] : neg_inf(); + group_sel.add(gscore, lane_id); + group_sel.done(); // no __syncthreads() — only warp0 is active here + + // Check if enough valid groups exist + bool proceed = false; + if (topk_group_i32 > 0) { + float kth = __shfl_sync( + FUSED_FULL_WARP_MASK, group_sel.get_val(0), topk_group_i32 - 1); + proceed = (kth != neg_inf()); + } + + if (!proceed) { + // Fallback: zero scores, uniform topk + for (int i = lane_id; i < static_cast(num_experts); i += WARP_SIZE) + scores_token[i] = 0.0f; + __syncwarp(); + for (int i = lane_id; i < topk_i32; i += WARP_SIZE) { + topk_indices[i] = static_cast(i); + topk_values[i] = 1.0f / static_cast(topk_i32); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif + return; + } + + // Select top-topk experts from selected groups (using biased scores as + // candidates) + warp_topk_fused::WarpSelect expert_sel( + topk_i32, neg_inf()); // reuses same smem — group_sel is done + + int32_t sel_gid = (lane_id < topk_group_i32) ? group_sel.get_idx(0) : 0; + for (int32_t g = 0; g < topk_group_i32; ++g) { + int32_t gid = __shfl_sync(FUSED_FULL_WARP_MASK, sel_gid, g); + int32_t offset = gid * num_experts_per_group; + for (int32_t i = lane_id; i < align_epg; i += WARP_SIZE) { + float cand = neg_inf(); + int32_t idx = 0; + if (i < num_experts_per_group) { + idx = offset + i; + float biased = + sigmoid_to_float(gate_token[idx]) + e_score_correction_bias[idx]; + if (is_finite_val(biased)) cand = biased; + } + expert_sel.add(cand, idx); + } + } + expert_sel.done(); + + // Compute routing weights from unbiased scores + float lane_score = 0.0f; + IdxT lane_idx = 0; + if (lane_id < topk_i32) { + lane_idx = static_cast(expert_sel.get_idx(0)); + lane_score = sigmoid_to_float(gate_token[static_cast(lane_idx)]); + } + + float topk_sum = 1e-20f; + if (renormalize) topk_sum += cg::reduce(tile, lane_score, cg::plus()); + + float scale = static_cast(routed_scaling_factor); + if (renormalize) scale /= topk_sum; + + // Fill sparse scores: first zero out, then write selected experts' weights + for (int i = lane_id; i < static_cast(num_experts); i += WARP_SIZE) + scores_token[i] = 0.0f; + __syncwarp(); + + if (lane_id < topk_i32) { + float val = lane_score * scale; + scores_token[static_cast(lane_idx)] = val; + topk_indices[lane_id] = lane_idx; + topk_values[lane_id] = val; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// --------------------------------------------------------------------------- +// Launch wrapper +// --------------------------------------------------------------------------- +template +void invokeFusedNoAuxTc(InT* gating_output, + float* e_score_correction_bias, + float* scores, + float* topk_values, + IdxT* topk_indices, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + cudaStream_t const stream) { + auto* kernel = &grouped_topk_fused_kernel; + + // blockDim = n_group * WARP_SIZE (one warp per group) + int32_t const num_warps = static_cast(n_group); + + // smem = WarpSelect staging (float) + 16B pad + group_scores buffer (float) + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(float); + size_t const smem_bytes = val_aligned + idx_bytes + extra_bytes; + + cudaLaunchConfig_t config; + config.gridDim = static_cast(num_tokens); + config.blockDim = static_cast(n_group) * WARP_SIZE; + config.dynamicSmemBytes = smem_bytes; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, + kernel, + scores, + topk_values, + topk_indices, + gating_output, + e_score_correction_bias, + num_tokens, + num_experts, + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor); +} + +#define INSTANTIATE_FUSED_NOAUX_TC(InT, IdxT) \ + template void invokeFusedNoAuxTc( \ + InT * gating_output, \ + float* e_score_correction_bias, \ + float* scores, \ + float* topk_values, \ + IdxT* topk_indices, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + cudaStream_t const stream); + +INSTANTIATE_FUSED_NOAUX_TC(float, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__nv_bfloat16, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__half, int64_t); + +// --------------------------------------------------------------------------- +// Paddle op wrapper +// --------------------------------------------------------------------------- +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor) { + auto input_shape = gating_output.shape(); + PD_CHECK(input_shape.size() == 2); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto place = gating_output.place(); + PD_CHECK(n_group <= 32, "grouped_topk: n_group must be <= 32"); + PD_CHECK(topk <= 32, "grouped_topk: topk must be <= WARP_SIZE (32)"); + + // Outputs are always float32 regardless of input dtype + auto scores = paddle::empty( + {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + auto topk_values = + paddle::empty({num_tokens, topk}, paddle::DataType::FLOAT32, place); + auto topk_indices = + paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); + + auto stream = gating_output.stream(); + auto dtype = gating_output.dtype(); + + float* scores_ptr = reinterpret_cast(scores.data()); + float* topk_values_ptr = reinterpret_cast(topk_values.data()); + int64_t* topk_idx_ptr = + reinterpret_cast(topk_indices.data()); + float* bias_ptr = + reinterpret_cast(e_score_correction_bias.data()); + + if (dtype == paddle::DataType::BFLOAT16) { + invokeFusedNoAuxTc<__nv_bfloat16, int64_t>( + reinterpret_cast<__nv_bfloat16*>( + gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else if (dtype == paddle::DataType::FLOAT16) { + invokeFusedNoAuxTc<__half, int64_t>( + reinterpret_cast<__half*>(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else { + PD_CHECK( + dtype == paddle::DataType::FLOAT32, + "grouped_topk: gating_output must be float32, float16, or bfloat16"); + invokeFusedNoAuxTc( + reinterpret_cast(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } + + return {scores, topk_values, topk_indices}; +} + +std::vector GroupedTopkInferDtype( + const paddle::DataType& /*gating_output_dtype*/, + const paddle::DataType& /*e_score_correction_bias_dtype*/) { + // Outputs are always float32: cast is fused into the kernel. + return {paddle::DataType::FLOAT32, + paddle::DataType::FLOAT32, + paddle::DataType::INT64}; +} + +std::vector> GroupedTopkInferShape( + const std::vector& gating_output_shape, + const std::vector&, + const int topk) { + auto num_tokens = gating_output_shape[0]; + auto num_experts = gating_output_shape[1]; + return {{num_tokens, num_experts}, {num_tokens, topk}, {num_tokens, topk}}; +} + +PD_BUILD_STATIC_OP(grouped_topk) + .Inputs({"gating_output", "e_score_correction_bias"}) + .Outputs({"output_tensor", "topk_values", "topk_indices"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk: int", + "renormalize: bool", + "routed_scaling_factor: float"}) + .SetKernelFn(PD_KERNEL(grouped_topk)) + .SetInferShapeFn(PD_INFER_SHAPE(GroupedTopkInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GroupedTopkInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index bf652968e46..3cf39a9fc94 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -331,6 +331,7 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", @@ -688,6 +689,7 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c048248eec4..47bc8d18917 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -36,16 +36,18 @@ from fastdeploy.worker.experts_manager import RedundantExpertManger try: - from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant + from fastdeploy.model_executor.ops.gpu import ( + grouped_topk, + noaux_tc, + noaux_tc_redundant, + ) except: logger.warning("import noaux_tc Failed!") import numpy as np if current_platform.is_cuda(): - from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( - fused_cast_sigmoid_bias, - ) + pass def get_moe_method(layer=None): @@ -103,11 +105,7 @@ def get_moe_scores( compute moe scores using e_score_correction_bias. """ assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - if use_fused_cast and current_platform.is_cuda(): - scores, scores_with_bias = fused_cast_sigmoid_bias(gating_output, e_score_correction_bias, cast_type="float32") - else: - scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias + use_fused = use_fused_cast and current_platform.is_cuda() if envs.FD_USE_PHI_MOE_TOPK: # calculate renormalize and routed_scaling_factor value outside the noaux_tc @@ -116,7 +114,9 @@ def get_moe_scores( renormalize = False routed_scaling_factor = 1.0 - if expert_id_to_ep_rank_array is None: + if expert_id_to_ep_rank_array is None and not use_fused: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, @@ -126,9 +126,20 @@ def get_moe_scores( renormalize, routed_scaling_factor, ) + elif expert_id_to_ep_rank_array is None and use_fused: + # fused kernel: cast + sigmoid + add + noaux_tc + scores, topk_values, topk_idx = grouped_topk( + gating_output, + e_score_correction_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + renormalize, + routed_scaling_factor, + ) else: - # noaux_tc_redundant returns 4 values: scores, topk_values, topk_idx, - # and tokens_per_expert_stats_list_out (inplace updated) + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx, _ = noaux_tc_redundant( scores, scores_with_bias, From 49563e4243017e5d5b33978ee3e65955c9c03518 Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Mon, 11 May 2026 19:52:00 +0800 Subject: [PATCH 2/4] Add unit_test file --- tests/operators/test_grouped_topk_op.py | 485 ++++++++++++++++++++++++ 1 file changed, 485 insertions(+) create mode 100644 tests/operators/test_grouped_topk_op.py diff --git a/tests/operators/test_grouped_topk_op.py b/tests/operators/test_grouped_topk_op.py new file mode 100644 index 00000000000..1e76328eb93 --- /dev/null +++ b/tests/operators/test_grouped_topk_op.py @@ -0,0 +1,485 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the `grouped_topk` custom CUDA op (low-level interface). + +grouped_topk fuses sigmoid into the kernel and accepts raw logits directly, +unlike noaux_tc which requires Python-side sigmoid preprocessing. + +Algorithm: + 1. scores = sigmoid(gating_output) [fused inside kernel] + 2. scores_with_bias = scores + e_score_correction_bias + 3. group_scores = sum of top-2 biased expert scores per group + 4. Select top-topk_group groups + 5. Within selected groups select top-topk experts by biased score + 6. Gather unbiased sigmoid scores for selected experts as topk_values + 7. Optionally renormalize and scale by routed_scaling_factor + +Model configs covered: + DeepSeek-V3 / R1 num_experts=256, n_group=8, topk_group=4, topk=8, renorm=True, scale=2.5 + GLM-4.5-Air num_experts=128, n_group=1, topk_group=1, topk=8, renorm=True, scale=1.0 + Qwen3-30B-A3B num_experts=128, n_group=4, topk_group=2, topk=8, renorm=False, scale=1.0 + Kimi-K2 num_experts=384, n_group=8, topk_group=2, topk=8, renorm=False, scale=1.0 +""" + +import unittest + +import numpy as np +import paddle + +try: + from fastdeploy.model_executor.ops.gpu import grouped_topk + + _GROUPED_TOPK_AVAILABLE = True +except Exception: + _GROUPED_TOPK_AVAILABLE = False + + +def native_grouped_topk( + gating_output: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, +): + """Pure-Python reference that mirrors the grouped_topk kernel semantics. + + Args: + gating_output: raw logits, shape [num_tokens, num_experts] + e_score_correction_bias: bias added to sigmoid scores, shape [1, num_experts] or [num_experts] + n_group: number of expert groups + topk_group: number of groups selected per token + topk: number of experts selected per token + renormalize: whether to L1-normalise the selected weights + routed_scaling_factor: multiplicative scale applied after renorm + + Returns: + (scores_out, topk_values, topk_indices) + scores_out – sparse score tensor, shape [num_tokens, num_experts] + topk_values – weights for selected experts, shape [num_tokens, topk] + topk_indices – expert indices, shape [num_tokens, topk] (int64) + """ + num_tokens, num_experts = gating_output.shape + experts_per_group = num_experts // n_group + + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias + + # Step 1: group scores = sum of top-2 biased scores per group + biased = scores_with_bias.reshape([num_tokens, n_group, experts_per_group]) + group_scores = biased.topk(min(2, experts_per_group), axis=-1)[0].sum(axis=-1) + + # Step 2: select top-topk_group groups + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] + group_mask = paddle.zeros_like(group_scores) + group_mask.put_along_axis_(group_idx, paddle.ones_like(group_idx, dtype=group_mask.dtype), axis=-1) + score_mask = ( + group_mask.unsqueeze(-1).expand([num_tokens, n_group, experts_per_group]).reshape([num_tokens, num_experts]) + ) + + # Step 3: select top-topk experts within selected groups (biased score) + tmp_scores = scores_with_bias.masked_fill(~score_mask.cast(paddle.bool), float("-inf")) + topk_indices = paddle.topk(tmp_scores, topk, axis=-1)[1] + + # Step 4: gather unbiased sigmoid scores + topk_values = paddle.take_along_axis(scores, topk_indices, axis=1) + + # Step 5: renormalize + scale + if renormalize: + topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20) + if routed_scaling_factor != 1.0: + topk_values = topk_values * routed_scaling_factor + + scores_out = paddle.zeros_like(scores) + scores_out.put_along_axis_(topk_indices, topk_values, axis=1) + + return scores_out, topk_values, topk_indices.cast(paddle.int64) + + +@unittest.skipUnless(_GROUPED_TOPK_AVAILABLE, "grouped_topk custom op not available (not compiled)") +class TestGroupedTopkOp(unittest.TestCase): + """Tests for the grouped_topk custom CUDA op.""" + + ATOL = 1e-3 + RTOL = 1e-3 + + def setUp(self): + paddle.seed(42) + + # ------------------------------------------------------------------ + # Parametrised helper + # ------------------------------------------------------------------ + def _run_case( + self, + num_tokens: int, + num_experts: int, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + input_dtype=paddle.float32, + bias_scale: float = 0.1, + seed: int = 42, + ): + paddle.seed(seed) + gating = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = (paddle.rand([1, num_experts], dtype=paddle.float32) - 0.5) * bias_scale + + # Reference always runs in fp32 + gating_fp32 = gating.cast(paddle.float32) if input_dtype != paddle.float32 else gating + ref_scores, ref_tv, ref_ti = native_grouped_topk( + gating_fp32.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + op_scores, op_tv, op_ti = grouped_topk( + gating.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + label = ( + f"T={num_tokens}, E={num_experts}, n_group={n_group}, " + f"topk_group={topk_group}, topk={topk}, " + f"renorm={renormalize}, scale={routed_scaling_factor}, dtype={input_dtype}" + ) + + self.assertEqual(op_tv.shape, [num_tokens, topk], f"[{label}] topk_values shape") + self.assertEqual(op_ti.shape, [num_tokens, topk], f"[{label}] topk_indices shape") + self.assertEqual(op_ti.dtype, paddle.int64, f"[{label}] topk_indices dtype") + self.assertEqual(op_tv.dtype, paddle.float32, f"[{label}] topk_values dtype") + + # Compare set-level index match (position order not guaranteed) + ref_sorted = paddle.sort(ref_ti, axis=-1) + op_sorted = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_sorted, op_sorted).item(): + n_diff = (ref_sorted != op_sorted).sum().item() + self.fail(f"[{label}] topk_indices set mismatch: {n_diff} positions differ") + + # Align values by expert index before comparing + ref_ord = paddle.argsort(ref_ti, axis=-1) + op_ord = paddle.argsort(op_ti, axis=-1) + ref_tv_s = paddle.take_along_axis(ref_tv, ref_ord, axis=-1) + op_tv_s = paddle.take_along_axis(op_tv, op_ord, axis=-1) + if not paddle.allclose(op_tv_s, ref_tv_s, atol=self.ATOL, rtol=self.RTOL).item(): + max_diff = (op_tv_s - ref_tv_s).abs().max().item() + self.fail(f"[{label}] topk_values max_diff={max_diff:.2e}") + + # ------------------------------------------------------------------ + # GLM-4.5-Air: n_experts=128, n_group=1, topk_group=1, topk=8, renorm=True + # ------------------------------------------------------------------ + def test_glm45air_T1(self): + self._run_case(1, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T32(self): + self._run_case(32, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T128(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T512(self): + self._run_case(512, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T1024(self): + self._run_case(1024, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T4096(self): + self._run_case(4096, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T8192(self): + self._run_case(8192, 128, 1, 1, 8, True, 1.0) + + # ------------------------------------------------------------------ + # DeepSeek-V3 / R1: n_experts=256, n_group=8, topk_group=4, topk=8, + # renorm=True, scale=2.5 + # ------------------------------------------------------------------ + def test_deepseek_v3_T1(self): + self._run_case(1, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T32(self): + self._run_case(32, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T128(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T512(self): + self._run_case(512, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T4096(self): + self._run_case(4096, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T8192(self): + self._run_case(8192, 256, 8, 4, 8, True, 2.5) + + # ------------------------------------------------------------------ + # Qwen3-30B-A3B: n_experts=128, n_group=4, topk_group=2, topk=8, + # renorm=False + # ------------------------------------------------------------------ + def test_qwen3_30b_T1(self): + self._run_case(1, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T128(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T512(self): + self._run_case(512, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T4096(self): + self._run_case(4096, 128, 4, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # Kimi-K2: n_experts=384, n_group=8, topk_group=2, topk=8, renorm=False + # ------------------------------------------------------------------ + def test_kimi_k2_T1(self): + self._run_case(1, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T128(self): + self._run_case(128, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T512(self): + self._run_case(512, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T4096(self): + self._run_case(4096, 384, 8, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # bfloat16 input path: kernel should cast internally to fp32 + # ------------------------------------------------------------------ + def test_bf16_input_glm45air(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0, input_dtype=paddle.bfloat16) + + def test_bf16_input_deepseek_v3(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5, input_dtype=paddle.bfloat16) + + def test_bf16_input_qwen3_30b(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0, input_dtype=paddle.bfloat16) + + # ------------------------------------------------------------------ + # Output shape and dtype sanity + # ------------------------------------------------------------------ + def test_output_shapes(self): + """Verify output shapes for various (T, E, topk) combinations.""" + cases = [ + (1, 128, 1, 1, 8), + (32, 256, 8, 4, 8), + (64, 384, 8, 2, 8), + ] + for T, E, ng, tkg, topk in cases: + gating = paddle.randn([T, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertEqual(tv.shape, [T, topk], f"T={T},E={E}: topk_values shape") + self.assertEqual(ti.shape, [T, topk], f"T={T},E={E}: topk_indices shape") + + def test_output_dtype_is_float32(self): + """topk_values must always be float32 regardless of input dtype.""" + for dtype in [paddle.float32, paddle.bfloat16]: + gating = paddle.randn([16, 128], dtype=dtype) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertEqual(tv.dtype, paddle.float32, f"dtype={dtype}: topk_values not float32") + self.assertEqual(ti.dtype, paddle.int64, f"dtype={dtype}: topk_indices not int64") + + # ------------------------------------------------------------------ + # Correctness invariants + # ------------------------------------------------------------------ + def test_topk_indices_in_valid_range(self): + """All selected expert indices must lie in [0, num_experts).""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8), (384, 8, 2, 8)]: + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertTrue((ti >= 0).all().item(), f"E={E}: negative index found") + self.assertTrue((ti < E).all().item(), f"E={E}: index >= num_experts") + + def test_no_duplicate_experts_per_token(self): + """Each token must select exactly topk distinct experts.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + for row in ti.numpy(): + self.assertEqual(len(set(row.tolist())), topk, f"E={E}: duplicate expert indices in row {row}") + + def test_topk_values_non_negative(self): + """Sigmoid output is in (0,1); routing weights must be >= 0.""" + gating = paddle.randn([64, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertTrue((tv >= 0).all().item(), "topk_values contains negative weights") + + def test_renormalized_weights_sum_to_one(self): + """With renormalize=True and scale=1.0, per-token weights sum ≈ 1.""" + num_tokens = 64 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.ones(num_tokens, dtype=np.float32), + atol=1e-3, + err_msg="Renormalized weights do not sum to 1 per token", + ) + + def test_scaled_weights_sum(self): + """With renormalize=True and scale=2.5, per-token weights sum ≈ 2.5.""" + num_tokens, scale = 64, 2.5 + gating = paddle.randn([num_tokens, 256], dtype=paddle.float32) + bias = paddle.zeros([1, 256], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 8, 4, 8, True, scale) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.full(num_tokens, scale, dtype=np.float32), + atol=1e-2, + err_msg=f"Scaled weights do not sum to {scale} per token", + ) + + def test_no_renorm_weights_are_raw_sigmoid(self): + """With renormalize=False, topk_values must equal sigmoid(logits) at selected positions.""" + num_tokens, E = 32, 128 + gating = paddle.randn([num_tokens, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, False, 1.0) + expected = paddle.take_along_axis(paddle.nn.functional.sigmoid(gating), ti, axis=1) + np.testing.assert_allclose( + tv.numpy(), + expected.numpy(), + atol=1e-4, + err_msg="Without renorm, topk_values should equal sigmoid(gating) at selected positions", + ) + + def test_deterministic(self): + """Two identical calls must produce bit-for-bit identical outputs.""" + gating = paddle.randn([32, 256], dtype=paddle.float32) + bias = (paddle.rand([1, 256], dtype=paddle.float32) - 0.5) * 0.1 + _, tv1, ti1 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + _, tv2, ti2 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + self.assertTrue( + paddle.allclose(tv1, tv2, atol=0.0, rtol=0.0).item(), + "topk_values not deterministic across two identical calls", + ) + self.assertTrue( + paddle.equal_all(ti1, ti2).item(), + "topk_indices not deterministic across two identical calls", + ) + + def test_zero_bias(self): + """All-zero bias: biased == unbiased; reference and op must agree.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(16) + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + self.assertTrue( + paddle.equal_all(ref_s, op_s).item(), + f"E={E}/zero_bias: topk_indices set mismatch", + ) + + def test_large_bias_steers_routing(self): + """Large positive bias on first half of experts must dominate selection.""" + E, topk = 128, 8 + paddle.seed(17) + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.concat( + [ + paddle.full([1, E // 2], 2.0, dtype=paddle.float32), + paddle.full([1, E // 2], -2.0, dtype=paddle.float32), + ], + axis=1, + ) + _, _, ti = grouped_topk(gating, bias, 1, 1, topk, True, 1.0) + self.assertTrue( + (ti < E // 2).all().item(), + "Large positive bias on experts [0, E/2) did not steer all selections there", + ) + + def test_extreme_logits_no_nan_inf(self): + """Very large logits must not produce NaN or Inf in outputs.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(18) + gating = paddle.randn([8, E], dtype=paddle.float32) * 50.0 + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, ng, tkg, topk, False, 1.0) + self.assertFalse(paddle.isnan(tv).any().item(), f"E={E}: NaN in topk_values") + self.assertFalse(paddle.isinf(tv).any().item(), f"E={E}: Inf in topk_values") + + def test_single_expert_selected(self): + """topk=1: each token selects exactly one expert; weight == 1.0 with renorm.""" + num_tokens = 16 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 1, True, 1.0) + self.assertEqual(tv.shape, [num_tokens, 1]) + self.assertEqual(ti.shape, [num_tokens, 1]) + np.testing.assert_allclose( + tv.numpy(), + np.ones((num_tokens, 1), dtype=np.float32), + atol=1e-5, + err_msg="With topk=1 and renorm=True, each weight should be 1.0", + ) + + def test_sparse_scores_consistency(self): + """Sparse scores tensor: non-zero at selected positions must equal topk_values; zero elsewhere.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([16, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + s, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + gathered = paddle.take_along_axis(s, ti, axis=1) + np.testing.assert_allclose( + gathered.numpy(), + tv.numpy(), + atol=1e-6, + err_msg=f"E={E}: sparse scores at topk positions != topk_values", + ) + nonzero_count = (s != 0).sum(axis=-1) + self.assertTrue( + (nonzero_count == topk).all().item(), + f"E={E}: non-zero count per token != topk", + ) + + def test_irregular_token_counts(self): + """Non-power-of-2 token counts must produce correct shapes and values.""" + irregular_T = [3, 7, 15, 33, 65, 127, 129, 257, 511, 513, 900] + for T in irregular_T: + gating = paddle.randn([T, 128], dtype=paddle.float32) + bias = (paddle.rand([1, 128], dtype=paddle.float32) - 0.5) * 0.1 + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + self.assertEqual(op_tv.shape, [T, 8], f"T={T}: topk_values shape mismatch") + self.assertEqual(op_ti.shape, [T, 8], f"T={T}: topk_indices shape mismatch") + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_s, op_s).item(): + n_diff = (ref_s != op_s).sum().item() + self.fail(f"T={T}: topk_indices mismatch, {n_diff} positions differ") + + +if __name__ == "__main__": + unittest.main() From a132799d703e47e361d960b89a66af729b59b008 Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Tue, 12 May 2026 16:30:55 +0800 Subject: [PATCH 3/4] Fixed the 1e-8 precision issue that had been introduced; the current diff is 0.00. --- custom_ops/gpu_ops/grouped_topk_kernels.cu | 173 ++++++++++++--------- 1 file changed, 100 insertions(+), 73 deletions(-) diff --git a/custom_ops/gpu_ops/grouped_topk_kernels.cu b/custom_ops/gpu_ops/grouped_topk_kernels.cu index d9c99908e72..ef5ed8533f0 100644 --- a/custom_ops/gpu_ops/grouped_topk_kernels.cu +++ b/custom_ops/gpu_ops/grouped_topk_kernels.cu @@ -423,6 +423,8 @@ __global__ void grouped_topk_fused_kernel( (reinterpret_cast(smem_buf + val_aligned + idx_bytes) + 15) & ~static_cast(15); float* s_group_scores = reinterpret_cast(ptr); + float* s_topk_value = + reinterpret_cast(smem_buf); // val_staging (256B-aligned) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); @@ -453,7 +455,7 @@ __global__ void grouped_topk_fused_kernel( for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) largest = sigmoid_to_float(gate_g[i]) + bias_g[i]; } - + __syncwarp(); float max1 = cg::reduce(tile, largest, cg::greater()); float max2 = max1; int cnt = __popc(__ballot_sync(FUSED_FULL_WARP_MASK, largest == max1)); @@ -464,97 +466,122 @@ __global__ void grouped_topk_fused_kernel( if (lane_id == 0) s_group_scores[warp_id] = max1 + max2; } - __syncthreads(); + __syncthreads(); // __syncwarp() maybe better? // ------------------------------------------------------------------ // Phase 2 (warp0 only): group selection → expert selection → output // ------------------------------------------------------------------ if (warp_id != 0) return; - topk_values += (int64_t)token_id * topk; - topk_indices += (int64_t)token_id * topk; + float value = neg_inf(); + float topk_group_value = neg_inf(); + int32_t num_equalto_topkth_group; + if (token_id < num_tokens) { + int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && (isfinite(s_group_scores[lane_id]))) { + value = s_group_scores[lane_id]; + } - // Select top-topk_group groups - warp_topk_fused::WarpSelect group_sel( - topk_group_i32, neg_inf()); - - float gscore = - (lane_id < n_group_i32) ? s_group_scores[lane_id] : neg_inf(); - group_sel.add(gscore, lane_id); - group_sel.done(); // no __syncthreads() — only warp0 is active here - - // Check if enough valid groups exist - bool proceed = false; - if (topk_group_i32 > 0) { - float kth = __shfl_sync( - FUSED_FULL_WARP_MASK, group_sel.get_val(0), topk_group_i32 - 1); - proceed = (kth != neg_inf()); - } + int neg_inf_num = WARP_SIZE - n_group; + int last_neg_inf_num = 0; + // Use loop to find the largset top_group + while (neg_inf_num < want_neg_inf_num) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = neg_inf(); + } + last_neg_inf_num = neg_inf_num; - if (!proceed) { - // Fallback: zero scores, uniform topk - for (int i = lane_id; i < static_cast(num_experts); i += WARP_SIZE) - scores_token[i] = 0.0f; - __syncwarp(); - for (int i = lane_id; i < topk_i32; i += WARP_SIZE) { - topk_indices[i] = static_cast(i); - topk_values[i] = 1.0f / static_cast(topk_i32); + neg_inf_num = __popc( + __ballot_sync(FUSED_FULL_WARP_MASK, (value == neg_inf()))); } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif - return; + // There is a possible case: + // may have many different group holding the same score! + // but we only accept some of them! + num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num; } + __syncwarp(); - // Select top-topk experts from selected groups (using biased scores as - // candidates) - warp_topk_fused::WarpSelect expert_sel( - topk_i32, neg_inf()); // reuses same smem — group_sel is done - - int32_t sel_gid = (lane_id < topk_group_i32) ? group_sel.get_idx(0) : 0; - for (int32_t g = 0; g < topk_group_i32; ++g) { - int32_t gid = __shfl_sync(FUSED_FULL_WARP_MASK, sel_gid, g); - int32_t offset = gid * num_experts_per_group; - for (int32_t i = lane_id; i < align_epg; i += WARP_SIZE) { - float cand = neg_inf(); - int32_t idx = 0; - if (i < num_experts_per_group) { - idx = offset + i; - float biased = - sigmoid_to_float(gate_token[idx]) + e_score_correction_bias[idx]; - if (is_finite_val(biased)) cand = biased; + warp_topk_fused::WarpSelect + queue((int32_t)topk, neg_inf()); + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((s_group_scores[i_group] > topk_group_value) || + ((s_group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_epg; i += WARP_SIZE) { + float candidates = neg_inf(); + if (i < num_experts_per_group) { + float biased = sigmoid_to_float(gate_token[offset + i]) + + e_score_correction_bias[offset + i]; + if (is_finite_val(biased)) candidates = biased; + } + queue.add(candidates, offset + i); + } + if (s_group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } } - expert_sel.add(cand, idx); } + queue.done(); + __syncwarp(); } - expert_sel.done(); - - // Compute routing weights from unbiased scores - float lane_score = 0.0f; - IdxT lane_idx = 0; - if (lane_id < topk_i32) { - lane_idx = static_cast(expert_sel.get_idx(0)); - lane_score = sigmoid_to_float(gate_token[static_cast(lane_idx)]); - } - - float topk_sum = 1e-20f; - if (renormalize) topk_sum += cg::reduce(tile, lane_score, cg::plus()); - float scale = static_cast(routed_scaling_factor); - if (renormalize) scale /= topk_sum; - - // Fill sparse scores: first zero out, then write selected experts' weights - for (int i = lane_id; i < static_cast(num_experts); i += WARP_SIZE) - scores_token[i] = 0.0f; + float topk_sum = 1e-20; + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk_fused::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + int32_t idx = i / WARP_SIZE; + float value = + i < topk ? sigmoid_to_float(gate_token[queue.get_idx(idx)]) : 0.0f; + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += cg::reduce(tile, value, cg::plus()); + } + } __syncwarp(); - if (lane_id < topk_i32) { - float val = lane_score * scale; - scores_token[static_cast(lane_idx)] = val; - topk_indices[lane_id] = lane_idx; - topk_values[lane_id] = val; + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores_token[i] = 0; + } } + __syncwarp(); + topk_values += (int64_t)token_id * topk; + topk_indices += (int64_t)token_id * topk; + if (token_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = s_topk_value[i] / topk_sum * routed_scaling_factor; + } else { + value = s_topk_value[i] * routed_scaling_factor; + } + int32_t idx = i / WARP_SIZE; // topk may be bigger than WARP_SIZE + scores_token[queue.get_idx(idx)] = value; + topk_indices[i] = queue.get_idx(idx); + topk_values[i] = value; + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + int32_t idx = i / WARP_SIZE; + topk_indices[i] = queue.get_idx(idx); + topk_values[i] = static_cast(1.0f / topk); + } + } + } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif From 0495e88bbefa9767d1cd15a728f7346b200d2ede Mon Sep 17 00:00:00 2001 From: ShaneGZhu <1092841848@qq.com> Date: Tue, 12 May 2026 16:34:56 +0800 Subject: [PATCH 4/4] clean --- fastdeploy/model_executor/layers/moe/moe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 47bc8d18917..edc2ee41933 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -46,9 +46,6 @@ import numpy as np -if current_platform.is_cuda(): - pass - def get_moe_method(layer=None): """