Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(
self.enable_logprob = False
self.max_logprobs = 20
self.logprobs_mode = "raw_logprobs"
self.enable_keep_sampling_mask = False
self.redundant_experts_num = 0
self.seed = 0
self.quantization = None
Expand Down
20 changes: 20 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""

enable_keep_sampling_mask: bool = False
"""
When enabled, the server returns a sparse index list for each generated token, indicating
which vocabulary positions were retained after top_p/top_k sampling, and streams it to
the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]],
where each inner list contains the retained vocabulary indices for a predicted token.
"""

max_logprobs: int = 20
"""
Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the
Expand Down Expand Up @@ -893,6 +901,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.",
)
model_group.add_argument(
"--enable-keep-sampling-mask",
action="store_true",
default=EngineArgs.enable_keep_sampling_mask,
help=(
Comment on lines +904 to +908
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按仓库 PR 规范,标题需要使用预置的 tag(例如 [Feature]/[APIServer]/[Engine]/[Speculative Decoding] 等)。当前标题使用的 [KSM] 不在模板列出的 tag 列表里且语义不够明确,建议改为更通用且可检索的 tag(如本 PR 涉及采样/接口输出,可考虑 [Feature] 或 [APIServer]/[Engine] 组合)。

Copilot generated this review using guidance from repository custom instructions.
"Enable output of sampling mask as a sparse index list over the vocabulary. "
Comment on lines +904 to +909
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 标题目前使用了 [KSM] 标签,但仓库模板给出的 tag 列表里没有该项,可能影响自动检查/统计。建议将标题 tag 调整为列表内语义最接近的类型(例如 [Feature]、[APIServer] 或 [Engine]),并保持格式为“[TAG] Title”。

Copilot uses AI. Check for mistakes.
"For non-MTP decoding, this is a list[int] per token step indicating which "
"vocabulary indices were kept after top_p/top_k sampling. "
"For MTP decoding, this is a list[list[int]] per token step, where each inner "
"list corresponds to one MTP group."
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开启 keep_sampling_mask 时,如果 top_p=1.0 且未设置有效 top_k,按当前实现会保留整个 vocab 的索引列表(每个 token 都可能返回接近 vocab_size 个整数),不仅计算端需要完整 argsort/拷贝,网络/序列化开销也会非常大。建议在 CLI help/文档里明确提示该开销风险,或在实现侧对 top_p>=1.0 时提供可选的上限/降采样策略(例如强制要求 top_k>0 或支持只返回 boundary 阈值而非全量索引)。

Suggested change
"list corresponds to one MTP group."
"list corresponds to one MTP group. Warning: when top_p >= 1.0 and top_k is "
"unset or non-positive, the returned index list may include nearly the entire "
"vocabulary for each token step, which can significantly increase compute, "
"memory, serialization, and network overhead. Prefer using this option with a "
"bounded top_k to avoid very large responses."

Copilot uses AI. Check for mistakes.
),
)
model_group.add_argument(
"--max-logprobs",
type=int,
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,7 @@ def _start_worker_service(self):
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def _start_worker_service(self):
"enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ class CompletionOutput:
delta_message: Optional[DeltaMessage] = None
multipart: Optional[list[Any]] = None
num_image_tokens: Optional[int] = None
# Sparse indices of retained vocab ids:
# - Non-MTP: list[int]
# - MTP: list[list[int]]
sampling_mask: Optional[Any] = None

def to_dict(self):
"""
Expand All @@ -745,6 +749,7 @@ def to_dict(self):
"text": self.text,
"reasoning_content": self.reasoning_content,
"reasoning_token_num": self.reasoning_token_num,
"sampling_mask": self.sampling_mask,
}

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class ChatCompletionResponseChoice(BaseModel):
prompt_logprobs: Optional[PromptLogprobs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
speculate_metrics: Optional[SpeculateMetrics] = None
# Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token
sampling_mask: Optional[List[List[int]]] = None


class ChatCompletionResponse(BaseModel):
Expand Down Expand Up @@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
# Per-token index list of retained positions after top_p sampling.
# Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step).
sampling_mask: Optional[List[List[int]]] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None
arrival_time: Optional[float] = None
speculate_metrics: Optional[SpeculateMetrics] = None
Expand Down
32 changes: 32 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ async def chat_completion_stream_generator(
delta=delta_message,
logprobs=logprobs_res,
draft_logprobs=draft_logprobs_res,
sampling_mask=(
self._make_sampling_mask_list(output["sampling_mask"])
if output.get("sampling_mask") is not None
else None
),
arrival_time=arrival_time,
speculate_metrics=output_speculate_metrics,
)
Expand Down Expand Up @@ -580,6 +585,7 @@ async def chat_completion_full_generator(
decoder_base_url=self.tokenizer_base_url,
)
prompt_logprobs_res_list = [[] for _ in range(num_choices)]
sampling_mask_list = [[] for _ in range(num_choices)]
speculate_metrics = [None for _ in range(num_choices)]
choices = []
while num_choices > 0:
Expand Down Expand Up @@ -660,6 +666,9 @@ async def chat_completion_full_generator(
)
if prompt_logprobs_res:
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
output_sampling_mask = output.get("sampling_mask", None)
if output_sampling_mask is not None:
sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask))
speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None)
if data["finished"]:
trace_carrier = data.get("trace_carrier")
Expand Down Expand Up @@ -695,6 +704,7 @@ async def chat_completion_full_generator(
draft_logprob_contents=draft_logprob_contents,
response_processor=response_processor,
prompt_logprobs_res_list=prompt_logprobs_res_list,
sampling_mask_list=sampling_mask_list,
max_tokens=max_tokens,
speculate_metrics=speculate_metrics[idx],
)
Expand Down Expand Up @@ -749,6 +759,7 @@ async def _create_chat_completion_choice(
logprob_contents: list,
draft_logprob_contents: list,
prompt_logprobs_res_list: list,
sampling_mask_list: list,
response_processor: ChatResponseProcessor,
max_tokens: int,
speculate_metrics: SpeculateMetrics | None,
Expand Down Expand Up @@ -787,6 +798,11 @@ async def _create_chat_completion_choice(
if prompt_logprobs_res_list[idx]:
prompt_logprobs_full_res = prompt_logprobs_res_list[idx]

# Flatten per-step List[List[int]] into a single List[List[int]] over all tokens.
sampling_mask_full_res = None
if sampling_mask_list and sampling_mask_list[idx]:
sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step]

num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0)
Expand All @@ -810,6 +826,7 @@ async def _create_chat_completion_choice(
logprobs=logprobs_full_res,
draft_logprobs=draft_logprobs_full_res,
prompt_logprobs=prompt_logprobs_full_res,
sampling_mask=sampling_mask_full_res,
finish_reason=finish_reason,
speculate_metrics=speculate_metrics,
)
Expand Down Expand Up @@ -1000,3 +1017,18 @@ def _make_logprob_dict(
)
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
}

