Skip to content

[KSM] fix logz when topk#7232

Merged
yuanlehome merged 3 commits intoPaddlePaddle:release/2.5from
zeroRains:kms_2.5
Apr 8, 2026
Merged

[KSM] fix logz when topk#7232
yuanlehome merged 3 commits intoPaddlePaddle:release/2.5from
zeroRains:kms_2.5

Conversation

@zeroRains
Copy link
Copy Markdown
Contributor

@zeroRains zeroRains commented Apr 8, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

保留logz计算过程,使用logz对logprobs重归一化,#7225 #6966

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 8, 2026 03:19
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 8, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在在 top_k + top_p 截断采样场景下保留/补回 logZ(候选集合概率质量的对数和)计算结果,为后续将 logprobs 归一化到“实际截断分布”提供必要信息;同时在 SamplerOutput 中新增字段承载该数据。

Changes:

  • SamplerOutput 新增 logz_per_batch 字段用于承载每个 batch row 的 logZ。
  • 调整 _compute_sampling_mask 返回值,除稀疏 sampling mask 外额外返回 logz_per_batch,并在采样 forward 路径中计算与回填。

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
fastdeploy/worker/output.py SamplerOutput 新增 logz_per_batch 字段及注释说明
fastdeploy/model_executor/layers/sample/sampler.py _compute_sampling_mask 增加 logZ 计算并调整调用处以透传到 SamplerOutput

Comment on lines +128 to +132
Tuple of (sparse_indices, logz_per_batch):
- sparse_indices: List of length num_reqs; element i is a 1-D int64
numpy array of the retained vocab indices for request i.
- logz_per_batch: 1-D numpy array of shape [num_reqs] containing
log(Z_K) where Z_K is the sum of probabilities in the candidate set.
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_sampling_mask 返回的 logz_per_batch 实际 shape 等于 probs.shape[0]。在 speculative decoding 中这里的 batch 维度会被展开成 total_accepted_tokens,因此 docstring 里写“shape [num_reqs]”容易误导后续使用者;建议改成更准确的描述(例如“shape [B] where B = probs.shape[0]”并说明不同调用场景的含义)。

