From 8ad26ca61aabb7200540b76fd449a646fafd395d Mon Sep 17 00:00:00 2001 From: yosh20004 <2172622103@qq.com> Date: Tue, 17 Mar 2026 06:13:30 +0000 Subject: [PATCH 1/4] [Common] Optimize naive top-k masking in fused router Refactor naive_topk_and_mask to track selections with a per-lane mask and reduce across the warp more directly. This keeps the top-k routing path cleaner while preserving the existing interface. Co-authored-by: Guitar_Players XiaomingFun233 Signed-off-by: yosh20004 <2172622103@qq.com> --- .../common/fused_router/utils.h | 91 +++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 372efdc490..c398372228 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -203,50 +203,69 @@ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int __syncwarp(); } -__device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, - int *topk_indices, CompType *topk_scores, int lane_id) { - // Check if the index is masked by the later iteration - auto is_masked = [&topk_indices](int k, int index) { - if (k == 0) return false; - for (int i = 0; i < k; i++) { - if (topk_indices[i] == index) return true; - } - return false; - }; - // Topk Times: Find the max value and its index - // Then mask it, and record the index in the topk_indices - // After looping topk times, the topk_indices will be the topk indices +template +__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, + T *topk_scores, int lane_id) { + // Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected. + uint32_t local_mask = 0; + for (int k = 0; k < topk; k++) { - // Find the max value and its index - CompType val = (lane_id < data_size && !is_masked(k, lane_id)) - ? scores[lane_id] - : -std::numeric_limits::infinity(); - int index = (lane_id < data_size) ? lane_id : 0; - // Some value is hanlded in local thread - // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... - // Reduce the value in local thread - for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - CompType cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() : scores[i]; - if (cur_val > val) { - val = cur_val; - index = i; + CompType local_max_val = -std::numeric_limits::infinity(); + int local_max_idx = -1; + + // 1) Per-lane local max on unmasked elements. + int bit_idx = 0; + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + CompType cur_val = 0.0f; + if constexpr (std::is_same_v) { + uint64_t mask = -(uint64_t)((local_mask >> bit_idx) & 1u); + uint64_t x_bits = __double_as_longlong(static_cast(scores[i])); + uint64_t result_bits = + (~mask & x_bits) | (mask & 0xFFF0000000000000ULL); + cur_val = __longlong_as_double(result_bits); + } else { + uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); + uint32_t x_bits = __float_as_uint(static_cast(scores[i])); + uint32_t result_bits = + (~full_mask & x_bits) | (full_mask & 0xFF800000u); + cur_val = __uint_as_float(result_bits); + } + if (cur_val > local_max_val) { + local_max_val = cur_val; + local_max_idx = i; } + bit_idx++; } - // Warp shuffle between threads - for (int s = 16; s > 0; s /= 2) { - auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); - auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); - if (shuffled_val > val) { - val = shuffled_val; - index = shuffled_index; + + // 2) Warp reduction to find global max and index. + CompType global_max_val = local_max_val; + int global_max_idx = local_max_idx; + for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { + CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s); + int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s); + if (shuffled_val > global_max_val) { + global_max_val = shuffled_val; + global_max_idx = shuffled_idx; } } + global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0); + global_max_val = __shfl_sync(0xffffffff, global_max_val, 0); + + // 3) Write top-k result. if (lane_id == 0) { - topk_indices[k] = index; - topk_scores[k] = val; + topk_indices[k] = global_max_idx; + topk_scores[k] = static_cast(global_max_val); + } + + // 4) Mark selected element in owning lane's local mask. + if (global_max_idx >= 0 && (global_max_idx % kThreadsPerWarp) == lane_id) { + int local_bit_pos = global_max_idx / kThreadsPerWarp; + if (local_bit_pos < 32) { + local_mask |= (1u << local_bit_pos); + } } - __syncwarp(); } + __syncwarp(); } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future From d6dfdcf673eff8217b6539788b08889bcd113732 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:46:22 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/fused_router/utils.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index c398372228..71ad0e1a82 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -205,7 +205,7 @@ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int template __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, - T *topk_scores, int lane_id) { + T *topk_scores, int lane_id) { // Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected. uint32_t local_mask = 0; @@ -220,14 +220,12 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i if constexpr (std::is_same_v) { uint64_t mask = -(uint64_t)((local_mask >> bit_idx) & 1u); uint64_t x_bits = __double_as_longlong(static_cast(scores[i])); - uint64_t result_bits = - (~mask & x_bits) | (mask & 0xFFF0000000000000ULL); - cur_val = __longlong_as_double(result_bits); + uint64_t result_bits = (~mask & x_bits) | (mask & 0xFFF0000000000000ULL); + cur_val = __longlong_as_double(result_bits); } else { uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); uint32_t x_bits = __float_as_uint(static_cast(scores[i])); - uint32_t result_bits = - (~full_mask & x_bits) | (full_mask & 0xFF800000u); + uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u); cur_val = __uint_as_float(result_bits); } if (cur_val > local_max_val) { From 28844c1a0dacc75b8e1d2e72a65ad2bcbc00cc57 Mon Sep 17 00:00:00 2001 From: yosh20005 <2172622103@qq.com> Date: Fri, 20 Mar 2026 00:13:44 +0800 Subject: [PATCH 3/4] fixup! [Common] Optimize naive top-k masking in fused router Signed-off-by: yosh20005 <2172622103@qq.com> --- transformer_engine/common/fused_router/utils.h | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 71ad0e1a82..abe8800bd6 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -217,17 +217,10 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i int bit_idx = 0; for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { CompType cur_val = 0.0f; - if constexpr (std::is_same_v) { - uint64_t mask = -(uint64_t)((local_mask >> bit_idx) & 1u); - uint64_t x_bits = __double_as_longlong(static_cast(scores[i])); - uint64_t result_bits = (~mask & x_bits) | (mask & 0xFFF0000000000000ULL); - cur_val = __longlong_as_double(result_bits); - } else { - uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); - uint32_t x_bits = __float_as_uint(static_cast(scores[i])); - uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u); - cur_val = __uint_as_float(result_bits); - } + uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); + uint32_t x_bits = __float_as_uint(static_cast(scores[i])); + uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u); + cur_val = __uint_as_float(result_bits); if (cur_val > local_max_val) { local_max_val = cur_val; local_max_idx = i; From 1958bb4ebd14f6c4958095bb55dbeff6c372409b Mon Sep 17 00:00:00 2001 From: yosh20005 <2172622103@qq.com> Date: Fri, 20 Mar 2026 01:34:49 +0800 Subject: [PATCH 4/4] [Common] Add top-k local mask bounds assert Signed-off-by: yosh20005 <2172622103@qq.com> --- transformer_engine/common/fused_router/utils.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index abe8800bd6..3d5d75d4af 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ +#include + #include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -208,6 +210,8 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i T *topk_scores, int lane_id) { // Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected. uint32_t local_mask = 0; + assert(data_size <= static_cast(sizeof(local_mask) * 8 * kThreadsPerWarp) && + "local_mask too small for data_size > 1024"); for (int k = 0; k < topk; k++) { CompType local_max_val = -std::numeric_limits::infinity();