Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 141 additions & 6 deletions tests/pytorch/test_parallel_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import random
import torch
import torch.nn.functional as F
from transformer_engine.pytorch import parallel_cross_entropy

from utils import dtype_tols
Expand All @@ -14,11 +15,39 @@ class TestParallelCrossEntropy:
def generate_iters(self, iters: int):
self.iters = iters

def generate_infra(self, reduce_loss: bool, label_smoothing: float):
def generate_infra(
self,
reduce_loss: bool,
label_smoothing: float,
z_loss_weight: float = 0.0,
ignore_idx: int = -100,
):
self.test_loss_func = parallel_cross_entropy
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
if z_loss_weight == 0.0:
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing,
reduction="mean" if reduce_loss else "none",
ignore_index=ignore_idx,
)
Comment on lines +27 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ignore_idx not forwarded to CrossEntropyLoss

generate_infra now accepts an ignore_idx parameter, and the z_loss_weight > 0 branch correctly passes it to both F.cross_entropy(..., ignore_index=ignore_idx) and z_pen[tar == ignore_idx]. However, the z_loss_weight == 0.0 branch silently falls back to PyTorch's default (-100), ignoring the parameter:

self.ref_loss_func = torch.nn.CrossEntropyLoss(
    label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
    # ignore_index=ignore_idx is missing here
)

All current tests happen to pass ignore_idx=-100 (the default), so there is no visible failure now. But if any test ever calls generate_infra(..., z_loss_weight=0.0, ignore_idx=42), the reference would still ignore token id -100 instead of 42, producing a silently incorrect reference and a false-passing comparison.

Suggested change
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing,
reduction="mean" if reduce_loss else "none",
ignore_index=ignore_idx,
)

else:

def ref_with_zloss(inp, tar):
inp = inp.float()
ce = F.cross_entropy(
inp,
tar,
reduction="none",
label_smoothing=label_smoothing,
ignore_index=ignore_idx,
)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == ignore_idx] = 0.0
loss = ce + z_pen
if reduce_loss:
loss = loss.sum() / (tar != ignore_idx).sum()
return loss
Comment on lines +34 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 F.cross_entropy does not receive the parameterized ignore_index

generate_infra now accepts an ignore_idx parameter and correctly uses it to zero out z_pen, but F.cross_entropy is called without ignore_index=ignore_idx. PyTorch's default is -100, so all current tests pass since generate_input always uses -100. However, if a future test passes a non-default ignore_idx, the CE component of the reference would still ignore -100 while the real kernel would ignore the custom index, producing a silently incorrect reference loss and false-passing gradient tests.

def ref_with_zloss(inp, tar):
    inp = inp.float()
    ce = F.cross_entropy(
        inp, tar, reduction="none",
        label_smoothing=label_smoothing,
        ignore_index=ignore_idx,   # <-- forward the parameter
    )
    z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
    z_pen[tar == ignore_idx] = 0.0
    ...


self.ref_loss_func = ref_with_zloss

def generate_input(
self,
Expand Down Expand Up @@ -63,14 +92,20 @@ def one_iteration_test(
label_smoothing: float,
reduce_loss: bool,
ignore_idx: bool = False,
z_loss_weight: float = 0.0,
):

# Random data
self.generate_input(dtype, swap_dim, ignore_idx)

# Forward pass
# Forward pass — default return is a single tensor (backward compatible)
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
self.input_test,
self.tar_test,
label_smoothing,
reduce_loss,
None,
z_loss_weight=z_loss_weight,
)
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)

Expand Down Expand Up @@ -168,6 +203,106 @@ def test_ignore_idx_reduced_loss(self):
ignore_idx=True,
)

def test_z_loss(self):
self.generate_iters(5)
self.generate_infra(False, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
z_loss_weight=0.001,
)

def test_z_loss_zero_weight(self):
self.generate_infra(False, 0)
self.generate_input(torch.float32, False, False)

inp_base = self.input_test.clone().detach().requires_grad_(True)
inp_zero = self.input_test.clone().detach().requires_grad_(True)

loss_base = self.test_loss_func(inp_base, self.tar_test)
loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0)

assert torch.equal(
loss_base, loss_zero
), "z_loss_weight=0.0 must be bit-identical to the default"

loss_base.sum().backward()
loss_zero.sum().backward()

assert torch.equal(
inp_base.grad, inp_zero.grad
), "Gradients with z_loss_weight=0.0 must be bit-identical to the default"

self.input_test = None
self.input_ref = None
self.tar_test = None
self.tar_ref = None
Comment on lines +218 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 test_z_loss_zero_weight only validates the forward pass

The test clones the input tensor but never calls .requires_grad_(True), so no gradient is accumulated and the backward path is never exercised. The Triton kernel eliminates the z-loss branches at compile time via tl.constexpr, so validating that the gradient is also bit-identical for z_loss_weight=0.0 would strengthen the regression value of this test.

