diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 11f12775c5..7dd942b314 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -395,6 +395,23 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned."); } + // Fuse the add for BackwardAdd stage + if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { + NVTE_CHECK(cudnnGetVersion() >= 92100, + "Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion()); + + _add = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("add") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + auto add_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto _dx_with_add = _graph.pointwise(_dx, _add, add_options); + _dx->set_output(false).set_data_type(get_cudnn_fe_dtype(itype)); + _dx = _dx_with_add; + } _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); } @@ -467,13 +484,16 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_ void* rsigma_dptr, void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { - // cuDNN does not currently support fused backward+add - NVTE_CHECK(add_dptr == nullptr); - // Binding data pointers to graph tensors _variant_pack = { {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; + // Bind the add tensor for fused backward+add + if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { + NVTE_CHECK(add_dptr != nullptr, "add_dptr must not be null for BackwardAdd"); + _variant_pack.insert({{_add, add_dptr}}); + } + if (_zero_centered) _variant_pack.insert({{_scalar_offset, reinterpret_cast(this->_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}}); diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 79de2ac140..0cbd5a99f9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -294,7 +294,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { std::shared_ptr _z_mx_row, _z_mx_col, _sf_row, _sf_col; const bool _training; // BWD - std::shared_ptr _dz, _dx, _dgamma, _dbeta; + std::shared_ptr _dz, _dx, _dgamma, _dbeta, _add; fe::graph::Graph _graph; std::unordered_map, void*> _variant_pack; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 6f6656534a..adf2ccee04 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -206,16 +206,21 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const CheckOutputTensor(*dgamma, "dgamma"); } - // cuDNN does not currently support fused backward+add - NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; - - // TE backend does not currently support zero_centered_gamma_in_weight_dtype - NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), - "zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add"); - - bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, dgamma->data.dptr, add.data.dptr); + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; bool gamma_in_weight_dtype = false; + if (use_cudnn_norm_bwd()) { + norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); + } else { + norm_backend = NVTE_Norm_Backend::Te; + // TE backend does not currently support zero_centered_gamma_in_weight_dtype + NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), + "zero_centered_gamma_in_weight_dtype is currently not supported " + "for rmsnorm_bwd_add with TE backend"); + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr, add.data.dptr); + } auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd,