From 6a5dafc3d300d69eb01a484e355725d01cc94d68 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 7 Apr 2026 20:57:26 +0800 Subject: [PATCH] fix logz when top_k --- .../model_executor/layers/sample/sampler.py | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index a0fc666bca4..e2d62a63c15 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -197,13 +197,37 @@ def _compute_sampling_mask( max_k = int(k_per_row.max().item()) # ------------------------------------------------------------------ - # Stage 5: compute logZ_K for renormalization - # Z_K = sum(probs[i] * final_mask[i]) for each request i - # logZ_K = log(Z_K), with small constant to avoid log(0) + # Stage 5: compute logZ for renormalization + # + # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the + # distribution actually sampled from (top-k truncated + top-p nucleus). + # + # When top_k is active the sampling pipeline first renormalises to + # π_topk, then applies top-p on π_topk. The total log-normaliser + # that maps π_full → π_mask absorbs both steps: + # + # logZ = log Z_topk + log Z_topp_on_renorm + # + # where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed) + # Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K) + # + # Substituting: + # log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp + # = log π_topk(k) - log Z_topp ✓ + # + # When top_k is absent Z_topk = 1 → logZ = log Z_topp as before. # ------------------------------------------------------------------ - candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) - z_k = candidate_probs.sum(axis=-1) # [B] - logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] + if has_top_k: + # Z_topp: sum of renormed probs that survive the final mask + candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) + z_topp = candidate_probs.sum(axis=-1) # [B] + # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] + log_z_topk = paddle.log(row_sums.squeeze(-1)) + logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] + else: + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] # Transfer only the leading max_k columns — typically max_k << vocab_size. indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]