-
Notifications
You must be signed in to change notification settings - Fork 669
[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy #2707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import random | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from transformer_engine.pytorch import parallel_cross_entropy | ||
|
|
||
| from utils import dtype_tols | ||
|
|
@@ -14,11 +15,39 @@ class TestParallelCrossEntropy: | |
| def generate_iters(self, iters: int): | ||
| self.iters = iters | ||
|
|
||
| def generate_infra(self, reduce_loss: bool, label_smoothing: float): | ||
| def generate_infra( | ||
| self, | ||
| reduce_loss: bool, | ||
| label_smoothing: float, | ||
| z_loss_weight: float = 0.0, | ||
| ignore_idx: int = -100, | ||
| ): | ||
| self.test_loss_func = parallel_cross_entropy | ||
| self.ref_loss_func = torch.nn.CrossEntropyLoss( | ||
| label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" | ||
| ) | ||
| if z_loss_weight == 0.0: | ||
| self.ref_loss_func = torch.nn.CrossEntropyLoss( | ||
| label_smoothing=label_smoothing, | ||
| reduction="mean" if reduce_loss else "none", | ||
| ignore_index=ignore_idx, | ||
| ) | ||
| else: | ||
|
|
||
| def ref_with_zloss(inp, tar): | ||
| inp = inp.float() | ||
| ce = F.cross_entropy( | ||
| inp, | ||
| tar, | ||
| reduction="none", | ||
| label_smoothing=label_smoothing, | ||
| ignore_index=ignore_idx, | ||
| ) | ||
| z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1)) | ||
| z_pen[tar == ignore_idx] = 0.0 | ||
| loss = ce + z_pen | ||
| if reduce_loss: | ||
| loss = loss.sum() / (tar != ignore_idx).sum() | ||
| return loss | ||
|
Comment on lines
+34
to
+48
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def ref_with_zloss(inp, tar):
inp = inp.float()
ce = F.cross_entropy(
inp, tar, reduction="none",
label_smoothing=label_smoothing,
ignore_index=ignore_idx, # <-- forward the parameter
)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == ignore_idx] = 0.0
... |
||
|
|
||
| self.ref_loss_func = ref_with_zloss | ||
|
|
||
| def generate_input( | ||
| self, | ||
|
|
@@ -63,14 +92,20 @@ def one_iteration_test( | |
| label_smoothing: float, | ||
| reduce_loss: bool, | ||
| ignore_idx: bool = False, | ||
| z_loss_weight: float = 0.0, | ||
| ): | ||
|
|
||
| # Random data | ||
| self.generate_input(dtype, swap_dim, ignore_idx) | ||
|
|
||
| # Forward pass | ||
| # Forward pass — default return is a single tensor (backward compatible) | ||
| test_loss = self.test_loss_func( | ||
| self.input_test, self.tar_test, label_smoothing, reduce_loss, None | ||
| self.input_test, | ||
| self.tar_test, | ||
| label_smoothing, | ||
| reduce_loss, | ||
| None, | ||
| z_loss_weight=z_loss_weight, | ||
| ) | ||
| ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) | ||
|
|
||
|
|
@@ -168,6 +203,106 @@ def test_ignore_idx_reduced_loss(self): | |
| ignore_idx=True, | ||
| ) | ||
|
|
||
| def test_z_loss(self): | ||
| self.generate_iters(5) | ||
| self.generate_infra(False, 0, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.float32, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0, | ||
| reduce_loss=False, | ||
| z_loss_weight=0.001, | ||
| ) | ||
|
|
||
| def test_z_loss_zero_weight(self): | ||
| self.generate_infra(False, 0) | ||
| self.generate_input(torch.float32, False, False) | ||
|
|
||
| inp_base = self.input_test.clone().detach().requires_grad_(True) | ||
| inp_zero = self.input_test.clone().detach().requires_grad_(True) | ||
|
|
||
| loss_base = self.test_loss_func(inp_base, self.tar_test) | ||
| loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0) | ||
|
|
||
| assert torch.equal( | ||
| loss_base, loss_zero | ||
| ), "z_loss_weight=0.0 must be bit-identical to the default" | ||
|
|
||
| loss_base.sum().backward() | ||
| loss_zero.sum().backward() | ||
|
|
||
| assert torch.equal( | ||
| inp_base.grad, inp_zero.grad | ||
| ), "Gradients with z_loss_weight=0.0 must be bit-identical to the default" | ||
|
|
||
| self.input_test = None | ||
| self.input_ref = None | ||
| self.tar_test = None | ||
| self.tar_ref = None | ||
|
Comment on lines
+218
to
+242
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The test clones the input tensor but never calls Consider adding backward verification: def test_z_loss_zero_weight(self):
self.generate_infra(False, 0)
self.generate_input(torch.float32, False, False)
inp_base = self.input_test.clone().requires_grad_(True)
inp_zero = self.input_test.clone().requires_grad_(True)
loss_base = self.test_loss_func(inp_base, self.tar_test)
loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0)
assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default"
loss_base.sum().backward()
loss_zero.sum().backward()
assert torch.equal(inp_base.grad, inp_zero.grad), \
"Gradients with z_loss_weight=0.0 must be bit-identical to the default"
self.input_test = self.input_ref = self.tar_test = self.tar_ref = None |
||
|
|
||
| def test_z_loss_with_ignore_idx(self): | ||
| self.generate_iters(5) | ||
| self.generate_infra(False, 0, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.float32, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0, | ||
| reduce_loss=False, | ||
| ignore_idx=True, | ||
| z_loss_weight=0.001, | ||
| ) | ||
|
|
||
| def test_z_loss_reduced(self): | ||
| self.generate_iters(5) | ||
| self.generate_infra(True, 0, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.float32, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0, | ||
| reduce_loss=True, | ||
| z_loss_weight=0.001, | ||
| ) | ||
|
Comment on lines
+257
to
+267
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The test suite covers Consider adding: def test_z_loss_reduced_with_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(True, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
ignore_idx=True,
z_loss_weight=0.001,
) |
||
|
|
||
| def test_z_loss_reduced_with_ignore_idx(self): | ||
| self.generate_iters(5) | ||
| self.generate_infra(True, 0, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.float32, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0, | ||
| reduce_loss=True, | ||
| ignore_idx=True, | ||
| z_loss_weight=0.001, | ||
| ) | ||
|
|
||
| def test_z_loss_label_smoothing(self): | ||
| self.generate_iters(3) | ||
| self.generate_infra(False, 0.1, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.float32, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0.1, | ||
| reduce_loss=False, | ||
| z_loss_weight=0.001, | ||
| ) | ||
|
|
||
| def test_z_loss_bfloat16(self): | ||
| self.generate_iters(3) | ||
| self.generate_infra(False, 0, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.bfloat16, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0, | ||
| reduce_loss=False, | ||
| z_loss_weight=0.001, | ||
| ) | ||
|
|
||
|
|
||
| def test_non_contiguous_transposed_input(): | ||
| """Regression test: stride(-2) != shape[-1] should not produce wrong results.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,6 +100,7 @@ def cross_entropy_kernel( | |
| n_non_ignore, | ||
| reduce_loss: tl.constexpr, | ||
| label_smoothing: tl.constexpr, | ||
| z_loss_weight: tl.constexpr, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider documenting this behaviour in the docstring of both z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.
This value is used as a Triton compile-time constant (tl.constexpr); varying it across
calls will trigger kernel recompilation. Use a fixed value during training. |
||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| """ | ||
|
|
@@ -121,6 +122,8 @@ def cross_entropy_kernel( | |
| n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing. | ||
| n_non_ignore: The number of non-ignored elements in the batch. | ||
| label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. | ||
| z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token. | ||
| Compile-time constant (tl.constexpr); varying it across calls triggers kernel recompilation. | ||
| BLOCK_SIZE (int): The block size for Triton operations. | ||
| """ | ||
|
|
||
|
|
@@ -160,6 +163,9 @@ def cross_entropy_kernel( | |
| m = tl.maximum(m, m_new) | ||
| ori_X_y = tl.maximum(ori_X_y, X_y_new) | ||
|
|
||
| # lse = log(Z): free to compute (m, d already in registers). | ||
| lse = m + tl.log(d) | ||
|
|
||
| # Label smoothing is a general case of normal cross entropy | ||
| scaled_x_sum = 0.0 | ||
| eps = label_smoothing / (n_cols * world_size) | ||
|
|
@@ -180,13 +186,16 @@ def cross_entropy_kernel( | |
| if label_smoothing > 0: | ||
| # scale X beforehand to avoid overflow | ||
| scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) | ||
| # Softmax gradient | ||
| X_block = tl.exp(X_block - m) / d | ||
| # Z-loss gradient: d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i). | ||
| # Applied before eps subtraction so only pure softmax is scaled. | ||
| if z_loss_weight > 0: | ||
| X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse) | ||
|
Comment on lines
+191
to
+194
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
At this point But the correct z-loss gradient is purely For typical training settings ( The correct approach is to add the z-loss gradient additively, using the pre-eps softmax value: if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
# Z-loss gradient: 2 * z_loss_weight * lse * softmax(x_i), additive to CE gradient.
if z_loss_weight > 0:
softmax_i = tl.exp(X_block_fp32 - m) / d # pure softmax, before subtracting eps
if reduce_loss:
X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse) / n_non_ignore
else:
X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse)where |
||
| X_block = X_block - eps | ||
| # Scale gradients based on reduction mode | ||
| # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore | ||
| # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here | ||
| if reduce_loss: | ||
| X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) | ||
| else: | ||
| X_block = tl.exp(X_block - m) / d - eps | ||
| X_block = X_block / n_non_ignore | ||
| tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) | ||
|
|
||
| # We need tl.debug_barrier() to ensure the new result of X_ptr is written | ||
|
|
@@ -196,7 +205,8 @@ def cross_entropy_kernel( | |
|
|
||
| # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) | ||
| # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) | ||
| loss = -(ori_X_y - m - tl.log(d)) | ||
| # = lse - ori_X_y (reusing lse = m + log(d) already computed above) | ||
| loss = lse - ori_X_y | ||
|
|
||
| # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps | ||
| # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) | ||
|
|
@@ -205,9 +215,13 @@ def cross_entropy_kernel( | |
| # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) | ||
| # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 | ||
| if label_smoothing > 0: | ||
| smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) | ||
| smooth_loss = scaled_x_sum + label_smoothing * lse | ||
| loss = loss * (1 - label_smoothing) + smooth_loss | ||
|
|
||
| # Z-loss regularization: adds z_loss_weight * log(Z)^2 per token. | ||
| if z_loss_weight > 0: | ||
| loss += z_loss_weight * lse * lse | ||
|
|
||
| # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` | ||
| vocab_start_idx = rank * n_cols | ||
| vocab_end_idx = (rank + 1) * n_cols | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -33,6 +33,7 @@ def forward( | |||||
| dist_process_group=None, | ||||||
| ignore_idx=-100, | ||||||
| is_cg_capturable=False, | ||||||
| z_loss_weight=0.0, | ||||||
| ): | ||||||
| """ | ||||||
| The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each | ||||||
|
|
@@ -45,7 +46,8 @@ def forward( | |||||
| label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. | ||||||
| reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. | ||||||
| dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. | ||||||
| ignore_idx (int): The index for which loss and gradients are made to zero | ||||||
| ignore_idx (int): The index for which loss and gradients are made to zero. | ||||||
| z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token. | ||||||
|
|
||||||
| Returns: | ||||||
| tensor: The computed loss. | ||||||
|
|
@@ -57,6 +59,7 @@ def forward( | |||||
| reduce_loss, | ||||||
| dist_process_group, | ||||||
| ignore_idx, | ||||||
| z_loss_weight, | ||||||
| ) | ||||||
|
|
||||||
| ctx.save_for_backward(inp.detach()) | ||||||
|
|
@@ -85,6 +88,7 @@ def backward(ctx, grad_output): | |||||
| None, | ||||||
| None, | ||||||
| None, | ||||||
| None, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -96,11 +100,12 @@ def parallel_cross_entropy( | |||||
| dist_process_group: Optional[torch.distributed.ProcessGroup] = None, | ||||||
| ignore_idx: int = -100, | ||||||
| is_cg_capturable: bool = False, | ||||||
| z_loss_weight: float = 0.0, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
A negative value is mathematically well-formed but semantically inverts the regularization (rewarding large logit magnitudes). Given the docstring describes this as a "regularization weight", an early guard against negative values would make the API safer and the intent explicit:
Suggested change
Consider adding before the if z_loss_weight < 0.0:
raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}") |
||||||
| *, | ||||||
| _input: Optional[torch.Tensor] = None, | ||||||
| ) -> torch.Tensor: | ||||||
| """ | ||||||
| Cross Entropy loss with optional distributed reduction. | ||||||
| Cross Entropy loss with optional distributed reduction and z-loss regularization. | ||||||
|
|
||||||
| The input tensor can be in BF16/FP32, the loss and gradient calculation happens in | ||||||
| FP32 only. The returned loss is always in FP32, the input gradients are upcasted | ||||||
|
|
@@ -127,6 +132,9 @@ def parallel_cross_entropy( | |||||
| The index for which loss and gradients are made to zero. | ||||||
| is_cg_capturable : bool, default = False | ||||||
| Whether the operation is CUDA graph capturable. | ||||||
| z_loss_weight : float, default = 0.0 | ||||||
| Weight for z-loss regularization. Adds ``z_loss_weight * log(Z)^2`` per token. | ||||||
| This value is a Triton compile-time constant; use a fixed value during training. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
|
|
@@ -141,6 +149,9 @@ def parallel_cross_entropy( | |||||
| ) | ||||||
| inp = _input | ||||||
|
|
||||||
| if not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")): | ||||||
| raise ValueError(f"z_loss_weight must be a finite non-negative number, got {z_loss_weight}") | ||||||
|
|
||||||
| return CrossEntropyFunction.apply( | ||||||
| inp, | ||||||
| target, | ||||||
|
|
@@ -149,4 +160,5 @@ def parallel_cross_entropy( | |||||
| dist_process_group, | ||||||
| ignore_idx, | ||||||
| is_cg_capturable, | ||||||
| z_loss_weight, | ||||||
| ) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ignore_idxnot forwarded toCrossEntropyLossgenerate_infranow accepts anignore_idxparameter, and thez_loss_weight > 0branch correctly passes it to bothF.cross_entropy(..., ignore_index=ignore_idx)andz_pen[tar == ignore_idx]. However, thez_loss_weight == 0.0branch silently falls back to PyTorch's default (-100), ignoring the parameter:All current tests happen to pass
ignore_idx=-100(the default), so there is no visible failure now. But if any test ever callsgenerate_infra(..., z_loss_weight=0.0, ignore_idx=42), the reference would still ignore token id-100instead of42, producing a silently incorrect reference and a false-passing comparison.