Fix CrossEntropyRule precision for soft_label, use_softmax, label_smoothing, and weight#622
Open
Fix CrossEntropyRule precision for soft_label, use_softmax, label_smoothing, and weight#622
Conversation
…rners Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When Paddle processes int64 labels with label_smoothing > 0 in cross_entropy, the one_hot operation returns float32 and label_smooth runs in float32 before casting to input dtype. Match this behavior in CrossEntropyRule to reduce precision gaps from ~7e-8 to ULP level for float64 inputs.
|
Thanks for your contribution! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Comprehensive fix for
CrossEntropyRuleintester/paddle_to_torch/rules.pyto correctly convertpaddle.nn.functional.cross_entropyto its PyTorch equivalent across all parameter combinations.Files Changed
tester/paddle_to_torch/mapping.json— Added missinguse_softmaxdefaulttester/paddle_to_torch/rules.py— RewroteCrossEntropyRule.apply()to handle 5 distinct computation pathsWhat Was Fixed
1.
use_softmax=Falsereference pathWhen
use_softmax=False, Paddle expects probability inputs (output of softmax), not logits. The old rule passed probabilities directly totorch.nn.functional.cross_entropywhich expects logits. Fix: applysoftmax → log → nll_lossorlog → cross_entropydepending on other flags.2.
soft_label=Truemanual computation path (_manual_soft_label_ce)When
soft_label=True,use_softmax=True, andaxisis the last dimension, PyTorch'scross_entropydoes not support the same semantics as Paddle. Fix: manually compute-(label * log_softmax(input)).sum(dim=axis)— matching Paddle's exact formula.3.
label_smoothing > 0.0withsoft_label=TruePaddle applies label smoothing manually to the label tensor before computing cross-entropy. For int64 labels, Paddle's
one_hotreturns float32 andlabel_smoothoperates in float32 before casting to input dtype. Fix: replicate this float32 intermediate precision in the rule.4.
weightwithsoft_label=TruePaddle computes per-sample weights as
(label * weight).sum(axis)for soft labels and applies them manually. The old rule usedlabel @ weightwhich fails for non-floating labels and has wrong semantics for soft labels. Fix: compute(weighted_label * weight).sum(dim=axis)with proper dtype casting.5. Mixed-dtype torch error
When soft_label=True and label was int64, the matmul
label @ weightraised a torch dtype error. Fix: cast label to weight dtype before the operation.Remaining Residuals (4 cases at atol=0, rtol=0)
All 4 remaining accuracy errors are at ULP level — inherent numerical precision differences between Paddle's fused C++
cross_entropy_with_softmaxkernel and our Python-levellog_softmax+ manual reduction:one_hotreturns float32; fused kernel vs decomposed pathlabel_smoothop vs manual(1-ε)*y + ε/CThese cannot be eliminated without modifying Paddle's kernel or implementing a custom fused kernel on the PyTorch side. They are all within expected numerical precision bounds for float64.
Validation