Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,37 @@ def _compute_sampling_mask(
max_k = int(k_per_row.max().item())

# ------------------------------------------------------------------
# Stage 5: compute logZ_K for renormalization
# Z_K = sum(probs[i] * final_mask[i]) for each request i
# logZ_K = log(Z_K), with small constant to avoid log(0)
# Stage 5: compute logZ for renormalization
#
# Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the
# distribution actually sampled from (top-k truncated + top-p nucleus).
#
Comment on lines +200 to +204
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 标题中的标签“[KSM]”不在模板提供的 tag 列表中(如 [BugFix]/[Engine]/[Optimization] 等),建议改为符合列表的标签;另外 PR 描述的 Motivation/Modifications/Usage/Accuracy Tests 目前未填写,建议补充清楚修复背景、影响范围以及如何验证。

Copilot uses AI. Check for mistakes.
# When top_k is active the sampling pipeline first renormalises to
# π_topk, then applies top-p on π_topk. The total log-normaliser
# that maps π_full → π_mask absorbs both steps:
#
# logZ = log Z_topk + log Z_topp_on_renorm
#
# where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed)
# Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K)
#
# Substituting:
# log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp
# = log π_topk(k) - log Z_topp ✓
#
# When top_k is absent Z_topk = 1 → logZ = log Z_topp as before.
# ------------------------------------------------------------------
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
if has_top_k:
# Z_topp: sum of renormed probs that survive the final mask
candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs))
z_topp = candidate_probs.sum(axis=-1) # [B]
# row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B]
log_z_topk = paddle.log(row_sums.squeeze(-1))
logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B]
Comment on lines +220 to +226
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这次修改会影响在启用 keep_sampling_mask 时返回的 logz_per_batch,从而改变 logprobs 的重归一化结果;但当前仓库的 sampler 单测未覆盖 “top_k + top_p + keep_sampling_mask” 组合(尤其是校验归一化后候选集合的概率和为 1 / 与采样候选集一致)。建议在 tests/layers/test_sampler.py 增加覆盖该路径的用例,避免回归。

Copilot generated this review using guidance from repository custom instructions.
else:
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]

# Transfer only the leading max_k columns — typically max_k << vocab_size.
indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]
Expand Down
Loading