@staticmethod
def _make_sampling_mask_list(sampling_mask) -> List[List[int]]:
"""Wrap sampling_mask into a uniform List[List[int]] format.

sampling_mask is already in sparse-index form (no bool-to-index conversion needed):
Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]]
MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...]
"""
assert sampling_mask is not None
if sampling_mask and isinstance(sampling_mask[0], list):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _make_sampling_mask_list 方法在 sampling_mask 为空列表时会抛出 IndexError。

sampling_mask 为空列表 [] 时,isinstance(sampling_mask[0], list) 会抛出 IndexError: list index out of range

建议修复方式:

@staticmethod
def _make_sampling_mask_list(sampling_mask) -> List[List[int]]:
    assert sampling_mask is not None
    if sampling_mask and isinstance(sampling_mask[0], list):
        # MTP: already List[List[int]], return as-is
        return sampling_mask
    # Non-MTP: already List[int], wrap in outer list for uniform format
    return [sampling_mask]

或者在调用方增加空列表检查。

# MTP: already List[List[int]], return as-is
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _make_sampling_mask_list 函数中的 isinstance(sampling_mask[0], list) 访问可能在 sampling_mask 为空列表时失败。

keep_sampling_mask=True 但某个请求的输出没有实际 token 时(如 accept_num=0),sampling_mask 可能为空列表 [],导致 sampling_mask[0] 访问越界。

