diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu new file mode 100644 index 0000000000..35ca329587 --- /dev/null +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu @@ -0,0 +1,139 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../utils.cuh" +#include "common/util/cuda_runtime.h" +#include "utils.h" + +namespace transformer_engine { + +// Using float to handle all the calculations +using CompType = float; + +template +__global__ void fused_moe_aux_loss_forward_kernel_v2(const DataType* probs, + const IndexType* tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff, DataType* aux_loss, + float* Const_buf) { + // ----------------------------------------------------------------------- + // 1) Compute the constant coefficient (identical for all threads) + // ----------------------------------------------------------------------- + const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; + + // Write the coefficient once and initialize the output accumulator. + if (blockIdx.x == 0 && threadIdx.x == 0) { + Const_buf[0] = C_coeff; + aux_loss[0] = static_cast(0); + } + + // ----------------------------------------------------------------------- + // 2) Each CTA computes a partial dot‑product: + // Σ_col ( Σ_row probs[row, col] ) * tokens_per_expert[col] + // ----------------------------------------------------------------------- + CompType thread_sum = CompType(0); + + // Grid‑stride over rows so that every row is processed exactly once. + // Each thread processes a subset of columns. + for (int col = threadIdx.x; col < num_cols; col += blockDim.x) { + CompType col_sum = CompType(0); + + // Accumulate probs over the rows assigned to this CTA (grid‑stride). + for (int row = blockIdx.x; row < num_rows; row += gridDim.x) { + col_sum += CompType(probs[row * num_cols + col]); + } + + // Multiply by the token count for this expert. + col_sum *= CompType(tokens_per_expert[col]); + + // Accumulate the per‑column contribution into the thread‑local sum. + thread_sum += col_sum; + } + + // ----------------------------------------------------------------------- + // 3) Block‑level reduction of thread_sum using warp_reduce_on_shmem + // ----------------------------------------------------------------------- + extern __shared__ float shmem[]; + CompType* shmem_block = reinterpret_cast(shmem); + shmem_block[threadIdx.x] = thread_sum; + __syncthreads(); + + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int lane_id = threadIdx.x % kThreadsPerWarp; + if (warp_id == 0) { + CompType block_sum = + warp_reduce_on_shmem(shmem_block, blockDim.x, ReduceFuncType::SUM, lane_id); + __syncwarp(); + + // ----------------------------------------------------------------------- + // 4) One atomic add per CTA to the global accumulator. + // The multiplication by C_coeff is folded into the atomic. + // ----------------------------------------------------------------------- + if (lane_id == 0) { + atomicAdd(reinterpret_cast(aux_loss), static_cast(block_sum * C_coeff)); + } + } +} + +/* ------------------------------------------------------------------------- + * Kernel launcher – simplified (no cluster launch). + * ------------------------------------------------------------------------- */ +template +void fused_moe_aux_loss_forward_kernel_launcher_v2(const DataType* probs, + const IndexType* tokens_per_expert, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff, DataType* aux_loss, + float* Const_buf, cudaStream_t stream) { + const int block_size = std::min(1024, num_cols); + const int grid_size = sm_count() * 2; + + // One CompType per thread in shared memory. + const size_t smem_size = block_size * sizeof(CompType); + + fused_moe_aux_loss_forward_kernel_v2 + <<>>(probs, tokens_per_expert, total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, + aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void fused_moe_aux_loss_forward_v2(const Tensor& probs, const Tensor& tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, Tensor& aux_loss, + Tensor& Const_buf, cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + probs.data.dtype, DataType, + TE_ROUTER_INDEX_TYPE_SWITCH_ALL( + tokens_per_expert.data.dtype, IndexType, + fused_moe_aux_loss_forward_kernel_launcher_v2( + reinterpret_cast(probs.data.dptr), + reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, + reinterpret_cast(aux_loss.data.dptr), + reinterpret_cast(Const_buf.data.dptr), stream););); +} + +void nvte_fused_moe_aux_loss_forward_v2(const NVTETensor probs, const NVTETensor tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, + NVTETensor Const_buf, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); + using namespace transformer_engine; + fused_moe_aux_loss_forward_v2( + *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), + *convertNVTETensorCheck(Const_buf), stream); +} + +} // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 794880d324..5cc9f6e8e5 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -110,6 +110,11 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream); +void nvte_fused_moe_aux_loss_forward_v2(const NVTETensor probs, const NVTETensor tokens_per_expert, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, + NVTETensor Const_buf, cudaStream_t stream); + /*! \brief Backward pass for auxiliary loss. * * \param[in] Const_buf Constant buffer from the forward pass. diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index c81671f104..df570f25d8 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -184,10 +184,10 @@ Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); - nvte_fused_moe_aux_loss_forward(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts, - num_tokens, num_experts, static_cast(topk), - static_cast(coeff), aux_loss_tensor.data(), - const_buf_tensor.data(), stream); + nvte_fused_moe_aux_loss_forward_v2(probs_tensor.data(), tpe_tensor.data(), num_tokens, + num_experts, num_tokens, num_experts, static_cast(topk), + static_cast(coeff), aux_loss_tensor.data(), + const_buf_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 94625c0f12..4d4d8660db 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -155,9 +155,10 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, auto aux_loss_cu = makeTransformerEngineTensor(aux_loss); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); - nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, - num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), - Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_moe_aux_loss_forward_v2(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, + num_experts, num_rows, num_cols, topk, coeff, + aux_loss_cu.data(), Const_buf_cu.data(), + at::cuda::getCurrentCUDAStream()); return std::make_tuple(aux_loss, Const_buf); }