diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 372efdc490..3d5d75d4af 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ #define TRANSFORMER_ENGINE_FUSED_ROUTER_UTILS_H_ +#include + #include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -203,50 +205,62 @@ __device__ inline void apply_softmax_on_float(float *scores, int data_size, int __syncwarp(); } -__device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, - int *topk_indices, CompType *topk_scores, int lane_id) { - // Check if the index is masked by the later iteration - auto is_masked = [&topk_indices](int k, int index) { - if (k == 0) return false; - for (int i = 0; i < k; i++) { - if (topk_indices[i] == index) return true; - } - return false; - }; - // Topk Times: Find the max value and its index - // Then mask it, and record the index in the topk_indices - // After looping topk times, the topk_indices will be the topk indices +template +__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, + T *topk_scores, int lane_id) { + // Bit i indicates whether the i-th local element (lane_id + i * warp_size) was selected. + uint32_t local_mask = 0; + assert(data_size <= static_cast(sizeof(local_mask) * 8 * kThreadsPerWarp) && + "local_mask too small for data_size > 1024"); + for (int k = 0; k < topk; k++) { - // Find the max value and its index - CompType val = (lane_id < data_size && !is_masked(k, lane_id)) - ? scores[lane_id] - : -std::numeric_limits::infinity(); - int index = (lane_id < data_size) ? lane_id : 0; - // Some value is hanlded in local thread - // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... - // Reduce the value in local thread - for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - CompType cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() : scores[i]; - if (cur_val > val) { - val = cur_val; - index = i; + CompType local_max_val = -std::numeric_limits::infinity(); + int local_max_idx = -1; + + // 1) Per-lane local max on unmasked elements. + int bit_idx = 0; + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + CompType cur_val = 0.0f; + uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); + uint32_t x_bits = __float_as_uint(static_cast(scores[i])); + uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u); + cur_val = __uint_as_float(result_bits); + if (cur_val > local_max_val) { + local_max_val = cur_val; + local_max_idx = i; } + bit_idx++; } - // Warp shuffle between threads - for (int s = 16; s > 0; s /= 2) { - auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); - auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); - if (shuffled_val > val) { - val = shuffled_val; - index = shuffled_index; + + // 2) Warp reduction to find global max and index. + CompType global_max_val = local_max_val; + int global_max_idx = local_max_idx; + for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { + CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s); + int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s); + if (shuffled_val > global_max_val) { + global_max_val = shuffled_val; + global_max_idx = shuffled_idx; } } + global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0); + global_max_val = __shfl_sync(0xffffffff, global_max_val, 0); + + // 3) Write top-k result. if (lane_id == 0) { - topk_indices[k] = index; - topk_scores[k] = val; + topk_indices[k] = global_max_idx; + topk_scores[k] = static_cast(global_max_val); + } + + // 4) Mark selected element in owning lane's local mask. + if (global_max_idx >= 0 && (global_max_idx % kThreadsPerWarp) == lane_id) { + int local_bit_pos = global_max_idx / kThreadsPerWarp; + if (local_bit_pos < 32) { + local_mask |= (1u << local_bit_pos); + } } - __syncwarp(); } + __syncwarp(); } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future