From e572feb95dee0b04a58fb3767bd556634f0b3a8e Mon Sep 17 00:00:00 2001 From: Cem Bassoy Date: Fri, 20 Mar 2026 17:36:13 +0100 Subject: [PATCH 1/2] feat: add z_loss_weight support to parallel_cross_entropy Add z-loss regularization (z_loss_weight * log(Z)^2 per token) to the Triton cross-entropy kernel. The z_loss_weight parameter is a tl.constexpr, so it is dead-code-eliminated when set to 0.0. Forward: adds z_loss_weight * lse^2 to per-token loss. Backward: scales softmax gradient by (1 + 2 * z_loss_weight * lse). Tests: extend existing test infrastructure with z_loss_weight parameter in generate_infra and one_iteration_test. Z-loss tests cover FP32, BF16, ignore_idx, and zero-weight identity. Signed-off-by: Cem Bassoy --- tests/pytorch/test_parallel_cross_entropy.py | 134 +++++++++++++++++- .../common/triton/cross_entropy.py | 28 +++- transformer_engine/pytorch/cross_entropy.py | 16 ++- .../pytorch/triton/cross_entropy.py | 2 + 4 files changed, 165 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index b4ea193f06..7f5492ad27 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -4,6 +4,7 @@ import random import torch +import torch.nn.functional as F from transformer_engine.pytorch import parallel_cross_entropy from utils import dtype_tols @@ -14,11 +15,26 @@ class TestParallelCrossEntropy: def generate_iters(self, iters: int): self.iters = iters - def generate_infra(self, reduce_loss: bool, label_smoothing: float): + def generate_infra(self, reduce_loss: bool, label_smoothing: float, z_loss_weight: float = 0.0, ignore_idx: int = -100): self.test_loss_func = parallel_cross_entropy - self.ref_loss_func = torch.nn.CrossEntropyLoss( - label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" - ) + if z_loss_weight == 0.0: + self.ref_loss_func = torch.nn.CrossEntropyLoss( + label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none", + ignore_index=ignore_idx, + ) + else: + + def ref_with_zloss(inp, tar): + inp = inp.float() + ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing, ignore_index=ignore_idx) + z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1)) + z_pen[tar == ignore_idx] = 0.0 + loss = ce + z_pen + if reduce_loss: + loss = loss.sum() / (tar != ignore_idx).sum() + return loss + + self.ref_loss_func = ref_with_zloss def generate_input( self, @@ -63,14 +79,20 @@ def one_iteration_test( label_smoothing: float, reduce_loss: bool, ignore_idx: bool = False, + z_loss_weight: float = 0.0, ): # Random data self.generate_input(dtype, swap_dim, ignore_idx) - # Forward pass + # Forward pass — default return is a single tensor (backward compatible) test_loss = self.test_loss_func( - self.input_test, self.tar_test, label_smoothing, reduce_loss, None + self.input_test, + self.tar_test, + label_smoothing, + reduce_loss, + None, + z_loss_weight=z_loss_weight, ) ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) @@ -168,6 +190,106 @@ def test_ignore_idx_reduced_loss(self): ignore_idx=True, ) + def test_z_loss(self): + self.generate_iters(5) + self.generate_infra(False, 0, z_loss_weight=0.001) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, + swap_dim=random.choice([True, False]), + label_smoothing=0, + reduce_loss=False, + z_loss_weight=0.001, + ) + + def test_z_loss_zero_weight(self): + self.generate_infra(False, 0) + self.generate_input(torch.float32, False, False) + + inp_base = self.input_test.clone().detach().requires_grad_(True) + inp_zero = self.input_test.clone().detach().requires_grad_(True) + + loss_base = self.test_loss_func(inp_base, self.tar_test) + loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0) + + assert torch.equal( + loss_base, loss_zero + ), "z_loss_weight=0.0 must be bit-identical to the default" + + loss_base.sum().backward() + loss_zero.sum().backward() + + assert torch.equal( + inp_base.grad, inp_zero.grad + ), "Gradients with z_loss_weight=0.0 must be bit-identical to the default" + + self.input_test = None + self.input_ref = None + self.tar_test = None + self.tar_ref = None + + def test_z_loss_with_ignore_idx(self): + self.generate_iters(5) + self.generate_infra(False, 0, z_loss_weight=0.001) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, + swap_dim=random.choice([True, False]), + label_smoothing=0, + reduce_loss=False, + ignore_idx=True, + z_loss_weight=0.001, + ) + + def test_z_loss_reduced(self): + self.generate_iters(5) + self.generate_infra(True, 0, z_loss_weight=0.001) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, + swap_dim=random.choice([True, False]), + label_smoothing=0, + reduce_loss=True, + z_loss_weight=0.001, + ) + + def test_z_loss_reduced_with_ignore_idx(self): + self.generate_iters(5) + self.generate_infra(True, 0, z_loss_weight=0.001) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, + swap_dim=random.choice([True, False]), + label_smoothing=0, + reduce_loss=True, + ignore_idx=True, + z_loss_weight=0.001, + ) + + def test_z_loss_label_smoothing(self): + self.generate_iters(3) + self.generate_infra(False, 0.1, z_loss_weight=0.001) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, + swap_dim=random.choice([True, False]), + label_smoothing=0.1, + reduce_loss=False, + z_loss_weight=0.001, + ) + + def test_z_loss_bfloat16(self): + self.generate_iters(3) + self.generate_infra(False, 0, z_loss_weight=0.001) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.bfloat16, + swap_dim=random.choice([True, False]), + label_smoothing=0, + reduce_loss=False, + z_loss_weight=0.001, + ) + def test_non_contiguous_transposed_input(): """Regression test: stride(-2) != shape[-1] should not produce wrong results.""" diff --git a/transformer_engine/common/triton/cross_entropy.py b/transformer_engine/common/triton/cross_entropy.py index bec2620467..b180f7414e 100644 --- a/transformer_engine/common/triton/cross_entropy.py +++ b/transformer_engine/common/triton/cross_entropy.py @@ -100,6 +100,7 @@ def cross_entropy_kernel( n_non_ignore, reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, + z_loss_weight: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -121,6 +122,8 @@ def cross_entropy_kernel( n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing. n_non_ignore: The number of non-ignored elements in the batch. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token. + Compile-time constant (tl.constexpr); varying it across calls triggers kernel recompilation. BLOCK_SIZE (int): The block size for Triton operations. """ @@ -160,6 +163,9 @@ def cross_entropy_kernel( m = tl.maximum(m, m_new) ori_X_y = tl.maximum(ori_X_y, X_y_new) + # lse = log(Z): free to compute (m, d already in registers). + lse = m + tl.log(d) + # Label smoothing is a general case of normal cross entropy scaled_x_sum = 0.0 eps = label_smoothing / (n_cols * world_size) @@ -180,13 +186,16 @@ def cross_entropy_kernel( if label_smoothing > 0: # scale X beforehand to avoid overflow scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + # Softmax gradient + X_block = tl.exp(X_block - m) / d + # Z-loss gradient: d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i). + # Applied before eps subtraction so only pure softmax is scaled. + if z_loss_weight > 0: + X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse) + X_block = X_block - eps # Scale gradients based on reduction mode - # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore - # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here if reduce_loss: - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) - else: - X_block = tl.exp(X_block - m) / d - eps + X_block = X_block / n_non_ignore tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written @@ -196,7 +205,8 @@ def cross_entropy_kernel( # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - loss = -(ori_X_y - m - tl.log(d)) + # = lse - ori_X_y (reusing lse = m + log(d) already computed above) + loss = lse - ori_X_y # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) @@ -205,9 +215,13 @@ def cross_entropy_kernel( # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss + # Z-loss regularization: adds z_loss_weight * log(Z)^2 per token. + if z_loss_weight > 0: + loss += z_loss_weight * lse * lse + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` vocab_start_idx = rank * n_cols vocab_end_idx = (rank + 1) * n_cols diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 733b9c10e1..107ea09e51 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -33,6 +33,7 @@ def forward( dist_process_group=None, ignore_idx=-100, is_cg_capturable=False, + z_loss_weight=0.0, ): """ The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each @@ -45,7 +46,8 @@ def forward( label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. - ignore_idx (int): The index for which loss and gradients are made to zero + ignore_idx (int): The index for which loss and gradients are made to zero. + z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token. Returns: tensor: The computed loss. @@ -57,6 +59,7 @@ def forward( reduce_loss, dist_process_group, ignore_idx, + z_loss_weight, ) ctx.save_for_backward(inp.detach()) @@ -85,6 +88,7 @@ def backward(ctx, grad_output): None, None, None, + None, ) @@ -96,11 +100,12 @@ def parallel_cross_entropy( dist_process_group: Optional[torch.distributed.ProcessGroup] = None, ignore_idx: int = -100, is_cg_capturable: bool = False, + z_loss_weight: float = 0.0, *, _input: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Cross Entropy loss with optional distributed reduction. + Cross Entropy loss with optional distributed reduction and z-loss regularization. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted @@ -127,6 +132,9 @@ def parallel_cross_entropy( The index for which loss and gradients are made to zero. is_cg_capturable : bool, default = False Whether the operation is CUDA graph capturable. + z_loss_weight : float, default = 0.0 + Weight for z-loss regularization. Adds ``z_loss_weight * log(Z)^2`` per token. + This value is a Triton compile-time constant; use a fixed value during training. Returns ------- @@ -141,6 +149,9 @@ def parallel_cross_entropy( ) inp = _input + if not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")): + raise ValueError(f"z_loss_weight must be a finite non-negative number, got {z_loss_weight}") + return CrossEntropyFunction.apply( inp, target, @@ -149,4 +160,5 @@ def parallel_cross_entropy( dist_process_group, ignore_idx, is_cg_capturable, + z_loss_weight, ) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 1401383c8f..03a4f7c91b 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -30,6 +30,7 @@ def cross_entropy_forward( reduce_loss: bool, dist_process_group: Union[dist.ProcessGroup, None], ignore_idx: int, + z_loss_weight: float = 0.0, ): """Forward implementation of Cross Entropy kernel""" @@ -100,6 +101,7 @@ def cross_entropy_forward( n_non_ignore=n_non_ignore, reduce_loss=reduce_loss, label_smoothing=label_smoothing, + z_loss_weight=z_loss_weight, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) From 966912a6415c96a3f1f718e8e93909e8b5165600 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Mar 2026 11:22:11 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_parallel_cross_entropy.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 7f5492ad27..3e5114d1cf 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -15,18 +15,31 @@ class TestParallelCrossEntropy: def generate_iters(self, iters: int): self.iters = iters - def generate_infra(self, reduce_loss: bool, label_smoothing: float, z_loss_weight: float = 0.0, ignore_idx: int = -100): + def generate_infra( + self, + reduce_loss: bool, + label_smoothing: float, + z_loss_weight: float = 0.0, + ignore_idx: int = -100, + ): self.test_loss_func = parallel_cross_entropy if z_loss_weight == 0.0: self.ref_loss_func = torch.nn.CrossEntropyLoss( - label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none", + label_smoothing=label_smoothing, + reduction="mean" if reduce_loss else "none", ignore_index=ignore_idx, ) else: def ref_with_zloss(inp, tar): inp = inp.float() - ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing, ignore_index=ignore_idx) + ce = F.cross_entropy( + inp, + tar, + reduction="none", + label_smoothing=label_smoothing, + ignore_index=ignore_idx, + ) z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1)) z_pen[tar == ignore_idx] = 0.0 loss = ce + z_pen