Copilot uses AI. Check for mistakes.
Comment on lines +685 to 689
logz_per_batch = None
if sampling_metadata.keep_sampling_mask:
sampling_mask = _compute_sampling_mask(
sampling_mask, logz_per_batch = _compute_sampling_mask(
probs,
sampling_metadata.top_p,
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前 logz_per_batch 只在 keep_sampling_mask 打开时计算并塞进 SamplerOutput,但仓库内未找到任何地方消费/传递该字段(例如 _build_stream_transfer_data、ZMQ side-channel、OpenAI response 都没有携带它)。如果该 PR 的目标是修复 top_k 场景下 logZ/用于 logprobs 归一化,这里还缺少把 logZ 应用到 logprobs(或把 logZ 随输出传递给上层做归一化)的完整链路,否则改动对外行为基本不生效。

Copilot uses AI. Check for mistakes.
Comment on lines 1052 to 1058
# Compute sampling mask at accepted token positions.
# Shape: [total_accepted_tokens, vocab_size], bool (CPU).
sampling_mask = None
logz_per_batch = None
if keep_sampling_mask:
# Expand top_p from [batch, 1] to [total_accepted, 1].
accept_top_p = sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 speculative decoding 路径中这里调用 _compute_sampling_mask(target_probs, ...)target_probs 的 batch 维度是 total_accepted_tokens,因此得到的 logz_per_batch 也是按“accepted token”展开的。如果后续要按 request 维度返回/使用 logZ,建议在 post-process 阶段按 accept_num 做 regroup/对齐(否则容易和 per-request 的其它字段错位)。

Copilot uses AI. Check for mistakes.
Comment on lines 100 to 104
top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
top_k_list: Optional[list] = None,
) -> List[np.ndarray]:
) -> tuple[List[np.ndarray], np.ndarray]:
"""
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 标题目前是 "[KSM] fix logz when topk"(还带引号),与模板要求的 [CLASS]Title/Tag 列表不一致;另外 PR 描述里的 Modifications/Usage/Accuracy Tests 等段落基本未补充。建议按模板补全描述并把标题改成符合约定的 Tag(例如 [BugFix] ... 或其它语义明确的标签)。

Copilot uses AI. Check for mistakes.
@@ -647,8 +682,9 @@ def forward_cuda(
# Compute sampling mask BEFORE top_k_top_p_sampling modifies probs.
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释仍写“Binary mask [num_reqs, vocab_size] bool tensor”,但 _compute_sampling_mask 实际返回的是稀疏索引 List[np.ndarray](CPU),且在 speculative 路径维度也不是简单的 num_reqs。建议同步更新注释,避免后续误用。

Suggested change
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
# `_compute_sampling_mask` returns CPU-side sparse retained-token indices,
# i.e. a List[np.ndarray], not a dense bool tensor of shape
# [num_reqs, vocab_size]. In speculative paths, the outer list length also
# does not necessarily match `num_reqs`.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review 摘要

PR 概述:修复 topk 情况下 logz 计算问题,正确计算 topk + topp 截断后的归一化常数。

变更范围fastdeploy/model_executor/layers/sample/sampler.pyfastdeploy/worker/output.py

影响面 Tag[BugFix] [OP]


📝 PR 规范检查

PR 标题使用了非官方 Tag [KSM],且作为 Cherry-Pick PR 未遵循规范。

问题

  1. 标题 Tag [KSM] 不在官方 Tag 列表中
  2. 作为 Cherry-Pick PR(kms_2.5 → release/2.5),未按规范添加 [Cherry-Pick] 前缀和原 PR ID

标题建议(可直接复制):

  • [Cherry-Pick][BugFix] fix logz when topk(#7225)

描述建议

Motivation

修复在 topk + topp 组合采样时,logz 计算未考虑 topk 截断的问题,导致 logprobs 归一化不准确。

Modifications

  • 修改 _compute_sampling_mask 函数,新增返回值 logz_per_batch
  • 在 topk 激活时,正确计算两阶段归一化:logZ = log(Z_topk) + log(Z_topp_on_renorm)
  • SamplerOutput 类中添加 logz_per_batch 字段
  • 更新两处 CUDA 采样路径调用点

问题

级别 文件 概述
🟡 建议 PR 规范 标题 Tag 不规范,Cherry-Pick 格式缺失

总体评价

代码逻辑正确,完整考虑了 topk + topp 组合采样时的归一化计算。接口变更向下兼容(Optional 字段 + 默认值),所有调用点已正确适配。主要问题为 PR 标题格式未遵循规范。

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 8, 2026

Codecov Report

❌ Patch coverage is 30.76923% with 18 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.5@46ad25d). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/pre_and_post_process.py 0.00% 10 Missing and 2 partials ⚠️
fastdeploy/model_executor/layers/sample/sampler.py 53.84% 5 Missing and 1 partial ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.5    #7232   +/-   ##
==============================================
  Coverage               ?   68.99%           
==============================================
  Files                  ?      390           
  Lines                  ?    54405           
  Branches               ?     8577           
==============================================
  Hits                   ?    37535           
  Misses                 ?    14161           
  Partials               ?     2709           
Flag Coverage Δ
GPU 68.99% <30.76%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-08

📋 Review 摘要

PR 概述:修复使用 top-k 时 logz 计算问题,使用 logz 对 logprobs 进行重归一化

变更范围:model_executor/layers/sample/sampler.py、pre_and_post_process.py、worker/output.py

影响面 Tag[BugFix]

📝 PR 规范检查

PR 标题包含 [KSM] 标签,描述说明了修复动机和相关 PR 链接,符合规范。

问题

级别 文件 概述
🔴 Bug pre_and_post_process.py:531 变量名错误:使用未定义的 valid_mask 应为 log_valid_mask
🟡 建议 pre_and_post_process.py:359-377, 527-539 代码重复:logprobs 归一化逻辑在两函数中完全相同
🟡 建议 缺少单元测试:建议添加 _compute_sampling_mask logz 计算的测试用例

总体评价

PR 修复了 top-k 场景下 logz 计算的核心逻辑,数学公式正确。但存在一个变量命名错误会导致运行时错误,需要修复。建议将重复的归一化逻辑提取为辅助函数,并增加测试覆盖。

if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
logprobs = sampler_output.logprobs_tensors.logprobs
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
valid_mask = paddle.isfinite(logprobs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 变量名错误导致运行时异常

post_process_specualate 函数中,529 行定义了变量 log_valid_mask,但在 531 行的 paddle.where 中使用了 valid_mask,该变量未定义会导致 NameError

建议修复
将 531 行的 valid_mask 改为 log_valid_mask

model_output.mask_rollback,
)

# Renormalize logprobs to match truncated sampling distribution (when enabled).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 代码重复

post_process_normal(359-377行)和 post_process_specualate(527-539行)中的 logprobs 归一化逻辑完全相同。

建议
提取为共享辅助函数,如:

def _renormalize_logprobs(sampler_output: SamplerOutput) -> None:
    """Renormalize logprobs to match truncated sampling distribution."""
    if sampler_output.logprobs_tensors is None or sampler_output.logz_per_batch is None:
        return
    logprobs = sampler_output.logprobs_tensors.logprobs
    logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
    valid_mask = paddle.isfinite(logprobs)
    normalized_logprobs = paddle.where(
        valid_mask,
        logprobs - logz.unsqueeze(1),
        paddle.full_like(logprobs, float("-inf")),
    )
    sampler_output.logprobs_tensors = LogprobsTensors(
        logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
        logprobs=normalized_logprobs,
        selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
    )

Copilot AI review requested due to automatic review settings April 8, 2026 08:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

# logprobs_tensors.logprobs: [B, max_num_logprobs + 1]
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logz is created from a NumPy array without specifying place, which can put it on CPU while logprobs is on GPU. In Paddle this commonly causes a device mismatch error when executing logprobs - logz.unsqueeze(1). Create logz on the same place/device as logprobs (e.g., pass place=logprobs.place or otherwise ensure the tensor is moved to the same device) before subtraction.

Suggested change
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
logz = paddle.to_tensor(
sampler_output.logz_per_batch, dtype=logprobs.dtype, place=logprobs.place
)

Copilot uses AI. Check for mistakes.
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
logprobs = sampler_output.logprobs_tensors.logprobs
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same device-placement issue as the normal path: logz constructed from NumPy may land on CPU, while logprobs may be on GPU, breaking logprobs - logz.unsqueeze(1). Ensure logz is created/moved to logprobs's place/device before use.

Suggested change
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype, place=logprobs.place)

Copilot uses AI. Check for mistakes.
Comment on lines +358 to 379
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
# logprobs_tensors.logprobs: [B, max_num_logprobs + 1]
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float("-inf")),
)
# Update logprobs_tensors with normalized values
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)


Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The renormalization block is duplicated in both post_process_normal and post_process_specualate with near-identical logic. Consider extracting this into a small helper (e.g., _renormalize_logprobs_with_logz(sampler_output)) to reduce duplication and the risk of future drift (especially around masking/device placement).

Suggested change
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
# logprobs_tensors.logprobs: [B, max_num_logprobs + 1]
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float("-inf")),
)
# Update logprobs_tensors with normalized values
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
_renormalize_logprobs_with_logz(sampler_output)
def _renormalize_logprobs_with_logz(sampler_output: SamplerOutput):
"""Renormalize logprobs to match the truncated sampling distribution."""
if sampler_output.logprobs_tensors is None or sampler_output.logz_per_batch is None:
return
# logprobs_tensors.logprobs: [B, max_num_logprobs + 1]
logprobs = sampler_output.logprobs_tensors.logprobs
# logz_per_batch: [B], log(sum(probs in candidate set K)) for each request
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1), # broadcast subtraction
paddle.full_like(logprobs, float("-inf")),
)
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)

Copilot uses AI. Check for mistakes.
Comment on lines +226 to +230
logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B]
else:
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logz_per_batch is computed on GPU, then immediately transferred to CPU (.cpu().numpy()), and later converted back to a Paddle tensor in post-processing for renormalization. This adds synchronization + extra H2D/D2H copies. If logz_per_batch is only used for renormalizing logprobs, consider keeping it as a paddle.Tensor on the same device as logprobs (and only convert to NumPy at the boundary if it must be returned externally).

Suggested change
logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B]
else:
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
logz_per_batch = log_z_topk + paddle.log(z_topp + 1e-10) # [B]
else:
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10) # [B]

Copilot uses AI. Check for mistakes.
else:
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logz_per_batch is computed on GPU, then immediately transferred to CPU (.cpu().numpy()), and later converted back to a Paddle tensor in post-processing for renormalization. This adds synchronization + extra H2D/D2H copies. If logz_per_batch is only used for renormalizing logprobs, consider keeping it as a paddle.Tensor on the same device as logprobs (and only convert to NumPy at the boundary if it must be returned externally).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-08

📋 Review 摘要

PR 概述:修复 topk 模式下 logz 的计算逻辑,保留 logz 计算过程并使用 logz 对 logprobs 进行重归一化

变更范围model_executor/layers/sample/sampler.pymodel_executor/pre_and_post_process.pyworker/output.py

影响面 Tag[Models]

📝 PR 规范检查

PR 标题包含 [KSM] 标签(属于 Models 类型),描述包含 Motivation 和 Modifications,格式符合规范。

问题

级别 文件 概述
🟡 建议 fastdeploy/worker/output.py:196 logz_per_batch 形状注释在 speculative decoding 场景下不准确

总体评价

PR 实现正确,在 _compute_sampling_mask 函数中正确计算了 logz(log 归一化因子),并在 post-process 中对 logprobs 进行了重归一化。数学推导和代码实现均正确。唯一需要改进的是注释中对 logz_per_batch 形状的描述需要区分非 speculative 和 speculative 两种模式。


📐 数学和逻辑验证

logz 计算逻辑验证

当同时启用 top_k 和 top_p 时:

  1. 先进行 top_k 截断并重归一化:π_topk(k) = π_full(k) / Z_topk
  2. 再在 π_topk 上应用 top_p:最终候选集 K
  3. 总的 log 归一化因子:logZ = log(Z_topk) + log(Z_topp)

其中:

  • Z_topk = Σ_{j∈topk} π_full(j) = row_sums
  • Z_topp = Σ_{k∈K} π_topk(k) = z_topp

代码中的实现完全符合这个数学推导,逻辑正确。

Post-process 重归一化验证

log π_masked(k) = log π_full(k) - logZ

代码实现:

normalized_logprobs = logprobs - logz.unsqueeze(1)  # 广播减法

此逻辑正确,确保返回的 logprobs 对应于实际采样分布。

# check whether the current path is speculative or non-speculative when
# interpreting the dimension.
sampling_mask: Optional[List[np.ndarray]] = None
# logZ_K for each request: log(sum(probs in candidate set K))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 注释中 logz_per_batch 的形状描述需要区分两种模式。

当前注释说形状是 [num_reqs],这只在非 speculative decoding 模式下成立。

在 speculative decoding 模式下(SpeculativeSampler.forward_cuda,line 1068 调用 _compute_sampling_mask),输入的 target_probs 形状是 [total_accepted_tokens, vocab_size],所以返回的 logz_per_batch 形状也是 [total_accepted_tokens],每个 accepted token 都有独立的 logz 值。

建议更新注释为:

# logZ_K for logprobs renormalization:
#   - Non-speculative: shape [num_reqs], one value per request
#   - Speculative: shape [total_accepted_tokens], one per accepted token

@yuanlehome yuanlehome merged commit 7af95be into PaddlePaddle:release/2.5 Apr 8, 2026
35 of 42 checks passed
@zeroRains zeroRains deleted the kms_2.5 branch April 8, 2026 13:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants