From 85257996880d2a2158f3534e6664be169dd01e6e Mon Sep 17 00:00:00 2001 From: Carlos Gomes Date: Tue, 10 Feb 2026 12:42:06 +0100 Subject: [PATCH 1/7] add cudnn dln+add Signed-off-by: CarlosGomes98 --- .../common/normalization/common.cpp | 22 ++++++++++++++--- .../common/normalization/common.h | 2 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 24 ++++++++++++------- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 11f12775c5..1e0e308c73 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -395,6 +395,19 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned."); } + // Fuse the add for BackwardAdd stage + if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { + _add = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("add") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})); + auto add_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto _dx_with_add = _graph.pointwise(_dx, _add, add_options); + _dx->set_output(false); + _dx = _dx_with_add; + } _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); } @@ -467,13 +480,16 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_ void* rsigma_dptr, void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { - // cuDNN does not currently support fused backward+add - NVTE_CHECK(add_dptr == nullptr); - // Binding data pointers to graph tensors _variant_pack = { {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; + // Bind the add tensor for fused backward+add + if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { + NVTE_CHECK(add_dptr != nullptr, "add_dptr must not be null for BackwardAdd"); + _variant_pack.insert({{_add, add_dptr}}); + } + if (_zero_centered) _variant_pack.insert({{_scalar_offset, reinterpret_cast(this->_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}}); diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 79de2ac140..0cbd5a99f9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -294,7 +294,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { std::shared_ptr _z_mx_row, _z_mx_col, _sf_row, _sf_col; const bool _training; // BWD - std::shared_ptr _dz, _dx, _dgamma, _dbeta; + std::shared_ptr _dz, _dx, _dgamma, _dbeta, _add; fe::graph::Graph _graph; std::unordered_map, void*> _variant_pack; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 6f6656534a..d430bf5267 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -206,16 +206,22 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const CheckOutputTensor(*dgamma, "dgamma"); } - // cuDNN does not currently support fused backward+add - NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; - - // TE backend does not currently support zero_centered_gamma_in_weight_dtype - NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), - "zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add"); - - bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, dgamma->data.dptr, add.data.dptr); + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; bool gamma_in_weight_dtype = false; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); + } else { + norm_backend = NVTE_Norm_Backend::Te; + // TE backend does not currently support zero_centered_gamma_in_weight_dtype + NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), + "zero_centered_gamma_in_weight_dtype is currently not supported " + "for rmsnorm_bwd_add with TE backend"); + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr, add.data.dptr); + } auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd, From debe1309a5521eac66ebab25bd68112fe85a015a Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:20:20 +0100 Subject: [PATCH 2/7] try fixing cudnn build issue Signed-off-by: CarlosGomes98 --- transformer_engine/common/normalization/common.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 1e0e308c73..bc2aabdffd 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -400,12 +400,14 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor _add = _graph.tensor(fe::graph::Tensor_attributes() .set_name("add") .set_dim({batch_dim, hidden_dim, 1, 1}) - .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})); + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(itype))); auto add_options = fe::graph::Pointwise_attributes() .set_mode(fe::PointwiseMode_t::ADD) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); auto _dx_with_add = _graph.pointwise(_dx, _add, add_options); - _dx->set_output(false); + _dx_with_add->set_data_type(get_cudnn_fe_dtype(itype)); + _dx->set_output(false).set_data_type(get_cudnn_fe_dtype(itype)); _dx = _dx_with_add; } _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); From 793d4cd7f07a46a8866fce3cccf8cb70262469ea Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 19 Feb 2026 10:23:27 +0100 Subject: [PATCH 3/7] guard against cudnn version Signed-off-by: CarlosGomes98 --- transformer_engine/common/normalization/common.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index bc2aabdffd..d449babf43 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -397,6 +397,9 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor } // Fuse the add for BackwardAdd stage if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { + NVTE_CHECK(cudnnGetVersion() >= 92100, + "Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion()); + _add = _graph.tensor(fe::graph::Tensor_attributes() .set_name("add") .set_dim({batch_dim, hidden_dim, 1, 1}) From 2134cc660dc698141a8d92ab56ddd2c4a2897ddc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:29:24 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/normalization/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index d449babf43..4883689def 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -399,7 +399,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { NVTE_CHECK(cudnnGetVersion() >= 92100, "Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion()); - + _add = _graph.tensor(fe::graph::Tensor_attributes() .set_name("add") .set_dim({batch_dim, hidden_dim, 1, 1}) From 118a108e0d925e90c21a0ea7e9d2d2939ef870b1 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 18 Mar 2026 15:45:54 +0100 Subject: [PATCH 5/7] change itype to wtype for add in rmsnorm_bwd Signed-off-by: CarlosGomes98 --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/common/normalization/common.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d33027a41a..a8f90f3119 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 +Subproject commit a8f90f31197590b42d7b7daffdefc7fcae75f4aa diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 4883689def..d674ad06e5 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -404,7 +404,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_name("add") .set_dim({batch_dim, hidden_dim, 1, 1}) .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) - .set_data_type(get_cudnn_fe_dtype(itype))); + .set_data_type(get_cudnn_fe_dtype(wtype))); auto add_options = fe::graph::Pointwise_attributes() .set_mode(fe::PointwiseMode_t::ADD) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); From a16b8d8bc39185d782bf9117c077add898e74ef8 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 18 Mar 2026 16:31:32 +0100 Subject: [PATCH 6/7] remove dead code Signed-off-by: CarlosGomes98 --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/common/normalization/common.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index a8f90f3119..d33027a41a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit a8f90f31197590b42d7b7daffdefc7fcae75f4aa +Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index d674ad06e5..7dd942b314 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -409,7 +409,6 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor .set_mode(fe::PointwiseMode_t::ADD) .set_compute_data_type(get_cudnn_fe_dtype(ctype)); auto _dx_with_add = _graph.pointwise(_dx, _add, add_options); - _dx_with_add->set_data_type(get_cudnn_fe_dtype(itype)); _dx->set_output(false).set_data_type(get_cudnn_fe_dtype(itype)); _dx = _dx_with_add; } From 71db968d723d7d714773fc633e75e172d5027798 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 19 Mar 2026 21:46:08 +0100 Subject: [PATCH 7/7] remove dangling todo Signed-off-by: CarlosGomes98 --- transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index d430bf5267..adf2ccee04 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -210,7 +210,6 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const bool is_aligned = true; bool gamma_in_weight_dtype = false; if (use_cudnn_norm_bwd()) { - // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else {