From 4534154a05f8f6743e5ef38c5f5951f2a322be91 Mon Sep 17 00:00:00 2001 From: chen Date: Fri, 10 Apr 2026 07:26:54 +0000 Subject: [PATCH 1/2] Refactor(linear): split LinearBackward kernel into 3 independent kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters. --- infini_train/include/autograd/linear.h | 6 - infini_train/src/autograd/linear.cc | 30 ++- infini_train/src/kernels/cpu/linear.cc | 84 ++++--- infini_train/src/kernels/cuda/linear.cu | 297 ++++++++++++------------ 4 files changed, 203 insertions(+), 214 deletions(-) diff --git a/infini_train/include/autograd/linear.h b/infini_train/include/autograd/linear.h index 21d107b9..cebed3b2 100644 --- a/infini_train/include/autograd/linear.h +++ b/infini_train/include/autograd/linear.h @@ -12,12 +12,6 @@ class Tensor; namespace infini_train::autograd { -struct LinearGradFlags { - bool input = false; - bool weight = false; - bool bias = false; -}; - class Linear : public Function { public: static constexpr char kType[] = "LinearFunction"; diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index c9ed1dbb..ff0283ce 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -56,17 +56,29 @@ std::vector> Linear::Backward(const std::vector 1 && needs_input_grad_[1], - .bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]}; + bool need_grad_input = needs_input_grad_[0]; + bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]; auto device = grad_output->GetDevice().type(); - // TODO: skip autograd graph construction entirely when no input requires grad - auto [grad_input, grad_weight, grad_bias] - = Dispatcher::Instance() - .Call, std::shared_ptr, std::shared_ptr>>( - {device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_, - grad_output, bias_, grad_flags); + + std::shared_ptr grad_input = nullptr; + std::shared_ptr grad_weight = nullptr; + std::shared_ptr grad_bias = nullptr; + + if (need_grad_input) { + grad_input = Dispatcher::Instance().Call>( + {device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_); + } + if (need_grad_weight) { + grad_weight = Dispatcher::Instance().Call>( + {device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_); + } + if (need_grad_bias) { + grad_bias = Dispatcher::Instance().Call>({device, "LinearBackwardBias"}, grad_output, + out_features_); + } + if (bias_) { return {grad_input, grad_weight, grad_bias}; } else { diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index 2b209417..f238135c 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -5,7 +5,6 @@ #include "glog/logging.h" -#include "infini_train/include/autograd/linear.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -146,62 +145,55 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons return output; } -// TODO(dcj): support linear without bias later -std::tuple, std::shared_ptr, std::shared_ptr> -LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t in_features, int64_t out_features, const std::vector &input_dims, - const std::shared_ptr &grad_output, bool bias, - infini_train::autograd::LinearGradFlags grad_flags) { +std::shared_ptr LinearBackwardInput(const std::shared_ptr &weight, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features, + const std::vector &input_dims) { /* transpose: grad_input = grad_output * weight grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features] - grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features] - grad_bias[out_features] = grad_output[*, out_features].sum(axis=0) !transpose: grad_input = grad_output * weight^T grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T - grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features] - grad_bias[out_features] = grad_output[*, out_features].sum(axis=0) */ - const auto compute_grad_input = grad_flags.input; - const auto compute_grad_weight = grad_flags.weight; - const auto compute_grad_bias = grad_flags.bias; - CHECK_GE(input_dims.size(), 2); - - std::vector weight_dims - = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; - - std::shared_ptr grad_input = nullptr; - std::shared_ptr grad_weight = nullptr; - std::shared_ptr grad_bias = nullptr; - - if (compute_grad_input) { - CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; - grad_input = std::make_shared(input_dims, DataType::kFLOAT32); - if (transpose) { - grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix(); - } else { - grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose(); - } + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + if (transpose) { + grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix(); + } else { + grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose(); } + return grad_input; +} - if (compute_grad_weight) { - CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; - grad_weight = std::make_shared(weight_dims, DataType::kFLOAT32); - if (transpose) { - grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix(); - } else { - grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix(); - } - } +std::shared_ptr LinearBackwardWeight(const std::shared_ptr &input, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features) { + /* + transpose: + grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features] - if (compute_grad_bias && bias) { - grad_bias = std::make_shared(std::vector{out_features}, DataType::kFLOAT32); - grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum(); + !transpose: + grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features] + */ + std::vector weight_dims + = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; + auto grad_weight = std::make_shared(weight_dims, DataType::kFLOAT32); + if (transpose) { + grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix(); + } else { + grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix(); } + return grad_weight; +} - return {grad_input, grad_weight, grad_bias}; +std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_output, int64_t out_features) { + /* + grad_bias[out_features] = grad_output[*, out_features].sum(axis=0) + */ + auto grad_bias = std::make_shared(std::vector{out_features}, DataType::kFLOAT32); + grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum(); + return grad_bias; } } // namespace infini_train::kernels::cpu @@ -211,6 +203,8 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr #include -#include "infini_train/include/autograd/linear.h" #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/core/runtime/device_guard.h" @@ -317,183 +316,171 @@ __global__ void ReduceColumnsKernel(const TIn *__restrict__ input, TOut *__restr } } -std::tuple, std::shared_ptr, std::shared_ptr> -LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t in_features, int64_t out_features, const std::vector &input_dims, - const std::shared_ptr &grad_output, bool bias, - infini_train::autograd::LinearGradFlags grad_flags) { - const auto compute_grad_input = grad_flags.input; - const auto compute_grad_weight = grad_flags.weight; - const auto compute_grad_bias = grad_flags.bias; - +std::shared_ptr LinearBackwardInput(const std::shared_ptr &weight, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features, + const std::vector &input_dims) { CHECK_GE(input_dims.size(), 2); const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies{}); - const std::vector weight_dims - = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; + auto compute_dtype = weight->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); + auto grad_output_promoted + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - auto dtype = grad_output->Dtype(); + // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. + auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); - // For type promotion, use available tensors - DataType input_dtype = input ? input->Dtype() : (weight ? weight->Dtype() : dtype); - DataType weight_dtype = weight ? weight->Dtype() : (input ? input->Dtype() : dtype); - // Compute dtype determined by saved tensors (forward compute dtype), not grad_output - DataType compute_dtype = DispatchFunc, DataTypeList>( - {input_dtype, weight_dtype}, [=]() { return DataTypeMap_v>; }, - "CUDA LinearBackward"); + auto device = grad_output->GetDevice(); + float alpha = 1.0f; + float beta = 0.0f; + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); + + // TODO(zbl): use cublasSgemv if possible + // - if transpose: + // weight is [out_features, in_features] here + // d_input = d_output * weight --> d_input.T = weight.T * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[in_features, out_features] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // weight is [in_features, out_features] here + // d_input = d_output * weight.T --> d_input.T = weight * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[out_features, in_features] + // B = d_output.T[out_features, bs] + auto trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; + auto lda = transpose ? in_features : out_features; + switch (compute_dtype) { + DISPATCH_CASE(WRAP({ + CUBLAS_CHECK(cublasSgemm(handle, trans_a, CUBLAS_OP_N, in_features, bs, out_features, &alpha, + static_cast(weight->DataPtr()), lda, + static_cast(grad_output_promoted->DataPtr()), + out_features, &beta, static_cast(grad_input->DataPtr()), + in_features)); + }), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + CUBLAS_CHECK(cublasGemmEx( + handle, trans_a, CUBLAS_OP_N, in_features, bs, out_features, &alpha, weight->DataPtr(), + CUDA_R_16BF, lda, grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, &beta, + grad_input->DataPtr(), CUDA_R_32F, in_features, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + }), + DataType::kBFLOAT16) + } + + return grad_input; +} + +std::shared_ptr LinearBackwardWeight(const std::shared_ptr &input, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features) { + const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(grad_output_dims.size(), 2); + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 1, grad_output_dims.rend(), 1, std::multiplies{}); + + auto compute_dtype = input->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); auto grad_output_promoted - = dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + const std::vector weight_dims + = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; + // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. + auto grad_weight = std::make_shared(weight_dims, output_dtype, grad_output->GetDevice()); - // Allocate only needed gradient tensors (selective save: input/weight may be nullptr). - std::shared_ptr grad_input = nullptr; - std::shared_ptr grad_weight = nullptr; - std::shared_ptr grad_bias = nullptr; + auto device = grad_output->GetDevice(); + float alpha = 1.0f; + float beta = 0.0f; + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); - if (compute_grad_input) { - grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); - } - if (compute_grad_weight) { - grad_weight = std::make_shared(weight_dims, output_dtype, grad_output->GetDevice()); - } - // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output, and ReduceColumnsKernel assigns directly. - if (compute_grad_bias && bias) { - grad_bias - = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); + // - if transpose: + // d_weight = d_output.T * input --> d_weight.T = input.T * d_output + // C = d_weight.T[in_features, out_features] + // A = input.T[in_features, bs] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // d_weight = input.T * d_output --> d_weight.T = d_output.T * input + // C = d_weight.T[out_features, in_features] + // A = d_output.T[out_features, bs] + // B = input.T[in_features, bs] + int m = transpose ? in_features : out_features; + int n = transpose ? out_features : in_features; + auto ldc = transpose ? in_features : out_features; + + switch (compute_dtype) { + DISPATCH_CASE(WRAP({ + const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); + const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); + auto lda = transpose ? in_features : out_features; + auto ldb = transpose ? out_features : in_features; + CUBLAS_CHECK(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, bs, &alpha, + static_cast(a), lda, static_cast(b), + ldb, &beta, static_cast(grad_weight->DataPtr()), ldc)); + }), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); + const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); + auto lda = transpose ? in_features : out_features; + auto ldb = transpose ? out_features : in_features; + CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, bs, &alpha, a, CUDA_R_16BF, + lda, b, CUDA_R_16BF, ldb, &beta, grad_weight->DataPtr(), CUDA_R_32F, + ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + }), + DataType::kBFLOAT16) } + return grad_weight; +} + +std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_output, int64_t out_features) { + const auto &dims = grad_output->Dims(); + CHECK_GE(dims.size(), 2); + const int64_t bs = std::accumulate(dims.rbegin() + 1, dims.rend(), 1, std::multiplies{}); + + auto compute_dtype = grad_output->Dtype(); + // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_bias + = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); + auto device = grad_output->GetDevice(); const auto &cuda_stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - float alpha = 1.0f; - float beta = 0.0f; - - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - + // d_bias = \sum_i(i=0, bs-1) d_output[i] + // TODO(dcj): use thrust::fill or reduce kernel do this + constexpr int BLOCK_SIZE = 256; switch (compute_dtype) { - // TODO(zbl): use cublasSgemv if possible - DISPATCH_CASE( - WRAP({ - if (compute_grad_input) { - // - if transpose: - // weight is [out_features, in_features] here - // d_input = d_output * weight --> d_input.T = weight.T * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[in_features, out_features] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // weight is [in_features, out_features] here - // d_input = d_output * weight.T --> d_input.T = weight * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[out_features, in_features] - // B = d_output.T[out_features, bs] - CHECK(weight != nullptr) - << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; - auto weight_promoted - = weight_dtype == compute_dtype ? weight : std::make_shared(weight->To(compute_dtype)); - auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto lda1 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasSgemm(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, &alpha, - static_cast(weight_promoted->DataPtr()), lda1, - static_cast(grad_output_promoted->DataPtr()), out_features, - &beta, static_cast(grad_input->DataPtr()), in_features)); - } - if (compute_grad_weight) { - // - if transpose: - // d_weight = d_output.T * input --> d_weight.T = input.T * d_output - // C = d_weight.T[in_features, out_features] - // A = input.T[in_features, bs] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // d_weight = input.T * d_output --> d_weight.T = d_output.T * input - // C = d_weight.T[out_features, in_features] - // A = d_output.T[out_features, bs] - // B = input.T[in_features, bs] - CHECK(input != nullptr) - << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; - auto input_promoted - = input_dtype == compute_dtype ? input : std::make_shared(input->To(compute_dtype)); - auto trans_a2 = CUBLAS_OP_N; - auto trans_b2 = CUBLAS_OP_T; - int m2 = transpose ? in_features : out_features; - int n2 = transpose ? out_features : in_features; - const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); - const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); - auto lda2 = transpose ? in_features : out_features; - auto ldb2 = transpose ? out_features : in_features; - auto ldc2 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, - static_cast(a2), lda2, static_cast(b2), ldb2, - &beta, static_cast(grad_weight->DataPtr()), ldc2)); - } - // d_bias = \sum_i(i=0, bs-1) d_output[i] - // TODO(dcj): use thrust::fill or reduce kernel do this - if (compute_grad_bias && bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } - }), - DataType::kFLOAT32) DISPATCH_CASE(WRAP({ - if (compute_grad_input) { - CHECK(weight != nullptr) - << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; - auto weight_promoted = weight_dtype == compute_dtype - ? weight - : std::make_shared(weight->To(compute_dtype)); - auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto lda1 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, - &alpha, weight_promoted->DataPtr(), CUDA_R_16BF, lda1, - grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, - &beta, grad_input->DataPtr(), CUDA_R_32F, in_features, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - } - if (compute_grad_weight) { - CHECK(input != nullptr) - << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; - auto input_promoted = input_dtype == compute_dtype - ? input - : std::make_shared(input->To(compute_dtype)); - auto trans_a2 = CUBLAS_OP_N; - auto trans_b2 = CUBLAS_OP_T; - int m2 = transpose ? in_features : out_features; - int n2 = transpose ? out_features : in_features; - const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); - const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); - auto lda2 = transpose ? in_features : out_features; - auto ldb2 = transpose ? out_features : in_features; - auto ldc2 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasGemmEx(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, a2, CUDA_R_16BF, - lda2, b2, CUDA_R_16BF, ldb2, &beta, grad_weight->DataPtr(), - CUDA_R_32F, ldc2, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - } - if (compute_grad_bias && bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } + ReduceColumnsKernel<<>>( + static_cast(grad_output->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); + }), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + ReduceColumnsKernel<<>>( + static_cast(grad_output->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); }), DataType::kBFLOAT16) } - return {grad_input, grad_weight, grad_bias}; + return grad_bias; } } // namespace infini_train::kernels::cuda @@ -503,6 +490,8 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr Date: Fri, 10 Apr 2026 08:15:48 +0000 Subject: [PATCH 2/2] refactor(matmul): split MatmulBackward kernel into 2 independent kernels Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2. --- infini_train/src/autograd/matmul.cc | 33 ++++- infini_train/src/kernels/cpu/linear.cc | 71 ++++++--- infini_train/src/kernels/cuda/linear.cu | 182 ++++++++++++++---------- infini_train/src/kernels/cuda/outer.cu | 2 +- 4 files changed, 185 insertions(+), 103 deletions(-) diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 49f593bf..259cb4a4 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -31,10 +31,17 @@ void Matmul::SetupContext(const std::vector> &input_tens // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be // determined by autocast, not derived from output->Dtype(). auto compute_dtype = output->Dtype(); - saved_tensors_ = { - input1->Dtype() == compute_dtype ? input1 : std::make_shared(input1->To(compute_dtype)), - input2->Dtype() == compute_dtype ? input2 : std::make_shared(input2->To(compute_dtype)), + + // grad_input1 = grad_output @ input2^T, so input2 is needed + // grad_input2 = grad_output^T @ input1, so input1 is needed + bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0]; + bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + + auto cast = [&](const std::shared_ptr &t) { + return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); }; + + saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}; out_features_ = output->Dims()[0]; } @@ -45,10 +52,24 @@ std::vector> Matmul::Backward(const std::vector 0 && needs_input_grad_[0]; + bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + auto device = input1->GetDevice().type(); - auto [grad_input1, grad_input2] - = Dispatcher::Instance().Call, std::shared_ptr>>( - {device, "MatmulBackward"}, input1, input2, grad_output); + + std::shared_ptr grad_input1 = nullptr; + std::shared_ptr grad_input2 = nullptr; + + if (need_grad_input1) { + grad_input1 = Dispatcher::Instance().Call>({device, "MatmulBackwardInput1"}, input2, + grad_output, input1->Dims()); + } + if (need_grad_input2) { + grad_input2 = Dispatcher::Instance().Call>({device, "MatmulBackwardInput2"}, input1, + grad_output, input2->Dims()); + } + return {grad_input1, grad_input2}; } } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index f238135c..361c56f8 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -50,38 +50,71 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons return {output}; } -std::tuple, std::shared_ptr> -MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { +std::shared_ptr MatmulBackwardInput1(const std::shared_ptr &other, + const std::shared_ptr &grad_output, + const std::vector &input_dims) { /* grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] */ - const auto &input_dims = input->Dims(); const auto &other_dims = other->Dims(); const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(other_dims.size(), grad_output_dims.size()); + + const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; + const int64_t k = other_dims[other_dims.size() - 2]; + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; + + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) { + CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + grad_input->Fill(0.0f); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; + for (int64_t p = 0; p < k; ++p) { + const auto other_idx = b * k * n + p * n + j; + static_cast(grad_input->DataPtr())[b * m * k + i * k + p] + += grad * static_cast(other->DataPtr())[other_idx]; + } + } + } + } + return grad_input; +} + +std::shared_ptr MatmulBackwardInput2(const std::shared_ptr &input1, + const std::shared_ptr &grad_output, + const std::vector &other_dims) { + /* + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input1->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); CHECK_EQ(input_dims.size(), grad_output_dims.size()); const int64_t m = input_dims[input_dims.size() - 2]; const int64_t k = input_dims[input_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - const int64_t n = other_dims[other_dims.size() - 1]; - + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); - CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + CHECK_EQ(k, other_dims[other_dims.size() - 2]); const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; } - auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); auto grad_other = std::make_shared(other_dims, DataType::kFLOAT32); - grad_input->Fill(0.0f); grad_other->Fill(0.0f); for (int64_t b = 0; b < bs; ++b) { @@ -90,16 +123,13 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptr(grad_output->DataPtr())[b * m * n + i * n + j]; for (int64_t p = 0; p < k; ++p) { const auto input_idx = b * m * k + i * k + p; - const auto other_idx = b * k * n + p * n + j; - static_cast(grad_input->DataPtr())[input_idx] - += grad * static_cast(other->DataPtr())[other_idx]; - static_cast(grad_other->DataPtr())[other_idx] - += grad * static_cast(input->DataPtr())[input_idx]; + static_cast(grad_other->DataPtr())[b * k * n + p * n + j] + += grad * static_cast(input1->DataPtr())[input_idx]; } } } } - return {grad_input, grad_other}; + return grad_other; } std::shared_ptr LinearForward(const std::shared_ptr &input, const std::shared_ptr &weight, @@ -201,7 +231,8 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_LINEAR_KERNEL(MatmulForward) -REGISTER_CPU_LINEAR_KERNEL(MatmulBackward) +REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1) +REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2) REGISTER_CPU_LINEAR_KERNEL(LinearForward) REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput) REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight) diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index ec079b5d..cbc74c5e 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -80,112 +80,141 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons return output; } -std::tuple, std::shared_ptr> -MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { +std::shared_ptr MatmulBackwardInput1(const std::shared_ptr &other, + const std::shared_ptr &grad_output, + const std::vector &input_dims) { /* grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] */ - auto input_dtype = input->Dtype(); - auto other_dtype = other->Dtype(); - auto grad_output_dtype = grad_output->Dtype(); - // Compute dtype determined by saved tensors (forward compute dtype), not grad_output - DataType compute_dtype = DispatchFunc, DataTypeList>( - {input_dtype, other_dtype}, [=]() { return DataTypeMap_v>; }, - "CUDA MatmulBackward"); + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); - auto input_promoted = input_dtype == compute_dtype ? input : std::make_shared(input->To(compute_dtype)); - auto other_promoted = other_dtype == compute_dtype ? other : std::make_shared(other->To(compute_dtype)); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(other_dims.size(), grad_output_dims.size()); + + const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; + const int64_t k = other_dims[other_dims.size() - 2]; + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) { + CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto compute_dtype = other->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); auto grad_output_promoted = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - const auto &input_dims = input->Dims(); - const auto &other_dims = other->Dims(); - const auto &grad_output_dims = grad_output->Dims(); + // For bf16 compute, output in fp32 to preserve accumulation precision. + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); + + // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. + + auto device = grad_output->GetDevice(); + const float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); + // cuBLAS is colmun-major + // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T + // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] + // C = grad_input.T[*, k, m] + // A = other.T[*, n, k] + // B = grad_output.T[*, n, m] + const int lda = n, ldb = n, ldc = k; + const int64_t stride_a = k * n; + const int64_t stride_b = n * m; + const int64_t stride_c = m * k; + switch (compute_dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_32F, lda, + stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, + grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_16BF, lda, + stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, + grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + } + + return grad_input; +} + +std::shared_ptr MatmulBackwardInput2(const std::shared_ptr &input1, + const std::shared_ptr &grad_output, + const std::vector &other_dims) { + /* + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + + const auto &input_dims = input1->Dims(); + const auto &grad_output_dims = grad_output->Dims(); CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); CHECK_EQ(input_dims.size(), grad_output_dims.size()); const int64_t m = input_dims[input_dims.size() - 2]; const int64_t k = input_dims[input_dims.size() - 1]; - const int64_t n = other_dims[other_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + CHECK_EQ(input_dims[input_dims.size() - 1], other_dims[other_dims.size() - 2]); const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; } - // For bf16 compute, output in fp32 to preserve accumulation precision (matches PyTorch behavior) + auto compute_dtype = input1->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); + auto grad_output_promoted + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); + + // For bf16 compute, output in fp32 to preserve accumulation precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; - auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); auto grad_other = std::make_shared(other_dims, output_dtype, grad_output->GetDevice()); // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. - auto device = input_promoted->GetDevice(); + auto device = grad_output->GetDevice(); const float alpha = 1.0f, beta = 0.0f; cublasHandle_t handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - { - // cuBLAS is colmun-major - // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T - // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] - // C = grad_input.T[*, k, m] - // A = other.T[*, n, k] - // B = grad_output.T[*, n, m] - const int lda = n, ldb = n, ldc = k; - const int64_t stride_a = k * n; - const int64_t stride_b = n * m; - const int64_t stride_c = m * k; - switch (compute_dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_32F, - lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, - lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } - } - - { - // cuBLAS is colmun-major - // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input - // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] - // C = grad_other.T[*, n, k] - // A = grad_output.T[*, n, m] - // B = input.T[*, k, m] - const int lda = n, ldb = k, ldc = n; - const int64_t stride_a = n * m; - const int64_t stride_b = k * m; - const int64_t stride_c = n * k; - switch (compute_dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_32F, lda, stride_a, input_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_16BF, lda, stride_a, input_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } + // cuBLAS is colmun-major + // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input + // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] + // C = grad_other.T[*, n, k] + // A = grad_output.T[*, n, m] + // B = input.T[*, k, m] + const int lda = n, ldb = k, ldc = n; + const int64_t stride_a = n * m; + const int64_t stride_b = k * m; + const int64_t stride_c = n * k; + switch (compute_dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), + CUDA_R_32F, lda, stride_a, input1->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, + grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), + CUDA_R_16BF, lda, stride_a, input1->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, + grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) } - return {grad_input, grad_other}; + return grad_other; } template __global__ void BiasCopyKernel(T *output, const T *bias, int bs, int out_features) { @@ -328,7 +357,7 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh auto grad_output_promoted = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + // For bf16 compute, accumulate in fp32 to preserve precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); @@ -391,7 +420,7 @@ std::shared_ptr LinearBackwardWeight(const std::shared_ptr &inpu auto grad_output_promoted = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + // For bf16 compute, accumulate in fp32 to preserve precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; const std::vector weight_dims = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; @@ -452,7 +481,7 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o const int64_t bs = std::accumulate(dims.rbegin() + 1, dims.rend(), 1, std::multiplies{}); auto compute_dtype = grad_output->Dtype(); - // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + // For bf16 compute, accumulate in fp32 to preserve precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; auto grad_bias = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); @@ -488,7 +517,8 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_LINEAR_KERNEL(MatmulForward) -REGISTER_CUDA_LINEAR_KERNEL(MatmulBackward) +REGISTER_CUDA_LINEAR_KERNEL(MatmulBackwardInput1) +REGISTER_CUDA_LINEAR_KERNEL(MatmulBackwardInput2) REGISTER_CUDA_LINEAR_KERNEL(LinearForward) REGISTER_CUDA_LINEAR_KERNEL(LinearBackwardInput) REGISTER_CUDA_LINEAR_KERNEL(LinearBackwardWeight) diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index ae7c9f7b..f3140ca5 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -90,7 +90,7 @@ std::tuple, std::shared_ptr> OuterBackward(const auto grad_output_promoted = grad_output_dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); - // For bf16 compute, output in fp32 to preserve accumulation precision (matches PyTorch behavior) + // For bf16 compute, output in fp32 to preserve accumulation precision. auto output_dtype = (promoted_type == DataType::kBFLOAT16) ? DataType::kFLOAT32 : promoted_type; auto grad_input = std::make_shared(std::vector{M}, output_dtype, grad_output->GetDevice()); auto grad_other = std::make_shared(std::vector{N}, output_dtype, grad_output->GetDevice());