-
Notifications
You must be signed in to change notification settings - Fork 737
[KSM] fix logz when topk #7232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KSM] fix logz when topk #7232
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -100,7 +100,7 @@ def _compute_sampling_mask( | |||||||||||||||||||||
| top_p: paddle.Tensor, | ||||||||||||||||||||||
| top_k: Optional[paddle.Tensor] = None, | ||||||||||||||||||||||
| top_k_list: Optional[list] = None, | ||||||||||||||||||||||
| ) -> List[np.ndarray]: | ||||||||||||||||||||||
| ) -> tuple[List[np.ndarray], np.ndarray]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Compute a combined top-k + top-p (nucleus) sampling mask as sparse | ||||||||||||||||||||||
| retained-token indices. | ||||||||||||||||||||||
|
|
@@ -125,8 +125,11 @@ def _compute_sampling_mask( | |||||||||||||||||||||
| top-k filtering is needed at all. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||
| sparse_indices: List of length num_reqs; element i is a 1-D int64 | ||||||||||||||||||||||
| numpy array of the retained vocab indices for request i. | ||||||||||||||||||||||
| Tuple of (sparse_indices, logz_per_batch): | ||||||||||||||||||||||
| - sparse_indices: List of length num_reqs; element i is a 1-D int64 | ||||||||||||||||||||||
| numpy array of the retained vocab indices for request i. | ||||||||||||||||||||||
| - logz_per_batch: 1-D numpy array of shape [num_reqs] containing | ||||||||||||||||||||||
| log(Z_K) where Z_K is the sum of probabilities in the candidate set. | ||||||||||||||||||||||
|
Comment on lines
+128
to
+132
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| real_bsz = probs.shape[0] | ||||||||||||||||||||||
| vocab_size = probs.shape[1] | ||||||||||||||||||||||
|
|
@@ -193,13 +196,45 @@ def _compute_sampling_mask( | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] | ||||||||||||||||||||||
| max_k = int(k_per_row.max().item()) | ||||||||||||||||||||||
| # ------------------------------------------------------------------ | ||||||||||||||||||||||
| # Stage 5: compute logZ for renormalization | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the | ||||||||||||||||||||||
| # distribution actually sampled from (top-k truncated + top-p nucleus). | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # When top_k is active the sampling pipeline first renormalises to | ||||||||||||||||||||||
| # π_topk, then applies top-p on π_topk. The total log-normaliser | ||||||||||||||||||||||
| # that maps π_full → π_mask absorbs both steps: | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # logZ = log Z_topk + log Z_topp_on_renorm | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed) | ||||||||||||||||||||||
| # Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K) | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Substituting: | ||||||||||||||||||||||
| # log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp | ||||||||||||||||||||||
| # = log π_topk(k) - log Z_topp ✓ | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # When top_k is absent Z_topk = 1 → logZ = log Z_topp as before. | ||||||||||||||||||||||
| # ------------------------------------------------------------------ | ||||||||||||||||||||||
| if has_top_k: | ||||||||||||||||||||||
| # Z_topp: sum of renormed probs that survive the final mask | ||||||||||||||||||||||
| candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) | ||||||||||||||||||||||
| z_topp = candidate_probs.sum(axis=-1) # [B] | ||||||||||||||||||||||
| # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] | ||||||||||||||||||||||
| log_z_topk = paddle.log(row_sums.squeeze(-1)) | ||||||||||||||||||||||
| logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) | ||||||||||||||||||||||
| z_k = candidate_probs.sum(axis=-1) # [B] | ||||||||||||||||||||||
| logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] | ||||||||||||||||||||||
|
Comment on lines
+226
to
+230
|
||||||||||||||||||||||
| logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] | |
| else: | |
| candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) | |
| z_k = candidate_probs.sum(axis=-1) # [B] | |
| logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] | |
| logz_per_batch = log_z_topk + paddle.log(z_topp + 1e-10) # [B] | |
| else: | |
| candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) | |
| z_k = candidate_probs.sum(axis=-1) # [B] | |
| logz_per_batch = paddle.log(z_k + 1e-10) # [B] |
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logz_per_batch is computed on GPU, then immediately transferred to CPU (.cpu().numpy()), and later converted back to a Paddle tensor in post-processing for renormalization. This adds synchronization + extra H2D/D2H copies. If logz_per_batch is only used for renormalizing logprobs, consider keeping it as a paddle.Tensor on the same device as logprobs (and only convert to NumPy at the boundary if it must be returned externally).
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的注释仍写“Binary mask [num_reqs, vocab_size] bool tensor”,但 _compute_sampling_mask 实际返回的是稀疏索引 List[np.ndarray](CPU),且在 speculative 路径维度也不是简单的 num_reqs。建议同步更新注释,避免后续误用。
| # Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated. | |
| # `_compute_sampling_mask` returns CPU-side sparse retained-token indices, | |
| # i.e. a List[np.ndarray], not a dense bool tensor of shape | |
| # [num_reqs, vocab_size]. In speculative paths, the outer list length also | |
| # does not necessarily match `num_reqs`. |
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前 logz_per_batch 只在 keep_sampling_mask 打开时计算并塞进 SamplerOutput,但仓库内未找到任何地方消费/传递该字段(例如 _build_stream_transfer_data、ZMQ side-channel、OpenAI response 都没有携带它)。如果该 PR 的目标是修复 top_k 场景下 logZ/用于 logprobs 归一化,这里还缺少把 logZ 应用到 logprobs(或把 logZ 随输出传递给上层做归一化)的完整链路,否则改动对外行为基本不生效。
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 speculative decoding 路径中这里调用 _compute_sampling_mask(target_probs, ...) 时 target_probs 的 batch 维度是 total_accepted_tokens,因此得到的 logz_per_batch 也是按“accepted token”展开的。如果后续要按 request 维度返回/使用 logZ,建议在 post-process 阶段按 accept_num 做 regroup/对齐(否则容易和 per-request 的其它字段错位)。
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -355,6 +355,26 @@ def post_process_normal( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampler_output.sampled_token_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_output.is_block_step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # logprobs_tensors.logprobs: [B, max_num_logprobs + 1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs = sampler_output.logprobs_tensors.logprobs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | |
| logz = paddle.to_tensor( | |
| sampler_output.logz_per_batch, dtype=logprobs.dtype, place=logprobs.place | |
| ) |
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The renormalization block is duplicated in both post_process_normal and post_process_specualate with near-identical logic. Consider extracting this into a small helper (e.g., _renormalize_logprobs_with_logz(sampler_output)) to reduce duplication and the risk of future drift (especially around masking/device placement).
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | |
| if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: | |
| # logprobs_tensors.logprobs: [B, max_num_logprobs + 1] | |
| logprobs = sampler_output.logprobs_tensors.logprobs | |
| # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request | |
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | |
| # Renormalize: log π_masked = log π_full - log Z_K | |
| # Only normalize valid candidates; padding positions use -inf | |
| valid_mask = paddle.isfinite(logprobs) | |
| normalized_logprobs = paddle.where( | |
| valid_mask, | |
| logprobs - logz.unsqueeze(1), # broadcast subtraction | |
| paddle.full_like(logprobs, float("-inf")), | |
| ) | |
| # Update logprobs_tensors with normalized values | |
| sampler_output.logprobs_tensors = LogprobsTensors( | |
| logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, | |
| logprobs=normalized_logprobs, | |
| selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, | |
| ) | |
| _renormalize_logprobs_with_logz(sampler_output) | |
| def _renormalize_logprobs_with_logz(sampler_output: SamplerOutput): | |
| """Renormalize logprobs to match the truncated sampling distribution.""" | |
| if sampler_output.logprobs_tensors is None or sampler_output.logz_per_batch is None: | |
| return | |
| # logprobs_tensors.logprobs: [B, max_num_logprobs + 1] | |
| logprobs = sampler_output.logprobs_tensors.logprobs | |
| # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request | |
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | |
| # Renormalize: log π_masked = log π_full - log Z_K | |
| # Only normalize valid candidates; padding positions use -inf | |
| valid_mask = paddle.isfinite(logprobs) | |
| normalized_logprobs = paddle.where( | |
| valid_mask, | |
| logprobs - logz.unsqueeze(1), # broadcast subtraction | |
| paddle.full_like(logprobs, float("-inf")), | |
| ) | |
| sampler_output.logprobs_tensors = LogprobsTensors( | |
| logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, | |
| logprobs=normalized_logprobs, | |
| selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议 代码重复
post_process_normal(359-377行)和 post_process_specualate(527-539行)中的 logprobs 归一化逻辑完全相同。
建议:
提取为共享辅助函数,如:
def _renormalize_logprobs(sampler_output: SamplerOutput) -> None:
"""Renormalize logprobs to match truncated sampling distribution."""
if sampler_output.logprobs_tensors is None or sampler_output.logz_per_batch is None:
return
logprobs = sampler_output.logprobs_tensors.logprobs
logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype)
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask,
logprobs - logz.unsqueeze(1),
paddle.full_like(logprobs, float("-inf")),
)
sampler_output.logprobs_tensors = LogprobsTensors(
logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks,
)
Copilot
AI
Apr 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same device-placement issue as the normal path: logz constructed from NumPy may land on CPU, while logprobs may be on GPU, breaking logprobs - logz.unsqueeze(1). Ensure logz is created/moved to logprobs's place/device before use.
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) | |
| logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype, place=logprobs.place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 Bug 变量名错误导致运行时异常
在 post_process_specualate 函数中,529 行定义了变量 log_valid_mask,但在 531 行的 paddle.where 中使用了 valid_mask,该变量未定义会导致 NameError。
建议修复:
将 531 行的 valid_mask 改为 log_valid_mask。
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,6 +193,10 @@ class SamplerOutput: | |
| # check whether the current path is speculative or non-speculative when | ||
| # interpreting the dimension. | ||
| sampling_mask: Optional[List[np.ndarray]] = None | ||
| # logZ_K for each request: log(sum(probs in candidate set K)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 注释中 当前注释说形状是 在 speculative decoding 模式下( 建议更新注释为: # logZ_K for logprobs renormalization:
# - Non-speculative: shape [num_reqs], one value per request
# - Speculative: shape [total_accepted_tokens], one per accepted token |
||
| # Used for renormalizing logprobs to match the truncated sampling distribution. | ||
| # Shape: [num_reqs] | ||
zeroRains marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logz_per_batch: Optional[np.ndarray] = None | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR 标题目前是 "[KSM] fix logz when topk"(还带引号),与模板要求的
[CLASS]Title/Tag 列表不一致;另外 PR 描述里的 Modifications/Usage/Accuracy Tests 等段落基本未补充。建议按模板补全描述并把标题改成符合约定的 Tag(例如[BugFix] ...或其它语义明确的标签)。