From 0097a00a5e21d54c0661004e76089a7a3a311148 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Wed, 8 Apr 2026 11:17:11 +0800 Subject: [PATCH 1/2] [KSM] fix logz when topk --- .../model_executor/layers/sample/sampler.py | 51 ++++++++++++++++--- fastdeploy/worker/output.py | 4 ++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 97d1b8a34c5..f127c21d183 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -100,7 +100,7 @@ def _compute_sampling_mask( top_p: paddle.Tensor, top_k: Optional[paddle.Tensor] = None, top_k_list: Optional[list] = None, -) -> List[np.ndarray]: +) -> tuple[List[np.ndarray], np.ndarray]: """ Compute a combined top-k + top-p (nucleus) sampling mask as sparse retained-token indices. @@ -125,8 +125,11 @@ def _compute_sampling_mask( top-k filtering is needed at all. Returns: - sparse_indices: List of length num_reqs; element i is a 1-D int64 - numpy array of the retained vocab indices for request i. + Tuple of (sparse_indices, logz_per_batch): + - sparse_indices: List of length num_reqs; element i is a 1-D int64 + numpy array of the retained vocab indices for request i. + - logz_per_batch: 1-D numpy array of shape [num_reqs] containing + log(Z_K) where Z_K is the sum of probabilities in the candidate set. """ real_bsz = probs.shape[0] vocab_size = probs.shape[1] @@ -193,13 +196,45 @@ def _compute_sampling_mask( k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] max_k = int(k_per_row.max().item()) + # ------------------------------------------------------------------ + # Stage 5: compute logZ for renormalization + # + # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the + # distribution actually sampled from (top-k truncated + top-p nucleus). + # + # When top_k is active the sampling pipeline first renormalises to + # π_topk, then applies top-p on π_topk. The total log-normaliser + # that maps π_full → π_mask absorbs both steps: + # + # logZ = log Z_topk + log Z_topp_on_renorm + # + # where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed) + # Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K) + # + # Substituting: + # log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp + # = log π_topk(k) - log Z_topp ✓ + # + # When top_k is absent Z_topk = 1 → logZ = log Z_topp as before. + # ------------------------------------------------------------------ + if has_top_k: + # Z_topp: sum of renormed probs that survive the final mask + candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) + z_topp = candidate_probs.sum(axis=-1) # [B] + # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] + log_z_topk = paddle.log(row_sums.squeeze(-1)) + logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] + else: + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] # Transfer only the leading max_k columns — typically max_k << vocab_size. indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k] mask_window_cpu = final_mask[:, :max_k].cpu().numpy() # [B, max_k] sparse_indices = [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)] - return sparse_indices + return sparse_indices, logz_per_batch class GuidedDecoding: @@ -647,8 +682,9 @@ def forward_cuda( # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. # Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated. sampling_mask = None + logz_per_batch = None if sampling_metadata.keep_sampling_mask: - sampling_mask = _compute_sampling_mask( + sampling_mask, logz_per_batch = _compute_sampling_mask( probs, sampling_metadata.top_p, top_k=sampling_metadata.top_k, @@ -679,6 +715,7 @@ def forward_cuda( logprobs_tensors=logprobs_tensors, logits=logits, sampling_mask=sampling_mask, + logz_per_batch=logz_per_batch, ) return sampler_output @@ -1015,6 +1052,7 @@ def forward_cuda( # Compute sampling mask at accepted token positions. # Shape: [total_accepted_tokens, vocab_size], bool (CPU). sampling_mask = None + logz_per_batch = None if keep_sampling_mask: # Expand top_p from [batch, 1] to [total_accepted, 1]. accept_top_p = sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) @@ -1027,7 +1065,7 @@ def forward_cuda( accept_top_k = ( sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) ) - sampling_mask = _compute_sampling_mask( + sampling_mask, logz_per_batch = _compute_sampling_mask( target_probs, accept_top_p, top_k=accept_top_k, @@ -1041,6 +1079,7 @@ def forward_cuda( cu_batch_token_offset=share_inputs["cu_batch_token_offset"], logits=logits, sampling_mask=sampling_mask, + logz_per_batch=logz_per_batch, ) return sampler_output diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 659c9359a0d..ec07a1ae75d 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -193,6 +193,10 @@ class SamplerOutput: # check whether the current path is speculative or non-speculative when # interpreting the dimension. sampling_mask: Optional[List[np.ndarray]] = None + # logZ_K for each request: log(sum(probs in candidate set K)) + # Used for renormalizing logprobs to match the truncated sampling distribution. + # Shape: [num_reqs] + logz_per_batch: Optional[np.ndarray] = None @dataclass From 62c720f3f7e73912c12d8d9baadecd693aa79319 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Wed, 8 Apr 2026 15:02:43 +0800 Subject: [PATCH 2/2] add the logz renormalize --- .../model_executor/pre_and_post_process.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 5d2ed846057..e37b52a41c8 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -355,6 +355,26 @@ def post_process_normal( sampler_output.sampled_token_ids, model_output.is_block_step, ) + # Renormalize logprobs to match truncated sampling distribution (when enabled). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + # logprobs_tensors.logprobs: [B, max_num_logprobs + 1] + logprobs = sampler_output.logprobs_tensors.logprobs + # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request + logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) + # Renormalize: log π_masked = log π_full - log Z_K + # Only normalize valid candidates; padding positions use -inf + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, + logprobs - logz.unsqueeze(1), # broadcast subtraction + paddle.full_like(logprobs, float("-inf")), + ) + # Update logprobs_tensors with normalized values + sampler_output.logprobs_tensors = LogprobsTensors( + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, + ) def save_output_normal( @@ -504,6 +524,20 @@ def post_process_specualate( model_output.mask_rollback, ) + # Renormalize logprobs to match truncated sampling distribution (when enabled). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + logprobs = sampler_output.logprobs_tensors.logprobs + logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) + ) + sampler_output.logprobs_tensors = LogprobsTensors( + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, + ) + if not skip_save_output: if sampler_output.logprobs_tensors is None: recover_model_output_map = recover_batch_index_for_output(