Consider adding backward verification:

def test_z_loss_zero_weight(self):
    self.generate_infra(False, 0)
    self.generate_input(torch.float32, False, False)

    inp_base = self.input_test.clone().requires_grad_(True)
    inp_zero = self.input_test.clone().requires_grad_(True)

    loss_base = self.test_loss_func(inp_base, self.tar_test)
    loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0)

    assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default"

    loss_base.sum().backward()
    loss_zero.sum().backward()

    assert torch.equal(inp_base.grad, inp_zero.grad), \
        "Gradients with z_loss_weight=0.0 must be bit-identical to the default"

    self.input_test = self.input_ref = self.tar_test = self.tar_ref = None


def test_z_loss_with_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(False, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
ignore_idx=True,
z_loss_weight=0.001,
)

def test_z_loss_reduced(self):
self.generate_iters(5)
self.generate_infra(True, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
z_loss_weight=0.001,
)
Comment on lines +257 to +267
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing test combination: reduce_loss=True + z-loss + ignore_idx

The test suite covers reduce_loss=True + z-loss (test_z_loss_reduced) and reduce_loss=False + z-loss + ignore_idx (test_z_loss_with_ignore_idx), but no test exercises all three together. This is the most semantically interesting combination: n_non_ignore is used to normalize both the loss value (in Python) and the gradient (in the Triton kernel, line 198), so an incorrect interaction would only appear when tokens are actually masked and reduction is active.

Consider adding:

def test_z_loss_reduced_with_ignore_idx(self):
    self.generate_iters(5)
    self.generate_infra(True, 0, z_loss_weight=0.001)
    for i in range(self.iters):
        self.one_iteration_test(
            dtype=torch.float32,
            swap_dim=random.choice([True, False]),
            label_smoothing=0,
            reduce_loss=True,
            ignore_idx=True,
            z_loss_weight=0.001,
        )


def test_z_loss_reduced_with_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(True, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
ignore_idx=True,
z_loss_weight=0.001,
)

def test_z_loss_label_smoothing(self):
self.generate_iters(3)
self.generate_infra(False, 0.1, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0.1,
reduce_loss=False,
z_loss_weight=0.001,
)

def test_z_loss_bfloat16(self):
self.generate_iters(3)
self.generate_infra(False, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.bfloat16,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
z_loss_weight=0.001,
)


