[KSM] fix logz when top_k#7225
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 修复在启用 top_k(并与 top_p 组合)时,采样候选集合对应的 logz_per_batch 计算不符合“实际采样分布(先 top-k 重归一化,再做 top-p 截断)”的归一化因子定义,从而用于后续 logprobs 重归一化时更准确/一致。
Changes:
- 重新定义并实现
top_k场景下的logZ:将Z_topk与在重归一化后分布上计算得到的Z_topp组合为总归一化项。 non-top_k场景保持原有log(sum(probs in candidate set))逻辑不变。- 增加更详细的推导注释,解释
π_full → π_mask的归一化关系。
| # Stage 5: compute logZ for renormalization | ||
| # | ||
| # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the | ||
| # distribution actually sampled from (top-k truncated + top-p nucleus). | ||
| # |
There was a problem hiding this comment.
PR 标题中的标签“[KSM]”不在模板提供的 tag 列表中(如 [BugFix]/[Engine]/[Optimization] 等),建议改为符合列表的标签;另外 PR 描述的 Motivation/Modifications/Usage/Accuracy Tests 目前未填写,建议补充清楚修复背景、影响范围以及如何验证。
| if has_top_k: | ||
| # Z_topp: sum of renormed probs that survive the final mask | ||
| candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) | ||
| z_topp = candidate_probs.sum(axis=-1) # [B] | ||
| # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] | ||
| log_z_topk = paddle.log(row_sums.squeeze(-1)) | ||
| logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] |
There was a problem hiding this comment.
这次修改会影响在启用 keep_sampling_mask 时返回的 logz_per_batch,从而改变 logprobs 的重归一化结果;但当前仓库的 sampler 单测未覆盖 “top_k + top_p + keep_sampling_mask” 组合(尤其是校验归一化后候选集合的概率和为 1 / 与采样候选集一致)。建议在 tests/layers/test_sampler.py 增加覆盖该路径的用例,避免回归。
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release/2.4 #7225 +/- ##
==============================================
Coverage ? 56.29%
==============================================
Files ? 333
Lines ? 42621
Branches ? 6478
==============================================
Hits ? 23992
Misses ? 16763
Partials ? 1866
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
fastdeploy-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-08
📋 Review 摘要
PR 概述:修复 top_k 激活时 logZ 计算错误的 bug
变更范围:model_executor/layers/sample/
影响面 Tag:[BugFix] [OP]
📝 PR 规范检查
PR 标题缺少必需的 [Tag],且描述中 Motivation、Modifications、Checklist 章节均未填写。
标题建议(可直接复制):
[BugFix] [KSM] fix logz when top_k
描述模板(可直接复制):
## Motivation
当 top_k 和 top_p 同时激活时,logZ 的计算逻辑存在错误。Stage 2 对概率进行了重新归一化得到 `renorm_sorted_probs`,但 Stage 5 却使用原始的 `sorted_probs` 计算 logZ,导致最终采样概率不正确。
正确的逻辑应该是:
- logZ = log Z_topk + log Z_topp
- 其中 Z_topk = row_sums(top-k 概率总和)
- Z_topp 应使用重新归一化后的 `renorm_sorted_probs` 计算(top-p 在 top-k 上的概率总和)
## Modifications
修改 `_compute_sampling_mask` 函数 Stage 5 的 logZ 计算逻辑:
- 当 has_top_k 为 True 时,使用 `renorm_sorted_probs` 计算 Z_topp
- logZ = log(row_sums) + log(Z_topp)
- 当 has_top_k 为 False 时,保持原有逻辑不变
## Checklist
- [x] Format your code, run `pre-commit` before commit.问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 无 | - | 未发现阻塞性问题 |
总体评价
代码修复逻辑正确,数学推导清晰,注释详细。renorm_sorted_probs、row_sums、has_top_k 等变量均在正确的作用域内定义。建议补充 PR 描述中的 Motivation 和 Modifications 内容以便后续代码审查。
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.