Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions transformer_engine/common/fused_router/fused_moe_aux_loss_v2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_router.h>

#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
#include "common/util/cuda_runtime.h"
#include "utils.h"

namespace transformer_engine {

// Using float to handle all the calculations
using CompType = float;

template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel_v2(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk,
float coeff, DataType* aux_loss,
float* Const_buf) {
// -----------------------------------------------------------------------
// 1) Compute the constant coefficient (identical for all threads)
// -----------------------------------------------------------------------
const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;

// Write the coefficient once and initialize the output accumulator.
if (blockIdx.x == 0 && threadIdx.x == 0) {
Const_buf[0] = C_coeff;
aux_loss[0] = static_cast<DataType>(0);
}

// -----------------------------------------------------------------------
// 2) Each CTA computes a partial dot‑product:
// Σ_col ( Σ_row probs[row, col] ) * tokens_per_expert[col]
// -----------------------------------------------------------------------
CompType thread_sum = CompType(0);

// Grid‑stride over rows so that every row is processed exactly once.
// Each thread processes a subset of columns.
for (int col = threadIdx.x; col < num_cols; col += blockDim.x) {
CompType col_sum = CompType(0);

// Accumulate probs over the rows assigned to this CTA (grid‑stride).
for (int row = blockIdx.x; row < num_rows; row += gridDim.x) {
col_sum += CompType(probs[row * num_cols + col]);
}

// Multiply by the token count for this expert.
col_sum *= CompType(tokens_per_expert[col]);

// Accumulate the per‑column contribution into the thread‑local sum.
thread_sum += col_sum;
}

// -----------------------------------------------------------------------
// 3) Block‑level reduction of thread_sum using warp_reduce_on_shmem
// -----------------------------------------------------------------------
extern __shared__ float shmem[];
CompType* shmem_block = reinterpret_cast<CompType*>(shmem);
shmem_block[threadIdx.x] = thread_sum;
__syncthreads();

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int lane_id = threadIdx.x % kThreadsPerWarp;
if (warp_id == 0) {
CompType block_sum =
warp_reduce_on_shmem(shmem_block, blockDim.x, ReduceFuncType::SUM, lane_id);
__syncwarp();

// -----------------------------------------------------------------------
// 4) One atomic add per CTA to the global accumulator.
// The multiplication by C_coeff is folded into the atomic.
// -----------------------------------------------------------------------
if (lane_id == 0) {
atomicAdd(reinterpret_cast<float*>(aux_loss), static_cast<float>(block_sum * C_coeff));
}
}
}

/* -------------------------------------------------------------------------
* Kernel launcher – simplified (no cluster launch).
* ------------------------------------------------------------------------- */
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_forward_kernel_launcher_v2(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk,
float coeff, DataType* aux_loss,
float* Const_buf, cudaStream_t stream) {
const int block_size = std::min(1024, num_cols);
const int grid_size = sm_count() * 2;

// One CompType per thread in shared memory.
const size_t smem_size = block_size * sizeof(CompType);

fused_moe_aux_loss_forward_kernel_v2<DataType, IndexType>
<<<grid_size, block_size, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff,
aux_loss, Const_buf);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void fused_moe_aux_loss_forward_v2(const Tensor& probs, const Tensor& tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, Tensor& aux_loss,
Tensor& Const_buf, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
probs.data.dtype, DataType,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL(
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_forward_kernel_launcher_v2<DataType, IndexType>(
reinterpret_cast<DataType*>(probs.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff,
reinterpret_cast<DataType*>(aux_loss.data.dptr),
reinterpret_cast<float*>(Const_buf.data.dptr), stream);););
}

void nvte_fused_moe_aux_loss_forward_v2(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_forward);
using namespace transformer_engine;
fused_moe_aux_loss_forward_v2(
*convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss),
*convertNVTETensorCheck(Const_buf), stream);
}

} // namespace transformer_engine
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream);

void nvte_fused_moe_aux_loss_forward_v2(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream);

/*! \brief Backward pass for auxiliary loss.
*
* \param[in] Const_buf Constant buffer from the forward pass.
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/jax/csrc/extensions/router.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream,
auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype);
auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32);

nvte_fused_moe_aux_loss_forward(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts,
num_tokens, num_experts, static_cast<int>(topk),
static_cast<float>(coeff), aux_loss_tensor.data(),
const_buf_tensor.data(), stream);
nvte_fused_moe_aux_loss_forward_v2(probs_tensor.data(), tpe_tensor.data(), num_tokens,
num_experts, num_tokens, num_experts, static_cast<int>(topk),
static_cast<float>(coeff), aux_loss_tensor.data(),
const_buf_tensor.data(), stream);

return ffi_with_cuda_error_check();
}
Expand Down
7 changes: 4 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/router.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
auto aux_loss_cu = makeTransformerEngineTensor(aux_loss);
auto Const_buf_cu = makeTransformerEngineTensor(Const_buf);

nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(),
Const_buf_cu.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_moe_aux_loss_forward_v2(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens,
num_experts, num_rows, num_cols, topk, coeff,
aux_loss_cu.data(), Const_buf_cu.data(),
at::cuda::getCurrentCUDAStream());

return std::make_tuple(aux_loss, Const_buf);
}
Expand Down