def test_non_contiguous_transposed_input():
"""Regression test: stride(-2) != shape[-1] should not produce wrong results."""
Expand Down
28 changes: 21 additions & 7 deletions transformer_engine/common/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def cross_entropy_kernel(
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
z_loss_weight: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 tl.constexpr specialization will recompile for every unique float value

z_loss_weight is declared tl.constexpr, which means Triton compiles a separate kernel for each unique Python float value passed. The PR describes this as intentional for dead-code elimination when z_loss_weight=0.0, and for a fixed training hyperparameter that's fine. However, if callers ever want to anneal or schedule z_loss_weight across training steps (e.g. a warmup from 0 → 0.001), every distinct float encountered will trigger a fresh JIT compilation, stalling the training loop.

Consider documenting this behaviour in the docstring of both cross_entropy_kernel and parallel_cross_entropy:

z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.
    This value is used as a Triton compile-time constant (tl.constexpr); varying it across
    calls will trigger kernel recompilation. Use a fixed value during training.

BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -121,6 +122,8 @@ def cross_entropy_kernel(
n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
n_non_ignore: The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.
Compile-time constant (tl.constexpr); varying it across calls triggers kernel recompilation.
BLOCK_SIZE (int): The block size for Triton operations.
"""

Expand Down Expand Up @@ -160,6 +163,9 @@ def cross_entropy_kernel(
m = tl.maximum(m, m_new)
ori_X_y = tl.maximum(ori_X_y, X_y_new)

# lse = log(Z): free to compute (m, d already in registers).
lse = m + tl.log(d)

# Label smoothing is a general case of normal cross entropy
scaled_x_sum = 0.0
eps = label_smoothing / (n_cols * world_size)
Expand All @@ -180,13 +186,16 @@ def cross_entropy_kernel(
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
# Softmax gradient
X_block = tl.exp(X_block - m) / d
# Z-loss gradient: d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i).
# Applied before eps subtraction so only pure softmax is scaled.
if z_loss_weight > 0:
X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse)
Comment on lines +191 to +194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incorrect z-loss gradient when label_smoothing > 0

At this point X_block = (softmax(x_i) - eps) / N (or softmax - eps without reduction). Multiplying the combined CE gradient by (1 + 2 * z_loss_weight * lse) expands to:

(softmax - eps)/N * (1 + 2*z*lse)
= (softmax - eps)/N + (softmax - eps)/N * 2*z*lse

But the correct z-loss gradient is purely softmax/N * 2 * z * lse — the z-loss term should be additive on top of the CE gradient, not multiplicative against the entire (softmax - eps) expression. The error introduced is -eps/N * 2 * z_loss_weight * lse per element.

For typical training settings (label_smoothing=0.1, V=64000, z_loss_weight=0.001, lse≈11) the error is on the order of 3e-8, which is below float32 precision for large vocabularies — explaining why test_z_loss_label_smoothing still passes. However, for small vocabularies (e.g. V=32) the error becomes measurable and the implementation is mathematically incorrect.

The correct approach is to add the z-loss gradient additively, using the pre-eps softmax value:

        if reduce_loss:
            X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
        else:
            X_block = tl.exp(X_block - m) / d - eps
        # Z-loss gradient: 2 * z_loss_weight * lse * softmax(x_i), additive to CE gradient.
        if z_loss_weight > 0:
            softmax_i = tl.exp(X_block_fp32 - m) / d  # pure softmax, before subtracting eps
            if reduce_loss:
                X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse) / n_non_ignore
            else:
                X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse)

where X_block_fp32 is the logit block before the CE computation (currently loaded at the top of the loop).

X_block = X_block - eps
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
X_block = X_block / n_non_ignore
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written
Expand All @@ -196,7 +205,8 @@ def cross_entropy_kernel(

# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss = -(ori_X_y - m - tl.log(d))
# = lse - ori_X_y (reusing lse = m + log(d) already computed above)
loss = lse - ori_X_y

# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
Expand All @@ -205,9 +215,13 @@ def cross_entropy_kernel(
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

# Z-loss regularization: adds z_loss_weight * log(Z)^2 per token.
if z_loss_weight > 0:
loss += z_loss_weight * lse * lse

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
Expand Down
16 changes: 14 additions & 2 deletions transformer_engine/pytorch/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def forward(
dist_process_group=None,
ignore_idx=-100,
is_cg_capturable=False,
z_loss_weight=0.0,
):
"""
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
Expand All @@ -45,7 +46,8 @@ def forward(
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx (int): The index for which loss and gradients are made to zero
ignore_idx (int): The index for which loss and gradients are made to zero.
z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.

Returns:
tensor: The computed loss.
Expand All @@ -57,6 +59,7 @@ def forward(
reduce_loss,
dist_process_group,
ignore_idx,
z_loss_weight,
)

ctx.save_for_backward(inp.detach())
Expand Down Expand Up @@ -85,6 +88,7 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
)


Expand All @@ -96,11 +100,12 @@ def parallel_cross_entropy(
dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
z_loss_weight: float = 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No validation that z_loss_weight is non-negative

A negative value is mathematically well-formed but semantically inverts the regularization (rewarding large logit magnitudes). Given the docstring describes this as a "regularization weight", an early guard against negative values would make the API safer and the intent explicit:

Suggested change
z_loss_weight: float = 0.0,
z_loss_weight: float = 0.0,

Consider adding before the CrossEntropyFunction.apply(...) call:

if z_loss_weight < 0.0:
    raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")

*,
_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Cross Entropy loss with optional distributed reduction.
Cross Entropy loss with optional distributed reduction and z-loss regularization.

The input tensor can be in BF16/FP32, the loss and gradient calculation happens in
FP32 only. The returned loss is always in FP32, the input gradients are upcasted
Expand All @@ -127,6 +132,9 @@ def parallel_cross_entropy(
The index for which loss and gradients are made to zero.
is_cg_capturable : bool, default = False
Whether the operation is CUDA graph capturable.
z_loss_weight : float, default = 0.0
Weight for z-loss regularization. Adds ``z_loss_weight * log(Z)^2`` per token.
This value is a Triton compile-time constant; use a fixed value during training.

Returns
-------
Expand All @@ -141,6 +149,9 @@ def parallel_cross_entropy(
)
inp = _input

if not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")):
raise ValueError(f"z_loss_weight must be a finite non-negative number, got {z_loss_weight}")

return CrossEntropyFunction.apply(
inp,
target,
Expand All @@ -149,4 +160,5 @@ def parallel_cross_entropy(
dist_process_group,
ignore_idx,
is_cg_capturable,
z_loss_weight,
)
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def cross_entropy_forward(
reduce_loss: bool,
dist_process_group: Union[dist.ProcessGroup, None],
ignore_idx: int,
z_loss_weight: float = 0.0,
):
"""Forward implementation of Cross Entropy kernel"""

Expand Down Expand Up @@ -100,6 +101,7 @@ def cross_entropy_forward(
n_non_ignore=n_non_ignore,
reduce_loss=reduce_loss,
label_smoothing=label_smoothing,
z_loss_weight=z_loss_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
Expand Down