-
Notifications
You must be signed in to change notification settings - Fork 737
[KSM] fix logz when top_k #7225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yuanlehome
merged 1 commit into
PaddlePaddle:release/2.4
from
yuanlehome:fix_log_z_when_top_k
Apr 8, 2026
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,13 +197,37 @@ def _compute_sampling_mask( | |
| max_k = int(k_per_row.max().item()) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Stage 5: compute logZ_K for renormalization | ||
| # Z_K = sum(probs[i] * final_mask[i]) for each request i | ||
| # logZ_K = log(Z_K), with small constant to avoid log(0) | ||
| # Stage 5: compute logZ for renormalization | ||
| # | ||
| # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the | ||
| # distribution actually sampled from (top-k truncated + top-p nucleus). | ||
| # | ||
| # When top_k is active the sampling pipeline first renormalises to | ||
| # π_topk, then applies top-p on π_topk. The total log-normaliser | ||
| # that maps π_full → π_mask absorbs both steps: | ||
| # | ||
| # logZ = log Z_topk + log Z_topp_on_renorm | ||
| # | ||
| # where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed) | ||
| # Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K) | ||
| # | ||
| # Substituting: | ||
| # log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp | ||
| # = log π_topk(k) - log Z_topp ✓ | ||
| # | ||
| # When top_k is absent Z_topk = 1 → logZ = log Z_topp as before. | ||
| # ------------------------------------------------------------------ | ||
| candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) | ||
| z_k = candidate_probs.sum(axis=-1) # [B] | ||
| logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] | ||
| if has_top_k: | ||
| # Z_topp: sum of renormed probs that survive the final mask | ||
| candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) | ||
| z_topp = candidate_probs.sum(axis=-1) # [B] | ||
| # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] | ||
| log_z_topk = paddle.log(row_sums.squeeze(-1)) | ||
| logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] | ||
|
Comment on lines
+220
to
+226
|
||
| else: | ||
| candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) | ||
| z_k = candidate_probs.sum(axis=-1) # [B] | ||
| logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] | ||
|
|
||
| # Transfer only the leading max_k columns — typically max_k << vocab_size. | ||
| indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k] | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR 标题中的标签“[KSM]”不在模板提供的 tag 列表中(如 [BugFix]/[Engine]/[Optimization] 等),建议改为符合列表的标签;另外 PR 描述的 Motivation/Modifications/Usage/Accuracy Tests 目前未填写,建议补充清楚修复背景、影响范围以及如何验证。