diff --git a/fastdeploy/config.py b/fastdeploy/config.py index b15a6dc824b..f86508c4f55 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -211,6 +211,7 @@ def __init__( self.enable_logprob = False self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" + self.enable_keep_sampling_mask = False self.redundant_experts_num = 0 self.seed = 0 self.quantization = None diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ff0965c56bb..92f49d2837b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -460,6 +460,14 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + enable_keep_sampling_mask: bool = False + """ + When enabled, the server returns a sparse index list for each generated token, indicating + which vocabulary positions were retained after top_p/top_k sampling, and streams it to + the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]], + where each inner list contains the retained vocabulary indices for a predicted token. + """ + max_logprobs: int = 20 """ Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the @@ -893,6 +901,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--enable-keep-sampling-mask", + action="store_true", + default=EngineArgs.enable_keep_sampling_mask, + help=( + "Enable output of sampling mask as a sparse index list over the vocabulary. " + "For non-MTP decoding, this is a list[int] per token step indicating which " + "vocabulary indices were kept after top_p/top_k sampling. " + "For MTP decoding, this is a list[list[int]] per token step, where each inner " + "list corresponds to one MTP group." + ), + ) model_group.add_argument( "--max-logprobs", type=int, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index f1152c6e22c..4cd2ad38512 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2508,6 +2508,7 @@ def _start_worker_service(self): "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 283693fae8c..44edea80d34 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -655,6 +655,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0e95cd5e1fb..ccab1ac4114 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -727,6 +727,10 @@ class CompletionOutput: delta_message: Optional[DeltaMessage] = None multipart: Optional[list[Any]] = None num_image_tokens: Optional[int] = None + # Sparse indices of retained vocab ids: + # - Non-MTP: list[int] + # - MTP: list[list[int]] + sampling_mask: Optional[Any] = None def to_dict(self): """ @@ -745,6 +749,7 @@ def to_dict(self): "text": self.text, "reasoning_content": self.reasoning_content, "reasoning_token_num": self.reasoning_token_num, + "sampling_mask": self.sampling_mask, } @classmethod diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 3560f3a8aef..42923623776 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -270,6 +270,8 @@ class ChatCompletionResponseChoice(BaseModel): prompt_logprobs: Optional[PromptLogprobs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] speculate_metrics: Optional[SpeculateMetrics] = None + # Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token + sampling_mask: Optional[List[List[int]]] = None class ChatCompletionResponse(BaseModel): @@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None + # Per-token index list of retained positions after top_p sampling. + # Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step). + sampling_mask: Optional[List[List[int]]] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index eb106f6550f..55bd37412a0 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -435,6 +435,11 @@ async def chat_completion_stream_generator( delta=delta_message, logprobs=logprobs_res, draft_logprobs=draft_logprobs_res, + sampling_mask=( + self._make_sampling_mask_list(output["sampling_mask"]) + if output.get("sampling_mask") is not None + else None + ), arrival_time=arrival_time, speculate_metrics=output_speculate_metrics, ) @@ -580,6 +585,7 @@ async def chat_completion_full_generator( decoder_base_url=self.tokenizer_base_url, ) prompt_logprobs_res_list = [[] for _ in range(num_choices)] + sampling_mask_list = [[] for _ in range(num_choices)] speculate_metrics = [None for _ in range(num_choices)] choices = [] while num_choices > 0: @@ -660,6 +666,9 @@ async def chat_completion_full_generator( ) if prompt_logprobs_res: prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res)) + output_sampling_mask = output.get("sampling_mask", None) + if output_sampling_mask is not None: + sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask)) speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None) if data["finished"]: trace_carrier = data.get("trace_carrier") @@ -695,6 +704,7 @@ async def chat_completion_full_generator( draft_logprob_contents=draft_logprob_contents, response_processor=response_processor, prompt_logprobs_res_list=prompt_logprobs_res_list, + sampling_mask_list=sampling_mask_list, max_tokens=max_tokens, speculate_metrics=speculate_metrics[idx], ) @@ -749,6 +759,7 @@ async def _create_chat_completion_choice( logprob_contents: list, draft_logprob_contents: list, prompt_logprobs_res_list: list, + sampling_mask_list: list, response_processor: ChatResponseProcessor, max_tokens: int, speculate_metrics: SpeculateMetrics | None, @@ -787,6 +798,11 @@ async def _create_chat_completion_choice( if prompt_logprobs_res_list[idx]: prompt_logprobs_full_res = prompt_logprobs_res_list[idx] + # Flatten per-step List[List[int]] into a single List[List[int]] over all tokens. + sampling_mask_full_res = None + if sampling_mask_list and sampling_mask_list[idx]: + sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step] + num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0) @@ -810,6 +826,7 @@ async def _create_chat_completion_choice( logprobs=logprobs_full_res, draft_logprobs=draft_logprobs_full_res, prompt_logprobs=prompt_logprobs_full_res, + sampling_mask=sampling_mask_full_res, finish_reason=finish_reason, speculate_metrics=speculate_metrics, ) @@ -1000,3 +1017,18 @@ def _make_logprob_dict( ) for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) } + + @staticmethod + def _make_sampling_mask_list(sampling_mask) -> List[List[int]]: + """Wrap sampling_mask into a uniform List[List[int]] format. + + sampling_mask is already in sparse-index form (no bool-to-index conversion needed): + Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]] + MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...] + """ + assert sampling_mask is not None + if sampling_mask and isinstance(sampling_mask[0], list): + # MTP: already List[List[int]], return as-is + return sampling_mask + # Non-MTP: already List[int], wrap in outer list for uniform format + return [sampling_mask] diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index 559abdb298e..71796d73444 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -133,7 +133,7 @@ def build_output_logprobs( is_naive: bool = False, logprobs_mode: str = "default", compute_logprobs_fn: Optional[Callable] = None, -) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: +) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]: """ Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. @@ -153,15 +153,12 @@ def build_output_logprobs( scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs". Returns: - tuple: (logprobs_tensors, cu_batch_token_offset) + tuple: (logprobs_tensors, cu_batch_token_offset, output_logits) """ num_logprobs = sampling_metadata.max_num_logprobs logprobs_tensors = None cu_batch_token_offset = None - if num_logprobs is None: - return logprobs_tensors, cu_batch_token_offset - real_bsz = share_inputs["seq_lens_this_time"].shape[0] if is_naive: @@ -208,6 +205,10 @@ def build_output_logprobs( mask = idx < share_inputs["accept_num"].unsqueeze(1) token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) + # Adapate for sampling mask + if num_logprobs is None: + return None, None, output_logits + # Compute logprobs with temperature scaling and top_p normalization if logprobs_mode == "raw_logprobs": raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) @@ -217,5 +218,5 @@ def build_output_logprobs( raw_logprobs = F.log_softmax(output_logits, axis=-1) logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) - - return logprobs_tensors, cu_batch_token_offset + # output_logits use to compute sampling_mask + return logprobs_tensors, cu_batch_token_offset, output_logits diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 0d7f6915ab4..e2ecb276957 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -66,3 +66,5 @@ class SamplingMetadata: # Add for HPU post-processing seq_lens_encoder: Optional[paddle.Tensor] = None seq_lens_decoder: Optional[paddle.Tensor] = None + # Add for keep sampling mask + keep_sampling_mask: Optional[bool] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index ff072e1a8ef..717e6bf8efa 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -125,6 +125,7 @@ def top_k_top_p_sampling( if topp_seed is not None: topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype) topp_seed_device.copy_(topp_seed, False) + _, ids = paddle.tensor.top_p_sampling( x, top_p, diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 08a33c11096..a458513572b 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -19,6 +19,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, List, Optional +import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -105,6 +106,149 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le return top_p_padding, top_k_padding, topp_seed +def _compute_sampling_mask( + probs: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, +) -> tuple[List[np.ndarray], np.ndarray]: + """ + Compute a combined top-k + top-p (nucleus) sampling mask as sparse + retained-token indices. + + Processing order: + 1. Sort probs descending once (shared by top-k and top-p stages). + 2. top-k mask — zero out positions beyond top_k[i] in sorted order. + 3. top-k renorm — renormalise in-place after truncation. + 4. top-p mask — cumsum on the already-sorted renormed probs; no + second argsort needed. + 5. intersect — AND of the two masks, applied on GPU before D2H. + + Either filter can be disabled: + - top-k is skipped when top_k_list is None or all values <= 0. + - top-p[i] >= 1.0 → keep all tokens for that request. + + Args: + probs: [num_reqs, vocab_size] softmax probabilities (GPU). + top_p: [num_reqs, 1] top-p threshold per request (GPU). + top_k: [num_reqs, 1] top-k per request (GPU, int); 0 = disabled. + top_k_list: Python list of top-k values; used to decide whether any + top-k filtering is needed at all. + + Returns: + Tuple of (sparse_indices, logz_per_batch): + - sparse_indices: List of length num_reqs; element i is a 1-D int64 + numpy array of the retained vocab indices for request i. + - logz_per_batch: 1-D numpy array of shape [num_reqs] containing + log(Z_K) where Z_K is the sum of probabilities in the candidate set. + """ + real_bsz = probs.shape[0] + vocab_size = probs.shape[1] + top_p = top_p[:real_bsz] # [B, 1] + + has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list) + + # ------------------------------------------------------------------ + # Stage 1: single sort — descending by probability. + # sorted_indices / sorted_probs are reused by both top-k and top-p. + # ------------------------------------------------------------------ + sorted_indices = paddle.argsort(probs, axis=-1, descending=True) # [B, V] + sorted_probs = paddle.take_along_axis(probs, sorted_indices, axis=-1) # [B, V] + + # ------------------------------------------------------------------ + # Stage 2: top-k mask (GPU, no D2H) + # ------------------------------------------------------------------ + if has_top_k: + top_k = top_k[:real_bsz] # [B, 1] + # col_idx[0, j] == j; compare against per-row top_k threshold. + col_idx = paddle.arange(vocab_size, dtype=top_k.dtype).unsqueeze(0) # [1, V] + # top_k == 0 means "disabled" → keep all columns for that row. + effective_k = paddle.where(top_k > 0, top_k, paddle.full_like(top_k, vocab_size)) + topk_mask = col_idx < effective_k # [B, V], True = inside top-k + + # Zero out tail, then renorm row-wise. + masked_sorted_probs = paddle.where(topk_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + row_sums = masked_sorted_probs.sum(axis=-1, keepdim=True).clip(min=1e-9) + renorm_sorted_probs = masked_sorted_probs / row_sums # [B, V] + else: + topk_mask = None + renorm_sorted_probs = sorted_probs + + # ------------------------------------------------------------------ + # Stage 3: top-p mask on already-sorted renormed probs (no re-sort). + # ------------------------------------------------------------------ + cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] + topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] + # When top_p[i] >= 1.0, keep the entire row. + topp_mask = paddle.where( + (top_p >= 1.0).expand_as(topp_mask), + paddle.ones_like(topp_mask), + topp_mask, + ) + + # Extend mask to cover sort tie-breaking: include all tokens whose + # probability >= the boundary token's probability (last retained + # in sorted order). In descending-sorted probs this just extends + # the contiguous True block by the run of equal-prob tokens. + k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1] + # boundary_idx = last True position (k-1), clamp for safety + boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1] + boundary_prob = paddle.take_along_axis( + renorm_sorted_probs, + boundary_idx, + axis=-1, + ) # [B, 1] + topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob) + + # ------------------------------------------------------------------ + # Stage 4: intersect on GPU, then minimal D2H. + # ------------------------------------------------------------------ + final_mask = topk_mask & topp_mask if has_top_k else topp_mask # [B, V] + + k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] + max_k = int(k_per_row.max().item()) + + # ------------------------------------------------------------------ + # Stage 5: compute logZ for renormalization + # + # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the + # distribution actually sampled from (top-k truncated + top-p nucleus). + # + # When top_k is active the sampling pipeline first renormalises to + # π_topk, then applies top-p on π_topk. The total log-normaliser + # that maps π_full → π_mask absorbs both steps: + # + # logZ = log Z_topk + log Z_topp_on_renorm + # + # where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed) + # Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K) + # + # Substituting: + # log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp + # = log π_topk(k) - log Z_topp ✓ + # + # When top_k is absent Z_topk = 1 → logZ = log Z_topp as before. + # ------------------------------------------------------------------ + if has_top_k: + # Z_topp: sum of renormed probs that survive the final mask + candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) + z_topp = candidate_probs.sum(axis=-1) # [B] + # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] + log_z_topk = paddle.log(row_sums.squeeze(-1)) + logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] + else: + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] + + # Transfer only the leading max_k columns — typically max_k << vocab_size. + indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k] + mask_window_cpu = final_mask[:, :max_k].cpu().numpy() # [B, max_k] + + sparse_indices = [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)] + return sparse_indices, logz_per_batch + + class GuidedDecoding: """ processor for guided decoding. @@ -554,6 +698,19 @@ def forward_cuda( _record_logits_diagnostic(logits, tag="post_penalty_logits", probs=probs) probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) + + # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. + # Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated. + sampling_mask = None + logz_per_batch = None + if sampling_metadata.keep_sampling_mask: + sampling_mask, logz_per_batch = _compute_sampling_mask( + probs, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + ) + _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, @@ -577,6 +734,8 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, logits=logits, + sampling_mask=sampling_mask, + logz_per_batch=logz_per_batch, ) return sampler_output @@ -1029,9 +1188,10 @@ def forward_cuda( reject_all_drafts, ) + keep_sampling_mask = sampling_metadata.keep_sampling_mask # Build logprobs via unified path (outside of sampling logic) - if sampling_metadata.max_num_logprobs is not None: - logprobs_tensors, cu_batch_token_offset = build_output_logprobs( + if sampling_metadata.max_num_logprobs is not None or keep_sampling_mask: + logprobs_tensors, cu_batch_token_offset, target_logits = build_output_logprobs( logits, sampling_metadata, share_inputs, @@ -1042,6 +1202,33 @@ def forward_cuda( sampler_output.logprobs_tensors = logprobs_tensors if cu_batch_token_offset is not None: sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() + if keep_sampling_mask: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1]) + # Derive target probs from already-extracted target_logits; avoids a second kernel call. + target_probs = F.softmax(target_logits, axis=-1) + # Compute sampling mask at accepted token positions. + # Shape: [total_accepted_tokens, vocab_size], bool (CPU). + # Expand top_p from [batch, 1] to [total_accepted, 1]. + # total_accepted = accept_nums.sum() + accept_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + accept_top_k = None + if ( + sampling_metadata.top_k is not None + and sampling_metadata.top_k_list + and any(x > 0 for x in sampling_metadata.top_k_list) + ): + accept_top_k = ( + sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + sampler_output.sampling_mask, sampler_output.logz_per_batch = _compute_sampling_mask( + target_probs, + accept_top_p, + top_k=accept_top_k, + top_k_list=sampling_metadata.top_k_list, + ) return sampler_output def forward_xpu( diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0fc6bfde5d0..9be94c2bd49 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -21,7 +21,7 @@ import paddle from fastdeploy import envs -from fastdeploy.config import SpeculativeConfig +from fastdeploy.config import PREEMPTED_TOKEN_ID, SpeculativeConfig from fastdeploy.platforms import current_platform from fastdeploy.worker.input_batch import ( InputBatch, @@ -216,6 +216,7 @@ def _build_stream_transfer_data( pooler_outputs: List[PoolingSequenceGroupOutput] = None, logprobs: Optional[LogprobsTensors] = None, prompt_logprobs_list: Optional[LogprobsTensors] = None, + sampling_mask: Optional[List[np.ndarray]] = None, ): """Split output_tokens and output""" @@ -225,6 +226,8 @@ def _build_stream_transfer_data( output_tokens = output_tokens.numpy().reshape([-1]) output_tokens_lists = np.split(output_tokens, output_tokens.shape[0]) + sampling_mask_list = sampling_mask + for bid, output_token_per_sample in enumerate(output_tokens_lists): stream_transfer_data = StreamTransferData( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid @@ -233,6 +236,8 @@ def _build_stream_transfer_data( stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1) if prompt_logprobs_list: stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] + if sampling_mask_list is not None: + stream_transfer_data.sampling_mask = sampling_mask_list[bid] stream_transfer_datas.append(stream_transfer_data) elif pooler_outputs is not None: for bid, pooler_output in enumerate(pooler_outputs): @@ -250,6 +255,86 @@ def _build_stream_transfer_data( return stream_transfer_datas +def _build_speculative_stream_transfer_data( + accept_tokens_cpu, + accept_num_cpu, + logprobs: Optional[LogprobsTensors] = None, + prompt_logprobs_list=None, + sampling_mask=None, + cu_batch_token_offset=None, + output_type: int = 3, + last_preempted_idx=None, +): + """Build StreamTransferData list for speculative decoding output. + + Args: + accept_tokens_cpu: paddle.Tensor [max_bsz, max_draft+1] of accepted token IDs (CPU pinned). + accept_num_cpu: paddle.Tensor [max_bsz] of per-request accept counts (CPU pinned). + logprobs: LogprobsTensors with rows flattened across all accepted tokens. + prompt_logprobs_list: per-request prompt logprobs list. + sampling_mask: per-token sampling mask list. + cu_batch_token_offset: paddle.Tensor cumulative token offset for logprobs slicing. + output_type: 3=target, 4=draft. + """ + stream_transfer_datas = [] + accept_num_np = accept_num_cpu.numpy().flatten() + accept_tokens_np = accept_tokens_cpu.numpy() + batch_size = accept_num_np.shape[0] + + # Inject PREEMPTED_TOKEN_ID for slots that were preempted this step, + # mirroring what speculate_save_output kernel does in the non-ZMQ path. + if last_preempted_idx is not None: + preempted_np = last_preempted_idx.numpy().flatten() + for bid in range(min(len(preempted_np), batch_size)): + if preempted_np[bid] != 0: + accept_num_np[bid] = PREEMPTED_TOKEN_ID + + # Build cumulative offset for logprobs slicing + logprobs_offset = 0 + cu_offsets = None + if cu_batch_token_offset is not None: + cu_offsets = cu_batch_token_offset.numpy().flatten() + + for bid in range(batch_size): + accept_num_val = int(accept_num_np[bid]) + num_tokens = max(accept_num_val, 0) + + tokens_np = accept_tokens_np[bid, :num_tokens] if num_tokens > 0 else np.array([], dtype=np.int64) + + stream_data = StreamTransferData( + decoder_state=DecoderState.TEXT, + batch_id=bid, + tokens=tokens_np, + speculaive_decoding=True, + accept_tokens=tokens_np, + accept_num=np.array([accept_num_val], dtype=np.int32), + output_type=output_type, + ) + + if cu_offsets is not None and len(cu_offsets) > bid + 1: + start = int(cu_offsets[bid]) + end = int(cu_offsets[bid + 1]) + else: + start = logprobs_offset + end = logprobs_offset + num_tokens + + # Slice logprobs for this request's accepted tokens + if logprobs is not None and num_tokens > 0: + stream_data.logprobs = logprobs.slice_rows(start, end) + + if prompt_logprobs_list and bid < len(prompt_logprobs_list): + stream_data.prompt_logprobs = prompt_logprobs_list[bid] + + if sampling_mask is not None and num_tokens > 0: + # Slice the flat per-token mask list to get this request's masks + stream_data.sampling_mask = sampling_mask[start:end] + + logprobs_offset += num_tokens + stream_transfer_datas.append(stream_data) + + return stream_transfer_datas + + def post_process_normal( sampler_output: SamplerOutput, model_output: ModelOutputData, @@ -367,6 +452,26 @@ def post_process_normal( sampler_output.sampled_token_ids, model_output.is_block_step, ) + # Renormalize logprobs to match truncated sampling distribution (when enabled). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + # logprobs_tensors.logprobs: [B, max_num_logprobs + 1] + logprobs = sampler_output.logprobs_tensors.logprobs + # logz_per_batch: [B], log(sum(probs in candidate set K)) for each request + logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) + # Renormalize: log π_masked = log π_full - log Z_K + # Only normalize valid candidates; padding positions use -inf + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, + logprobs - logz.unsqueeze(1), # broadcast subtraction + paddle.full_like(logprobs, float("-inf")), + ) + # Update logprobs_tensors with normalized values + sampler_output.logprobs_tensors = LogprobsTensors( + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, + ) def save_output_normal( @@ -393,6 +498,7 @@ def save_output_normal( recover_share_inputs_map["sampled_token_ids"], logprobs=sampler_output.logprobs_tensors, prompt_logprobs_list=model_output.prompt_logprobs_list, + sampling_mask=sampler_output.sampling_mask, ) async_output_queue.put(output) else: @@ -520,40 +626,32 @@ def post_process_specualate( model_output.max_dec_len, # max_dec_len ) + # Renormalize logprobs to match truncated sampling distribution (when enabled). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + logprobs = sampler_output.logprobs_tensors.logprobs + logz = paddle.to_tensor(sampler_output.logz_per_batch, dtype=logprobs.dtype) + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) + ) + sampler_output.logprobs_tensors = LogprobsTensors( + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks, + ) + def save_output_specualate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, + async_output_queue: queue.Queue = None, save_each_rank: bool = False, skip_save_output: bool = False, + enable_draft_logprob: bool = False, ): - if not skip_save_output: - if sampler_output.logprobs_tensors is None: - recover_share_inputs = recover_batch_index_for_output( - share_inputs, - model_output.index_to_batch_id, - model_output.enable_pd_reorder, - [ - "accept_tokens_cpu", - "accept_num_cpu", - "seq_lens_decoder_cpu", - "prompt_lens_cpu", - "last_preempted_idx", - ], - ) - speculate_save_output( - recover_share_inputs["accept_tokens_cpu"], - recover_share_inputs["accept_num_cpu"], - model_output.not_need_stop, - recover_share_inputs["seq_lens_decoder_cpu"], - recover_share_inputs["prompt_lens_cpu"], - recover_share_inputs["last_preempted_idx"], - model_output.mp_rank, - save_each_rank, - bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), - ) - else: + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + if save_each_rank or model_output.mp_rank == 0: recover_batch_index_for_sampler_output( sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder ) @@ -570,22 +668,91 @@ def save_output_specualate( "last_preempted_idx", ], ) - speculate_save_output_topk( - recover_share_inputs["sampled_token_ids"], - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - recover_share_inputs["accept_num_cpu"], - sampler_output.cu_batch_token_offset, - model_output.not_need_stop, - recover_share_inputs["seq_lens_decoder_cpu"], - recover_share_inputs["prompt_lens_cpu"], - recover_share_inputs["last_preempted_idx"], - 3, # mtype - model_output.mp_rank, - save_each_rank, + # target tokens (mtype=3) + output = _build_speculative_stream_transfer_data( + accept_tokens_cpu=recover_share_inputs["accept_tokens_cpu"], + accept_num_cpu=recover_share_inputs["accept_num_cpu"], + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + sampling_mask=sampler_output.sampling_mask, + cu_batch_token_offset=sampler_output.cu_batch_token_offset, + output_type=3, + last_preempted_idx=recover_share_inputs["last_preempted_idx"], ) - share_inputs["last_preempted_idx"][:] = 0 + async_output_queue.put(output) + + # draft tokens (mtype=4): when enable_draft_logprob and logprobs available + if sampler_output.logprobs_tensors is not None and enable_draft_logprob: + draft_output = _build_speculative_stream_transfer_data( + accept_tokens_cpu=recover_share_inputs["accept_tokens_cpu"], + accept_num_cpu=recover_share_inputs["accept_num_cpu"], + logprobs=sampler_output.logprobs_tensors, + sampling_mask=None, + cu_batch_token_offset=sampler_output.cu_batch_token_offset, + output_type=4, + ) + async_output_queue.put(draft_output) + + share_inputs["last_preempted_idx"][:] = 0 + else: + if not skip_save_output: + if sampler_output.logprobs_tensors is None: + recover_share_inputs = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + [ + "accept_tokens_cpu", + "accept_num_cpu", + "seq_lens_decoder_cpu", + "prompt_lens_cpu", + "last_preempted_idx", + ], + ) + speculate_save_output( + recover_share_inputs["accept_tokens_cpu"], + recover_share_inputs["accept_num_cpu"], + model_output.not_need_stop, + recover_share_inputs["seq_lens_decoder_cpu"], + recover_share_inputs["prompt_lens_cpu"], + recover_share_inputs["last_preempted_idx"], + model_output.mp_rank, + save_each_rank, + bool(envs.ENABLE_V1_KVCACHE_SCHEDULER), + ) + else: + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + recover_share_inputs = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + [ + "sampled_token_ids", + "accept_tokens_cpu", + "accept_num_cpu", + "seq_lens_decoder_cpu", + "prompt_lens_cpu", + "last_preempted_idx", + ], + ) + speculate_save_output_topk( + recover_share_inputs["sampled_token_ids"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_share_inputs["accept_num_cpu"], + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + recover_share_inputs["seq_lens_decoder_cpu"], + recover_share_inputs["prompt_lens_cpu"], + recover_share_inputs["last_preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) + share_inputs["last_preempted_idx"][:] = 0 def post_process( diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index b32e01c954f..ded86b3d7b6 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -44,5 +44,10 @@ class StreamTransferData: prompt_logprobs: Optional[LogprobsTensors] = None accept_tokens: Optional[np.array] = None accept_num: Optional[np.array] = None + output_type: int = 3 # 3=target, 4=draft # [num_reqs, hidden_size] pooler_output: Optional[np.array] = None + # Sparse sampling mask(s) for top_p/top_k: + # - Non-speculative: single 1-D int64 numpy array of retained vocab indices. + # - Speculative: List[np.ndarray], one 1-D array per accepted token. + sampling_mask: Optional[np.array] = None diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 1ab0b48f350..47322b20144 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -83,6 +83,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob + self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False) self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob if self.speculative_decoding: @@ -263,7 +264,7 @@ def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: main_process_metrics.request_token_ratio.observe(token_ratio) llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: - self._compute_speculative_status() + self._compute_speculative_status(result) if not is_prefill: self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, batch_id, task, result, is_prefill) @@ -282,6 +283,13 @@ def _process_batch_output_use_zmq(self, receive_datas): task: Request = self.resource_manager.tasks_list[i] task_id = task.request_id + + if self.speculative_decoding and getattr(stream_data, "speculaive_decoding", False): + result = self._process_speculative_output_use_zmq(stream_data, task, i, batch_result) + if result is not None: + batch_result.append(result) + continue + token_ids = stream_data.tokens # numpy.array if token_ids is not None and token_ids[-1] < 0: if envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -357,6 +365,12 @@ def _process_batch_output_use_zmq(self, receive_datas): result.prompt_logprobs = stream_data.prompt_logprobs except Exception as e: llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") + if self.use_sampling_mask: + if getattr(stream_data, "sampling_mask", None) is not None: + try: + result.outputs.sampling_mask = stream_data.sampling_mask.tolist() + except Exception as e: + llm_logger.warning(f"Failed to parse sampling_mask from StreamTransferData: {e}") if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages @@ -372,12 +386,273 @@ def _process_batch_output_use_zmq(self, receive_datas): return batch_result + def _process_speculative_output_use_zmq(self, stream_data, task, batch_id, batch_result): + """ + Process speculative decoding output from a single StreamTransferData. + + Args: + stream_data: StreamTransferData with speculative decoding fields populated. + task: Request object for this batch entry. + batch_id: The batch index. + batch_result: The batch result list (for draft token path to append to). + + Returns: + RequestOutput or None (None means the request was skipped/preempted). + """ + task_id = task.request_id + accept_num_val = int(stream_data.accept_num[0]) + mtype = getattr(stream_data, "output_type", 3) + + # --- Draft token path (mtype=4) --- + if mtype == 4: + return self._process_draft_output_use_zmq(stream_data, task, batch_id, accept_num_val) + + # --- Target token path (mtype=3) --- + + # Per-request per-head accept count tracking + self._record_speculative_decoding_accept_num_per_request(task_id, accept_num_val) + + # Preemption handling via accept_num + if accept_num_val == PREEMPTED_TOKEN_ID: + llm_logger.info(f"sync preemption for request_id {task_id} done.") + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.to_be_aborted_req_id_set: + self.resource_manager.recycle_abort_task(task_id) + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + self.resource_manager.reschedule_preempt_task(task_id) + return None + + # Recovery stop + recovery_stop = False + if accept_num_val == -3: + recovery_stop = True + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + else: + token_ids = stream_data.accept_tokens[:accept_num_val].tolist() + + # No tokens accepted this step + if accept_num_val == 0: + return None + + if self.cfg.scheduler_config.splitwise_role == "decode": + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.to_be_aborted_req_id_set: + return None + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + return None + + # Global counters for speculative status computation + self.total_step += 1 + self.number_of_output_tokens += len(token_ids) + + # Metrics recording + current_time = time.time() + if self.tokens_counter[task_id] == 0: + task.metrics.record_recv_first_token() + task.metrics.cal_cost_time() + metrics = copy.copy(task.metrics) + self._record_first_token_metrics(task, current_time) + else: + task.metrics.record_recv_token() + if self.tokens_counter[task_id] == 1 and self.cfg.scheduler_config.splitwise_role == "decode": + task.metrics.record_decode_recv_second_token() + metrics = copy.copy(task.metrics) + + self._record_metrics(task, current_time, token_ids) + + is_prefill = task.disaggregate_info is not None and task.disaggregate_info.get("role") == "prefill" + is_decode = task.disaggregate_info is not None and self.cfg.scheduler_config.splitwise_role == "decode" + + # Build RequestOutput with output_type + result = RequestOutput( + request_id=task_id, + output_type=mtype, + outputs=CompletionOutput( + index=batch_id, + send_idx=self.tokens_counter[task_id], + token_ids=[], + draft_token_ids=[], + ), + finished=False, + metrics=metrics, + ic_req_data=task.ic_req_data, + prompt_token_ids_len=task.prompt_token_ids_len, + ) + + # First token additional info + if self.tokens_counter[task_id] == 0: + if task.messages is not None: + result.prompt = task.messages + if task.get("multimodal_inputs", None): + result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0) + result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0) + result.num_cached_tokens = task.num_cached_tokens + + # is_prefill + multi-token -> draft_token_ids + if is_prefill and len(token_ids) > 1: + result.outputs.draft_token_ids = copy.deepcopy(token_ids) + + # Parse logprobs once if available + logprobs_lists = None + if self.use_logprobs and getattr(stream_data, "logprobs", None) is not None: + try: + logprobs_lists = stream_data.logprobs.tolists() + except Exception as e: + llm_logger.warning(f"Failed to parse speculative logprobs from StreamTransferData: {e}") + + if getattr(stream_data, "prompt_logprobs", None) is not None: + try: + result.prompt_logprobs = stream_data.prompt_logprobs + except Exception as e: + llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") + + if self.use_sampling_mask: + if getattr(stream_data, "sampling_mask", None) is not None: + try: + # Speculative: stream_data.sampling_mask is List[np.ndarray], + # one per accepted token → List[List[int]] + mask = stream_data.sampling_mask + if isinstance(mask, list): + result.outputs.sampling_mask = [m.tolist() for m in mask] + else: + result.outputs.sampling_mask = mask.tolist() + except Exception as e: + llm_logger.warning(f"Failed to parse sampling_mask from StreamTransferData: {e}") + + # Multi-token loop processing + for batch_token_index, token_id in enumerate(token_ids): + self.tokens_counter[task_id] += 1 + if token_id != RECOVERY_STOP_SIGNAL: + if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids): + result.outputs.token_ids.append(token_id) + task.output_token_ids.append(token_id) + + # Speculative logprobs handling + if logprobs_lists is not None and batch_token_index < len(logprobs_lists.logprobs): + result.outputs.logprob = float(logprobs_lists.logprobs[batch_token_index][0]) + topk_token_ids = logprobs_lists.logprob_token_ids[batch_token_index] + topk_logprobs = logprobs_lists.logprobs[batch_token_index] + sampled_rank = logprobs_lists.sampled_token_ranks[batch_token_index] + + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + + if token_id in task.eos_token_ids or is_prefill or recovery_stop: + result.finished = True + if recovery_stop: + result.error_msg = "Recover is not supported, the result is incomplete!" + + # Calculate statistics for the combined log + inference_start_time = task.metrics.get_inference_start_time(is_decode) + task.metrics.cal_cost_time() + e2e_time = current_time - inference_start_time + token_ratio = self.tokens_counter[task_id] / e2e_time + + # Get cache information + gpu_cache = getattr(task.metrics, "gpu_cache_token_num", 0) + cpu_cache = getattr(task.metrics, "cpu_cache_token_num", 0) + total_cached = gpu_cache + cpu_cache + + # Build cached detail dict + cached_detail = f'{{"CachedToken": {total_cached}, "GPU": {gpu_cache}, "CPU": {cpu_cache}}}' + + # Print combined log + ttft = task.metrics.first_token_time if task.metrics.first_token_time else 0 + ttft_s = ttft + task.metrics.time_in_queue + llm_logger.info( + f"Request={task_id}, InputToken={task.prompt_token_ids_len}, " + f"CachedDetail={cached_detail}, OutputToken={self.tokens_counter[task_id]}, " + f"TokenRatio={token_ratio:.2f}, TTFT={ttft:.2f}, TTFT_S={ttft_s:.2f}, " + f"E2E={e2e_time:.2f}, IsPrefill={is_prefill}, RecoveryStop={recovery_stop}, " + f"PreemptedCount={getattr(task.metrics, 'preempted_count', 0)}" + ) + + main_process_metrics.request_token_ratio.observe(token_ratio) + llm_logger.info(f"{self.resource_manager.info()}") + if self.cfg.speculative_config.method: + self._compute_speculative_status(result) + if not is_prefill: + self._record_completion_metrics(task, current_time) + llm_logger.info(f"task {task_id} received eos token. Recycling.") + if ( + envs.ENABLE_V1_KVCACHE_SCHEDULER + and self.cfg.cache_config.enable_prefix_caching + and self.cfg.cache_config.enable_output_caching + ): + self.resource_manager.cache_output_tokens(task) + self._recycle_resources(task_id, batch_id, task, result, is_prefill) + llm_logger.info(f"eos token {task_id} Recycle end.") + break + + if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + return result + return None + + def _process_draft_output_use_zmq(self, stream_data, task, batch_id, accept_num_val): + """ + Process draft token output (mtype=4) from speculative decoding via ZMQ. + Only processes logprobs (draft_top_logprobs), not actual tokens. + + Args: + stream_data: StreamTransferData with output_type=4. + task: Request object. + batch_id: Batch index. + accept_num_val: Number of accepted tokens. + + Returns: + RequestOutput with draft_top_logprobs populated. + """ + task_id = task.request_id + result = RequestOutput( + request_id=task_id, + output_type=4, + outputs=CompletionOutput( + index=batch_id, + send_idx=None, + token_ids=[], + draft_token_ids=[], + ), + finished=False, + metrics=None, + ) + + num_tokens = max(accept_num_val, 0) + if num_tokens > 0 and getattr(stream_data, "logprobs", None) is not None: + try: + logprobs_lists = stream_data.logprobs.tolists() + for idx in range(min(num_tokens, len(logprobs_lists.logprobs))): + topk_token_ids = logprobs_lists.logprob_token_ids[idx] + topk_logprobs = logprobs_lists.logprobs[idx] + sampled_rank = logprobs_lists.sampled_token_ranks[idx] + + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + except Exception as e: + llm_logger.warning(f"Failed to parse draft logprobs from StreamTransferData: {e}") + + return result + def process_sampling_results_use_zmq(self): """ use zmq to receive outputs from worker and process them """ - if self.speculative_decoding: - raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding") rank_id = self.cfg.parallel_config.local_data_parallel_id while True: try: @@ -392,7 +667,27 @@ def process_sampling_results_use_zmq(self): self._reschedule_preempt_task_use_zmq(receive_datas) batch_result = self._process_batch_output_use_zmq(receive_datas) - self.postprocess(batch_result) + + # Determine mtype for metrics and postprocess + mtype = 3 + if receive_datas and hasattr(receive_datas[0], "output_type"): + mtype = receive_datas[0].output_type + + # Batch-level speculative decoding metrics: only record for + # mtype=3 (target tokens); skip mtype=4 (draft logprobs) to + # avoid double-counting draft_tokens / max_emitted_tokens. + # Also filter out non-positive values (preempted=-9, recovery=-3, + # skipped=0) which do not represent real decoding steps. + if self.speculative_decoding and mtype == 3: + accept_nums = [ + int(sd.accept_num[0]) + for sd in receive_datas + if getattr(sd, "accept_num", None) is not None and int(sd.accept_num[0]) > 0 + ] + if accept_nums: + self._record_speculative_decoding_metrics(accept_nums) + + self.postprocess(batch_result, mtype) except Exception as e: llm_logger.error(f"Receive message:{receive_datas}, error:{e}") continue diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bc315c3646b..e63c3084974 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -126,6 +126,7 @@ def __init__( self.spec_method = self.fd_config.speculative_config.method self.speculative_decoding = self.spec_method is not None self.enable_logprob = fd_config.model_config.enable_logprob + self.enable_keep_sampling_mask = fd_config.model_config.enable_keep_sampling_mask self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size @@ -1231,6 +1232,7 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], logits_processors=self.share_inputs["logits_processors"], share_inputs=self.share_inputs, + keep_sampling_mask=self.enable_keep_sampling_mask, ) return token_num, token_num_event @@ -2466,8 +2468,10 @@ def _save_model_output( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, + async_output_queue=self.async_output_queue, save_each_rank=self.parallel_config.use_ep, skip_save_output=skip_save_output, + enable_draft_logprob=self.speculative_config.enable_draft_logprob, ) else: save_output_normal( diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 365fec12475..44cc9cb9e16 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -15,8 +15,9 @@ """ from dataclasses import dataclass, field -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional +import numpy as np import paddle @@ -178,6 +179,24 @@ class SamplerOutput: token_num_per_batch: Optional[paddle.Tensor] = None cu_batch_token_offset: Optional[paddle.Tensor] = None logits: Optional[paddle.Tensor] = None + # Sparse sampling mask for top_p/top_k: + # - Non-speculative decoding: per-request mask. This is a list of length + # num_reqs, where element i is a 1-D int32 numpy array of vocab indices + # retained by top_p/top_k for request i. Replaces the previous dense + # [num_reqs, vocab_size] bool tensor. + # - Speculative decoding: flattened per-accepted-token mask. This may be + # stored as a list aligned with all accepted tokens + # (e.g. length = total_accepted_tokens) and is regrouped by accept_num + # (number of accepted tokens per request) in post-processing before + # being sent back as per-request data. + # Callers MUST NOT assume this is always shaped by num_reqs; they should + # check whether the current path is speculative or non-speculative when + # interpreting the dimension. + sampling_mask: Optional[List[np.ndarray]] = None + # logZ_K for each request: log(sum(probs in candidate set K)) + # Used for renormalizing logprobs to match the truncated sampling distribution. + # Shape: [num_reqs] + logz_per_batch: Optional[np.ndarray] = None @dataclass diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8182e06990b..4255d35eb12 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1129,6 +1129,16 @@ def parse_args(): help="Maximum tokens per item in mm input.", ) + parser.add_argument( + "--enable_keep_sampling_mask", + "--enable-keep-sampling-mask", + action="store_true", + help=( + "Enable output of keep_sampling_mask as sparse vocab index list per token step " + "(Non-MTP: List[int]; MTP: List[List[int]])." + ), + ) + parser.add_argument( "--num_cpu_blocks", type=int, diff --git a/tests/e2e/test_ernie_21b_mtp.py b/tests/e2e/test_ernie_21b_mtp.py index dc60a213217..339d89ab2ce 100644 --- a/tests/e2e/test_ernie_21b_mtp.py +++ b/tests/e2e/test_ernie_21b_mtp.py @@ -83,18 +83,22 @@ def setup_and_run_server(): json.dumps(speculative_config), "--graph-optimization-config", '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + "--enable-keep-sampling-mask", ] # Start subprocess in new process group # 清除log目录 if os.path.exists("log"): shutil.rmtree("log") + env = os.environ.copy() + env["FD_USE_GET_SAVE_OUTPUT_V1"] = "1" with open(log_path, "w") as logfile: process = subprocess.Popen( cmd, stdout=logfile, stderr=subprocess.STDOUT, start_new_session=True, # Enables killing full group via os.killpg + env=env, ) # Wait up to 300 seconds for API server to be ready @@ -366,3 +370,176 @@ def test_mtp_accept_ratio(api_url): prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" + + +def _assert_sampling_mask_format(sampling_mask, max_tokens): + """验证 sampling_mask 字段格式的公共辅助函数。 + + sampling_mask 是 List[List[int]]: + - 外层列表长度 == 生成的 token 数(completion_tokens),对应 MTP 每步可接受多个 token + - 内层列表为保留位置的词汇表索引(int),非空且单调递增 + """ + assert sampling_mask is not None, "sampling_mask 不应为 None" + assert isinstance(sampling_mask, list), "sampling_mask 应为 list" + assert len(sampling_mask) > 0, "sampling_mask 不应为空" + assert len(sampling_mask) <= max_tokens, "sampling_mask 长度不应超过 max_tokens" + + for token_mask in sampling_mask: + assert isinstance(token_mask, list), f"每个 token 的 mask 应为 list,实际: {type(token_mask)}" + assert len(token_mask) > 0, "每个 token 的 mask 不应为空(至少保留采样到的 token)" + for idx in token_mask: + assert isinstance(idx, int), f"mask 中的每个元素应为 int,实际: {type(idx)}" + assert idx >= 0, f"mask 索引不应为负数,实际: {idx}" + + +def test_keep_sampling_mask_stream(api_url): + """测试流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. 每个非空 chunk 的 choices[0].sampling_mask 格式为 List[List[int]] + 2. 内层列表包含词汇表保留位置的索引,非空且单调递增 + 3. 最终 sampling_mask 总长度等于 completion_tokens + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + all_sampling_masks = [] + for chunk in chunks[:-1]: # 最后一个 chunk 是 usage-only + choice = chunk["choices"][0] + # 仅当 delta 有实际内容时才应携带 sampling_mask(首个 role chunk 内容为空,不含该字段) + has_content = bool(choice.get("delta", {}).get("content")) + mask = choice.get("sampling_mask") + if has_content: + assert mask is not None, f"有内容的 chunk 缺少 sampling_mask 字段: {choice}" + if mask is not None: + assert isinstance(mask, list), f"sampling_mask 应为 list,实际: {type(mask)}" + for token_mask in mask: + assert isinstance(token_mask, list), "每个 token mask 应为 list" + assert len(token_mask) > 0, "每个 token mask 不应为空" + for idx in token_mask: + assert isinstance(idx, int) and idx >= 0, f"mask 索引应为非负 int,实际: {idx}" + all_sampling_masks.extend(mask) + + # 最后一个 chunk 携带 usage 信息 + usage = chunks[-1].get("usage") + if usage: + completion_tokens = usage["completion_tokens"] + assert ( + len(all_sampling_masks) == completion_tokens + ), f"sampling_mask 总长度 {len(all_sampling_masks)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_non_stream(api_url): + """测试非流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. choices[0].sampling_mask 格式为 List[List[int]] + 2. 长度等于 completion_tokens + 3. 内层列表包含非负递增的词汇表索引 + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + + response = send_request(url=api_url, payload=payload).json() + assert "choices" in response, f"响应缺少 choices 字段: {response}" + choice = response["choices"][0] + assert "sampling_mask" in choice, f"choice 缺少 sampling_mask 字段: {choice}" + + sampling_mask = choice["sampling_mask"] + completion_tokens = response["usage"]["completion_tokens"] + _assert_sampling_mask_format(sampling_mask, max_tokens) + assert ( + len(sampling_mask) == completion_tokens + ), f"sampling_mask 长度 {len(sampling_mask)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_top_p_1_stream(api_url): + """测试 top_p=1.0 时流式响应的 sampling_mask(MTP 模式)。 + + top_p=1.0 表示保留全部词汇,每个 token mask 应包含所有词汇表位置。 + 验证 mask 非空且每个内层列表长度 > 1(至少保留多个候选 token)。 + """ + max_tokens = 10 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 1.0, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "1+1="}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + for chunk in chunks[:-1]: + choice = chunk["choices"][0] + mask = choice.get("sampling_mask") + if mask is not None: + for token_mask in mask: + assert len(token_mask) > 1, "top_p=1.0 时每个 token 的候选集应大于 1" + + +def test_keep_sampling_mask_consistent_with_top_p(api_url): + """对比 top_p=0.1 与 top_p=0.9 时 sampling_mask 的候选集大小(非流式,MTP 模式)。 + + top_p 越小,保留的候选 token 越少,平均 mask 长度应更短。 + """ + max_tokens = 15 + + def get_avg_mask_len(top_p): + payload = { + "model": "default", + "temperature": 1.0, + "top_p": top_p, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请列举三种编程语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + resp = send_request(url=api_url, payload=payload).json() + mask = resp["choices"][0].get("sampling_mask") + if not mask: + return 0 + return sum(len(m) for m in mask) / len(mask) + + avg_small = get_avg_mask_len(0.1) + avg_large = get_avg_mask_len(0.9) + assert avg_small <= avg_large, f"top_p=0.1 的平均 mask 长度 ({avg_small:.1f}) 应 <= top_p=0.9 ({avg_large:.1f})" diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index d98e79b74f2..bd7b6482b09 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -577,6 +577,7 @@ async def test_create_chat_completion_choice(self): response_processor=mock_response_processor, max_tokens=max_tokens_list[idx], speculate_metrics=None, + sampling_mask_list=None, ) expected = case["expected"] diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 46282cd386a..2682aee0713 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -166,7 +166,7 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): processor.total_step_per_request = {} processor.accept_token_num_per_head_per_request = {} processor.accept_token_num_per_head = [0] * MAX_DRAFT_TOKENS - + processor.use_sampling_mask = False # processor._recycle_resources = Mock() if speculative_decoding: