Skip to content

[KSM] support keep sampling mask#7222

Open
zeroRains wants to merge 3 commits intoPaddlePaddle:developfrom
zeroRains:ksm
Open

[KSM] support keep sampling mask#7222
zeroRains wants to merge 3 commits intoPaddlePaddle:developfrom
zeroRains:ksm

Conversation

@zeroRains
Copy link
Copy Markdown
Contributor

Motivation

添加keep_sampling_mask功能,详细见PR:#6725

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask

Usage or Command

服务启动指令:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
MODEL_PATH="/root/paddlejob/tmpspace/GLM-4.5-Air/"
python -m fastdeploy.entrypoints.openai.api_server \
    --port 9293 \
    --host $(hostname -i) \
    --model "$MODEL_PATH" \
    --disable-custom-all-reduce \
    --tensor-parallel-size 8 \
    --max-model-len 131072 \
    --max-num-seqs 32 \
    --gpu-memory-utilization 0.9 \
    --graph-optimization-config '{"use_cudagraph":true}' \
    --enable-logprob \
    --enable-keep-sampling-mask \
    --speculative-config '{"method":"mtp","num_speculative_tokens":1,"num_model_steps":1,"model":"'$MODEL_PATH'"}'

Accuracy Tests

yes

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 7, 2026 11:53
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 7, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 在 FastDeploy 的采样与 OpenAI API 输出链路中新增了 keep_sampling_mask 能力,用于把每个生成 token 在 top_p/top_k 后“保留的词表位置”以稀疏索引列表形式返回(并支持 MTP/Speculative decoding 场景)。

Changes:

  • 在采样阶段计算并透传稀疏 sampling_mask(非 speculative:按 request;speculative:按 accepted token 展平后再按需重组)。
  • OpenAI Chat API(流式/非流式)响应协议与 serving 逻辑增加 sampling_mask 字段拼装与聚合。
  • 增加启动参数 --enable-keep-sampling-mask,并补充/更新相关单测与 e2e 用例。

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/output/test_process_batch_output.py 测试初始化补充 use_sampling_mask 开关字段
tests/entrypoints/openai/test_max_streaming_tokens.py 适配新增的 sampling_mask_list 参数
tests/e2e/test_ernie_21b_mtp.py e2e 覆盖流式/非流式 sampling_mask 格式与 top_p 行为
fastdeploy/worker/worker_process.py worker 侧新增 CLI 开关解析
fastdeploy/worker/output.py SamplerOutput 增加 sampling_mask 字段
fastdeploy/worker/gpu_model_runner.py 将 keep_sampling_mask 透传到 SamplingMetadata,并在 speculative save_output 传入 async_output_queue
fastdeploy/output/token_processor.py ZMQ(v1) 输出处理增加 speculative 分支与 sampling_mask 解析;postprocess 支持 mtype
fastdeploy/output/stream_transfer_data.py StreamTransferData 增加 output_type 与 sampling_mask 字段
fastdeploy/model_executor/pre_and_post_process.py v1 输出链路支持 speculative 的 StreamTransferData 构造与 sampling_mask 透传
fastdeploy/model_executor/layers/sample/sampler.py 新增 _compute_sampling_mask 并在采样路径注入 sampling_mask
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py top-k 概率重归一化逻辑调整
fastdeploy/model_executor/layers/sample/meta_data.py SamplingMetadata 增加 keep_sampling_mask 字段
fastdeploy/model_executor/layers/sample/logprobs.py 为 sampling_mask 复用 target_logits,调整返回值
fastdeploy/entrypoints/openai/serving_chat.py chat streaming/full response 增加 sampling_mask 聚合输出
fastdeploy/entrypoints/openai/protocol.py OpenAI 协议模型增加 sampling_mask 字段
fastdeploy/engine/request.py CompletionOutput 增加 sampling_mask 并加入 to_dict
fastdeploy/engine/engine.py 启动 worker 时透传 enable_keep_sampling_mask
fastdeploy/engine/common_engine.py 同上(common_engine)
fastdeploy/engine/args_utils.py engine CLI 增加 --enable-keep-sampling-mask,并写入 EngineArgs
fastdeploy/config.py ModelConfig 初始化增加 enable_keep_sampling_mask 默认值