建议添加空列表检查:

if sampling_mask and isinstance(sampling_mask[0], list):
    return sampling_mask

return sampling_mask
# Non-MTP: already List[int], wrap in outer list for uniform format
return [sampling_mask]
11 changes: 6 additions & 5 deletions fastdeploy/model_executor/layers/sample/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def build_output_logprobs(
logprobs_tensors = None
cu_batch_token_offset = None

if num_logprobs is None:
return logprobs_tensors, cu_batch_token_offset

real_bsz = share_inputs["seq_lens_this_time"].shape[0]

if is_naive:
Expand Down Expand Up @@ -208,6 +205,10 @@ def build_output_logprobs(
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)

# Adapate for sampling mask
if num_logprobs is None:
return None, None, output_logits

# Compute logprobs with temperature scaling and top_p normalization
Comment on lines 205 to 212
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_output_logprobs 现在在 num_logprobs is None 时返回 (None, None, output_logits),并且后续也返回 3 个值;但函数签名/类型标注/文档仍描述为返回 2-tuple。这会误导调用方并让类型检查失效。建议同步更新返回类型标注与 docstring,或拆分“提取 output_logits”逻辑到独立函数以保持接口清晰。

Copilot uses AI. Check for mistakes.
if logprobs_mode == "raw_logprobs":
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata)
Expand All @@ -217,5 +218,5 @@ def build_output_logprobs(
raw_logprobs = F.log_softmax(output_logits, axis=-1)

logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)

return logprobs_tensors, cu_batch_token_offset
# output_logits use to compute sampling_mask
return logprobs_tensors, cu_batch_token_offset, output_logits
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/sample/meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,5 @@ class SamplingMetadata:
# Add for HPU post-processing
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
# Add for keep sampling mask
keep_sampling_mask: Optional[bool] = None
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def top_k_top_p_sampling(
if topp_seed is not None:
topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype)
topp_seed_device.copy_(topp_seed, False)
if top_k_list and any(x > 0 for x in top_k_list):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 请确认第 128-136 行新增的 top_k_renorm_probs 调用是否必要。

从代码结构看,第 87-97 行已经在 else 分支中对所有非 "rejection" 采样器调用了 top_k_renorm_probs(x, top_k),修改后的 probs 会被传递到后续的 if top_p_class == "air"elif top_p_class == "base_non_truncated" 分支。

第 128-136 行的新增代码位于内层 else 分支(即既不是 "air" 也不是 "base_non_truncated" 的情况),这可能导致 top_k_renorm_probs 被调用两次,或者调用逻辑与第 87-97 行重复。

建议检查代码缩进和逻辑分支,确保 top_k_renorm_probs 只在需要的位置被调用一次。

try:
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs

x = top_k_renorm_probs(x, top_k)
except ImportError:
logger.warning(
"top_k_renorm_probs is not supported on current platform, skipping top_k_renorm_probs."
)

Comment on lines +128 to +137
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在进入 paddle.tensor.top_p_sampling 的分支里又重复执行了一次 top_k_renorm_probs(上面行 87-97 已经做过一次)。这会造成重复计算,并且在部分平台(例如 iluvatar)还会出现前后 import 路径不一致/重复 warning 的问题。建议删除这一段重复的 top_k_renorm_probs 调用,保留前面的统一处理即可。

Suggested change
if top_k_list and any(x > 0 for x in top_k_list):
try:
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs
x = top_k_renorm_probs(x, top_k)
except ImportError:
logger.warning(
"top_k_renorm_probs is not supported on current platform, skipping top_k_renorm_probs."
)

Copilot uses AI. Check for mistakes.
_, ids = paddle.tensor.top_p_sampling(
x,
top_p,
Expand Down
Loading
Loading