From 8b866e3a47f7864ae2d5463ac380b42e8114ad59 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 13 Mar 2026 11:04:18 +0000 Subject: [PATCH 1/2] added new implementation of fused_moe_aux_loss_forward kernel Signed-off-by: Alp Dener --- .../fused_router/fused_moe_aux_loss_v2.cu | 150 ++++++++++++++++++ .../include/transformer_engine/fused_router.h | 5 + .../jax/csrc/extensions/router.cpp | 8 +- .../pytorch/csrc/extensions/router.cpp | 6 +- 4 files changed, 162 insertions(+), 7 deletions(-) create mode 100644 transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu new file mode 100644 index 0000000000..7b47cfe03c --- /dev/null +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu @@ -0,0 +1,150 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.cuh" +#include "common/util/cuda_runtime.h" +#include "utils.h" + +namespace transformer_engine { + +// Using float to handle all the calculations +using CompType = float; + +template +__global__ void fused_moe_aux_loss_forward_kernel_v2(const DataType* probs, + const IndexType* tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff, + DataType* aux_loss, float* Const_buf) { + // ----------------------------------------------------------------------- + // 1) Compute the constant coefficient (identical for all threads) + // ----------------------------------------------------------------------- + const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; + + // Write the coefficient once and initialize the output accumulator. + if (blockIdx.x == 0 && threadIdx.x == 0) { + Const_buf[0] = C_coeff; + aux_loss[0] = static_cast(0); + } + + // ----------------------------------------------------------------------- + // 2) Each CTA computes a partial dot‑product: + // Σ_col ( Σ_row probs[row, col] ) * tokens_per_expert[col] + // ----------------------------------------------------------------------- + CompType thread_sum = CompType(0); + + // Grid‑stride over rows so that every row is processed exactly once. + // Each thread processes a subset of columns. + for (int col = threadIdx.x; col < num_cols; col += blockDim.x) { + CompType col_sum = CompType(0); + + // Accumulate probs over the rows assigned to this CTA (grid‑stride). + for (int row = blockIdx.x; row < num_rows; row += gridDim.x) { + col_sum += CompType(probs[row * num_cols + col]); + } + + // Multiply by the token count for this expert. + col_sum *= CompType(tokens_per_expert[col]); + + // Accumulate the per‑column contribution into the thread‑local sum. + thread_sum += col_sum; + } + + // ----------------------------------------------------------------------- + // 3) Block‑level reduction of thread_sum using warp_reduce_on_shmem + // ----------------------------------------------------------------------- + extern __shared__ float shmem[]; + CompType* shmem_block = reinterpret_cast(shmem); + shmem_block[threadIdx.x] = thread_sum; + __syncthreads(); + + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int lane_id = threadIdx.x % kThreadsPerWarp; + if (warp_id == 0) { + CompType block_sum = warp_reduce_on_shmem(shmem_block, + blockDim.x, + ReduceFuncType::SUM, + lane_id); + __syncwarp(); + + // ----------------------------------------------------------------------- + // 4) One atomic add per CTA to the global accumulator. + // The multiplication by C_coeff is folded into the atomic. + // ----------------------------------------------------------------------- + if (lane_id == 0) { + atomicAdd(reinterpret_cast(aux_loss), + static_cast(block_sum * C_coeff)); + } + } +} + +/* ------------------------------------------------------------------------- + * Kernel launcher – simplified (no cluster launch). + * ------------------------------------------------------------------------- */ +template +void fused_moe_aux_loss_forward_kernel_launcher_v2(const DataType* probs, + const IndexType* tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff, + DataType* aux_loss, float* Const_buf, + cudaStream_t stream) { + const int block_size = std::min(1024, num_cols); + const int grid_size = sm_count() * 2; + + // One CompType per thread in shared memory. + const size_t smem_size = block_size * sizeof(CompType); + + fused_moe_aux_loss_forward_kernel_v2 + <<>>(probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + aux_loss, + Const_buf); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void fused_moe_aux_loss_forward_v2(const Tensor& probs, const Tensor& tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, int num_cols, + int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf, + cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + probs.data.dtype, DataType, + TE_ROUTER_INDEX_TYPE_SWITCH_ALL( + tokens_per_expert.data.dtype, IndexType, + fused_moe_aux_loss_forward_kernel_launcher_v2( + reinterpret_cast(probs.data.dptr), + reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, + reinterpret_cast(aux_loss.data.dptr), + reinterpret_cast(Const_buf.data.dptr), stream););); +} + +void nvte_fused_moe_aux_loss_forward_v2(const NVTETensor probs, const NVTETensor tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, + NVTETensor Const_buf, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); + using namespace transformer_engine; + fused_moe_aux_loss_forward_v2( + *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), + *convertNVTETensorCheck(Const_buf), stream); +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 794880d324..5cc9f6e8e5 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -110,6 +110,11 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream); +void nvte_fused_moe_aux_loss_forward_v2(const NVTETensor probs, const NVTETensor tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, + NVTETensor Const_buf, cudaStream_t stream); + /*! \brief Backward pass for auxiliary loss. * * \param[in] Const_buf Constant buffer from the forward pass. diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index c81671f104..1a3a25dfef 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -184,10 +184,10 @@ Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); - nvte_fused_moe_aux_loss_forward(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts, - num_tokens, num_experts, static_cast(topk), - static_cast(coeff), aux_loss_tensor.data(), - const_buf_tensor.data(), stream); + nvte_fused_moe_aux_loss_forward_v2(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts, + num_tokens, num_experts, static_cast(topk), + static_cast(coeff), aux_loss_tensor.data(), + const_buf_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 94625c0f12..9e220098e5 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -155,9 +155,9 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, auto aux_loss_cu = makeTransformerEngineTensor(aux_loss); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); - nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, - num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), - Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_moe_aux_loss_forward_v2(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), + Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); return std::make_tuple(aux_loss, Const_buf); } From 1071b6b2ab6d84b8a327801bb0bfac1d81ffb88f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 11:22:40 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_router/fused_moe_aux_loss_v2.cu | 37 +++++++------------ .../jax/csrc/extensions/router.cpp | 4 +- .../pytorch/csrc/extensions/router.cpp | 5 ++- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu index 7b47cfe03c..35ca329587 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu @@ -24,8 +24,8 @@ __global__ void fused_moe_aux_loss_forward_kernel_v2(const DataType* probs, const IndexType* tokens_per_expert, int total_num_tokens, int num_experts, int num_rows, int num_cols, int topk, - float coeff, - DataType* aux_loss, float* Const_buf) { + float coeff, DataType* aux_loss, + float* Const_buf) { // ----------------------------------------------------------------------- // 1) Compute the constant coefficient (identical for all threads) // ----------------------------------------------------------------------- @@ -71,10 +71,8 @@ __global__ void fused_moe_aux_loss_forward_kernel_v2(const DataType* probs, const int warp_id = threadIdx.x / kThreadsPerWarp; const int lane_id = threadIdx.x % kThreadsPerWarp; if (warp_id == 0) { - CompType block_sum = warp_reduce_on_shmem(shmem_block, - blockDim.x, - ReduceFuncType::SUM, - lane_id); + CompType block_sum = + warp_reduce_on_shmem(shmem_block, blockDim.x, ReduceFuncType::SUM, lane_id); __syncwarp(); // ----------------------------------------------------------------------- @@ -82,8 +80,7 @@ __global__ void fused_moe_aux_loss_forward_kernel_v2(const DataType* probs, // The multiplication by C_coeff is folded into the atomic. // ----------------------------------------------------------------------- if (lane_id == 0) { - atomicAdd(reinterpret_cast(aux_loss), - static_cast(block_sum * C_coeff)); + atomicAdd(reinterpret_cast(aux_loss), static_cast(block_sum * C_coeff)); } } } @@ -96,9 +93,8 @@ void fused_moe_aux_loss_forward_kernel_launcher_v2(const DataType* probs, const IndexType* tokens_per_expert, int total_num_tokens, int num_experts, int num_rows, int num_cols, int topk, - float coeff, - DataType* aux_loss, float* Const_buf, - cudaStream_t stream) { + float coeff, DataType* aux_loss, + float* Const_buf, cudaStream_t stream) { const int block_size = std::min(1024, num_cols); const int grid_size = sm_count() * 2; @@ -106,23 +102,16 @@ void fused_moe_aux_loss_forward_kernel_launcher_v2(const DataType* probs, const size_t smem_size = block_size * sizeof(CompType); fused_moe_aux_loss_forward_kernel_v2 - <<>>(probs, - tokens_per_expert, - total_num_tokens, - num_experts, - num_rows, - num_cols, - topk, - coeff, - aux_loss, - Const_buf); + <<>>(probs, tokens_per_expert, total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, + aux_loss, Const_buf); NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_moe_aux_loss_forward_v2(const Tensor& probs, const Tensor& tokens_per_expert, - int total_num_tokens, int num_experts, int num_rows, int num_cols, - int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf, - cudaStream_t stream) { + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, Tensor& aux_loss, + Tensor& Const_buf, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( probs.data.dtype, DataType, TE_ROUTER_INDEX_TYPE_SWITCH_ALL( diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index 1a3a25dfef..df570f25d8 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -184,8 +184,8 @@ Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); - nvte_fused_moe_aux_loss_forward_v2(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts, - num_tokens, num_experts, static_cast(topk), + nvte_fused_moe_aux_loss_forward_v2(probs_tensor.data(), tpe_tensor.data(), num_tokens, + num_experts, num_tokens, num_experts, static_cast(topk), static_cast(coeff), aux_loss_tensor.data(), const_buf_tensor.data(), stream); diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 9e220098e5..4d4d8660db 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -156,8 +156,9 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); nvte_fused_moe_aux_loss_forward_v2(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, - num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), - Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); + num_experts, num_rows, num_cols, topk, coeff, + aux_loss_cu.data(), Const_buf_cu.data(), + at::cuda::getCurrentCUDAStream()); return std::make_tuple(aux_loss, Const_buf); }