Comment on lines 629 to +633
"last_preempted_idx",
],
)
speculate_save_output_topk(
recover_share_inputs["sampled_token_ids"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
recover_share_inputs["accept_num_cpu"],
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
recover_share_inputs["seq_lens_decoder_cpu"],
recover_share_inputs["prompt_lens_cpu"],
recover_share_inputs["last_preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
# target tokens (mtype=3)
output = _build_speculative_stream_transfer_data(
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FD_USE_GET_SAVE_OUTPUT_V1 的 speculative decoding 输出分支里 recover_share_inputs 包含了 last_preempted_idx,但后续构造 StreamTransferData(_build_speculative_stream_transfer_data)时并未使用该信息来注入 PREEMPTED_TOKEN_ID(-9) 等 preemption 信号。这样 ZMQ 侧 token_processor 会把 preempted request 当作 accept_num==0 直接忽略,同时由于 batch_id 仍在 receive_datas 中,_reschedule_preempt_task_use_zmq 也不会触发 reschedule/recycle,可能导致请求卡住或资源泄漏。建议把 last_preempted_idx 传入并在 _build_speculative_stream_transfer_data 内对 preempted slot 设置 accept_num=PREEMPTED_TOKEN_ID(或将 tokens 标记为负值)以复用现有 preemption 处理逻辑。

Copilot uses AI. Check for mistakes.
Comment on lines +671 to +679
# Batch-level speculative decoding metrics
if self.speculative_decoding:
accept_nums = []
for sd in receive_datas:
if getattr(sd, "accept_num", None) is not None:
accept_nums.append(int(sd.accept_num[0]))
if accept_nums:
self._record_speculative_decoding_metrics(accept_nums)

Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_sampling_results_use_zmq 里对 speculative decoding 的 accept_nums 统计目前不区分 output_type。开启 enable_draft_logprob 时同一步会先收到 mtype=3(target) 再收到 mtype=4(draft) 两批数据;这里两次都会调用 _record_speculative_decoding_metrics(accept_nums),会导致 spec_decode 指标/Prometheus 计数翻倍。建议仅在 mtype==3(target) 时记录(或显式跳过 mtype==4)。

Copilot uses AI. Check for mistakes.
Comment on lines 49 to +53
pooler_output: Optional[np.array] = None
# Sparse sampling mask(s) for top_p/top_k:
# - Non-speculative: single 1-D int64 numpy array of retained vocab indices.
# - Speculative: List[np.ndarray], one 1-D array per accepted token.
sampling_mask: Optional[np.array] = None
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamTransferData.sampling_mask 的类型标注目前是 Optional[np.array],但注释/实际用法里 speculative decoding 分支会赋值为 List[np.ndarray](见 _build_speculative_stream_transfer_data 对 sampling_mask 的切片)。这会让类型信息与实际不一致,后续 IDE/静态检查及调用方容易误用。建议把类型改成更准确的 Union(例如 Optional[Union[np.ndarray, list[np.ndarray]]])或直接用 Any,并同步更新注释。

Copilot uses AI. Check for mistakes.
Comment on lines 205 to 212
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)

# Adapate for sampling mask
if num_logprobs is None:
return None, None, output_logits

# Compute logprobs with temperature scaling and top_p normalization
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_output_logprobs 现在在 num_logprobs is None 时返回 (None, None, output_logits),并且后续也返回 3 个值;但函数签名/类型标注/文档仍描述为返回 2-tuple。这会误导调用方并让类型检查失效。建议同步更新返回类型标注与 docstring,或拆分“提取 output_logits”逻辑到独立函数以保持接口清晰。

Copilot uses AI. Check for mistakes.
Comment on lines +904 to +909
model_group.add_argument(
"--enable-keep-sampling-mask",
action="store_true",
default=EngineArgs.enable_keep_sampling_mask,
help=(
"Enable output of sampling mask as a sparse index list over the vocabulary. "
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 标题目前使用了 [KSM] 标签,但仓库模板给出的 tag 列表里没有该项,可能影响自动检查/统计。建议将标题 tag 调整为列表内语义最接近的类型(例如 [Feature]、[APIServer] 或 [Engine]),并保持格式为“[TAG] Title”。

Copilot uses AI. Check for mistakes.
Comment on lines +128 to +131
if top_k_list and any(x > 0 for x in top_k_list):
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs

x = top_k_renorm_probs(x, top_k)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k_renorm_probs 在该函数里已经在前面 (进入 top_p_class != "rejection" 分支后) 统一做了一次 top-k 截断+归一化;这里在默认 CUDA 路径里又重复做了一次,会带来额外 kernel/开销,并可能让逻辑难以维护。建议删除这段重复调用,或将 top-k 归一化逻辑集中在一个位置确保只执行一次。

Suggested change
if top_k_list and any(x > 0 for x in top_k_list):
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs
x = top_k_renorm_probs(x, top_k)

Copilot uses AI. Check for mistakes.
Comment on lines +666 to +673
# Compute sampling mask BEFORE top_k_top_p_sampling modifies probs.
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
sampling_mask = None
if sampling_metadata.keep_sampling_mask:
sampling_mask = _compute_sampling_mask(
probs,
sampling_metadata.top_p,
top_k=sampling_metadata.top_k,
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释仍写“Binary mask [num_reqs, vocab_size]”,但 _compute_sampling_mask 实际返回的是稀疏索引列表(List[np.ndarray]),并非 dense bool mask。建议更新注释(以及 sampling_mask 相关变量命名/类型提示)以避免调用方误解采样掩码的格式与维度。

Copilot uses AI. Check for mistakes.
"For non-MTP decoding, this is a list[int] per token step indicating which "
"vocabulary indices were kept after top_p/top_k sampling. "
"For MTP decoding, this is a list[list[int]] per token step, where each inner "
"list corresponds to one MTP group."
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开启 keep_sampling_mask 时,如果 top_p=1.0 且未设置有效 top_k,按当前实现会保留整个 vocab 的索引列表(每个 token 都可能返回接近 vocab_size 个整数),不仅计算端需要完整 argsort/拷贝,网络/序列化开销也会非常大。建议在 CLI help/文档里明确提示该开销风险,或在实现侧对 top_p>=1.0 时提供可选的上限/降采样策略(例如强制要求 top_k>0 或支持只返回 boundary 阈值而非全量索引)。

Suggested change
"list corresponds to one MTP group."
"list corresponds to one MTP group. Warning: when top_p >= 1.0 and top_k is "
"unset or non-positive, the returned index list may include nearly the entire "
"vocabulary for each token step, which can significantly increase compute, "
"memory, serialization, and network overhead. Prefer using this option with a "
"bounded top_k to avoid very large responses."

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 37.16216% with 186 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@8cb417e). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/output/token_processor.py 3.68% 155 Missing and 2 partials ⚠️
fastdeploy/model_executor/pre_and_post_process.py 63.46% 12 Missing and 7 partials ⚠️
fastdeploy/entrypoints/openai/serving_chat.py 38.46% 6 Missing and 2 partials ⚠️
fastdeploy/model_executor/layers/sample/sampler.py 95.74% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7222   +/-   ##
==========================================
  Coverage           ?   73.27%           
==========================================
  Files              ?      377           
  Lines              ?    53507           
  Branches           ?     8375           
==========================================
  Hits               ?    39206           
  Misses             ?    11551           
  Partials           ?     2750           
Flag Coverage Δ
GPU 73.27% <37.16%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants