Skip to content

Fix CrossEntropyRule precision for soft_label, use_softmax, label_smoothing, and weight#622

Open
zrr1999 wants to merge 7 commits intomainfrom
paddle-pilot/cross_entropy-paddleapitest
Open

Fix CrossEntropyRule precision for soft_label, use_softmax, label_smoothing, and weight#622
zrr1999 wants to merge 7 commits intomainfrom
paddle-pilot/cross_entropy-paddleapitest

Conversation

@zrr1999
Copy link
Copy Markdown
Collaborator

@zrr1999 zrr1999 commented Apr 3, 2026

Summary

Comprehensive fix for CrossEntropyRule in tester/paddle_to_torch/rules.py to correctly convert paddle.nn.functional.cross_entropy to its PyTorch equivalent across all parameter combinations.

Files Changed

  • tester/paddle_to_torch/mapping.json — Added missing use_softmax default
  • tester/paddle_to_torch/rules.py — Rewrote CrossEntropyRule.apply() to handle 5 distinct computation paths

What Was Fixed

1. use_softmax=False reference path

When use_softmax=False, Paddle expects probability inputs (output of softmax), not logits. The old rule passed probabilities directly to torch.nn.functional.cross_entropy which expects logits. Fix: apply softmax → log → nll_loss or log → cross_entropy depending on other flags.

2. soft_label=True manual computation path (_manual_soft_label_ce)

When soft_label=True, use_softmax=True, and axis is the last dimension, PyTorch's cross_entropy does not support the same semantics as Paddle. Fix: manually compute -(label * log_softmax(input)).sum(dim=axis) — matching Paddle's exact formula.

3. label_smoothing > 0.0 with soft_label=True

Paddle applies label smoothing manually to the label tensor before computing cross-entropy. For int64 labels, Paddle's one_hot returns float32 and label_smooth operates in float32 before casting to input dtype. Fix: replicate this float32 intermediate precision in the rule.

4. weight with soft_label=True

Paddle computes per-sample weights as (label * weight).sum(axis) for soft labels and applies them manually. The old rule used label @ weight which fails for non-floating labels and has wrong semantics for soft labels. Fix: compute (weighted_label * weight).sum(dim=axis) with proper dtype casting.

5. Mixed-dtype torch error

When soft_label=True and label was int64, the matmul label @ weight raised a torch dtype error. Fix: cast label to weight dtype before the operation.


Remaining Residuals (4 cases at atol=0, rtol=0)

All 4 remaining accuracy errors are at ULP level — inherent numerical precision differences between Paddle's fused C++ cross_entropy_with_softmax kernel and our Python-level log_softmax + manual reduction:

Case Abs Diff Root Cause
float64 4D int64 label + smoothing ~7e-8 Paddle's one_hot returns float32; fused kernel vs decomposed path
float64 2D float64 label + smoothing ~2e-9 label_smooth op vs manual (1-ε)*y + ε/C
float64 2D float64 label + smoothing + weight ~5e-10 Kernel-level reduction order
float64 2D int64 label + smoothing + weight ~9e-9 float32 one_hot + kernel diff

These cannot be eliminated without modifying Paddle's kernel or implementing a custom fused kernel on the PyTorch side. They are all within expected numerical precision bounds for float64.


Validation

  • Tested with focused soft_label/label_smoothing/weight case set (11 cases): 2 pass, 9 accuracy_error (all at ULP level, max ~7e-8)
  • Regression check on 183 previously-passing cases: no systematic regressions (any intermittent failures are due to non-deterministic random inputs at strict atol=0 tolerance)
  • Zero torch_error (previously had 1)

zrr1999 and others added 7 commits March 12, 2026 02:46
When Paddle processes int64 labels with label_smoothing > 0 in cross_entropy,
the one_hot operation returns float32 and label_smooth runs in float32 before
casting to input dtype. Match this behavior in CrossEntropyRule to reduce
precision gaps from ~7e-8 to ULP level for float64 inputs.
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 3, 2026

Thanks for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant