[KSM] support keep sampling mask#7222
[KSM] support keep sampling mask#7222zeroRains wants to merge 3 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 在 FastDeploy 的采样与 OpenAI API 输出链路中新增了 keep_sampling_mask 能力,用于把每个生成 token 在 top_p/top_k 后“保留的词表位置”以稀疏索引列表形式返回(并支持 MTP/Speculative decoding 场景)。
Changes:
- 在采样阶段计算并透传稀疏
sampling_mask(非 speculative:按 request;speculative:按 accepted token 展平后再按需重组)。 - OpenAI Chat API(流式/非流式)响应协议与 serving 逻辑增加
sampling_mask字段拼装与聚合。 - 增加启动参数
--enable-keep-sampling-mask,并补充/更新相关单测与 e2e 用例。
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/output/test_process_batch_output.py | 测试初始化补充 use_sampling_mask 开关字段 |
| tests/entrypoints/openai/test_max_streaming_tokens.py | 适配新增的 sampling_mask_list 参数 |
| tests/e2e/test_ernie_21b_mtp.py | e2e 覆盖流式/非流式 sampling_mask 格式与 top_p 行为 |
| fastdeploy/worker/worker_process.py | worker 侧新增 CLI 开关解析 |
| fastdeploy/worker/output.py | SamplerOutput 增加 sampling_mask 字段 |
| fastdeploy/worker/gpu_model_runner.py | 将 keep_sampling_mask 透传到 SamplingMetadata,并在 speculative save_output 传入 async_output_queue |
| fastdeploy/output/token_processor.py | ZMQ(v1) 输出处理增加 speculative 分支与 sampling_mask 解析;postprocess 支持 mtype |
| fastdeploy/output/stream_transfer_data.py | StreamTransferData 增加 output_type 与 sampling_mask 字段 |
| fastdeploy/model_executor/pre_and_post_process.py | v1 输出链路支持 speculative 的 StreamTransferData 构造与 sampling_mask 透传 |
| fastdeploy/model_executor/layers/sample/sampler.py | 新增 _compute_sampling_mask 并在采样路径注入 sampling_mask |
| fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py | top-k 概率重归一化逻辑调整 |
| fastdeploy/model_executor/layers/sample/meta_data.py | SamplingMetadata 增加 keep_sampling_mask 字段 |
| fastdeploy/model_executor/layers/sample/logprobs.py | 为 sampling_mask 复用 target_logits,调整返回值 |
| fastdeploy/entrypoints/openai/serving_chat.py | chat streaming/full response 增加 sampling_mask 聚合输出 |
| fastdeploy/entrypoints/openai/protocol.py | OpenAI 协议模型增加 sampling_mask 字段 |
| fastdeploy/engine/request.py | CompletionOutput 增加 sampling_mask 并加入 to_dict |
| fastdeploy/engine/engine.py | 启动 worker 时透传 enable_keep_sampling_mask |
| fastdeploy/engine/common_engine.py | 同上(common_engine) |
| fastdeploy/engine/args_utils.py | engine CLI 增加 --enable-keep-sampling-mask,并写入 EngineArgs |
| fastdeploy/config.py | ModelConfig 初始化增加 enable_keep_sampling_mask 默认值 |
| "last_preempted_idx", | ||
| ], | ||
| ) | ||
| speculate_save_output_topk( | ||
| recover_share_inputs["sampled_token_ids"], | ||
| sampler_output.logprobs_tensors.logprob_token_ids, | ||
| sampler_output.logprobs_tensors.logprobs, | ||
| sampler_output.logprobs_tensors.selected_token_ranks, | ||
| recover_share_inputs["accept_num_cpu"], | ||
| sampler_output.cu_batch_token_offset, | ||
| model_output.not_need_stop, | ||
| recover_share_inputs["seq_lens_decoder_cpu"], | ||
| recover_share_inputs["prompt_lens_cpu"], | ||
| recover_share_inputs["last_preempted_idx"], | ||
| 3, # mtype | ||
| model_output.mp_rank, | ||
| save_each_rank, | ||
| # target tokens (mtype=3) | ||
| output = _build_speculative_stream_transfer_data( |
There was a problem hiding this comment.
FD_USE_GET_SAVE_OUTPUT_V1 的 speculative decoding 输出分支里 recover_share_inputs 包含了 last_preempted_idx,但后续构造 StreamTransferData(_build_speculative_stream_transfer_data)时并未使用该信息来注入 PREEMPTED_TOKEN_ID(-9) 等 preemption 信号。这样 ZMQ 侧 token_processor 会把 preempted request 当作 accept_num==0 直接忽略,同时由于 batch_id 仍在 receive_datas 中,_reschedule_preempt_task_use_zmq 也不会触发 reschedule/recycle,可能导致请求卡住或资源泄漏。建议把 last_preempted_idx 传入并在 _build_speculative_stream_transfer_data 内对 preempted slot 设置 accept_num=PREEMPTED_TOKEN_ID(或将 tokens 标记为负值)以复用现有 preemption 处理逻辑。
| # Batch-level speculative decoding metrics | ||
| if self.speculative_decoding: | ||
| accept_nums = [] | ||
| for sd in receive_datas: | ||
| if getattr(sd, "accept_num", None) is not None: | ||
| accept_nums.append(int(sd.accept_num[0])) | ||
| if accept_nums: | ||
| self._record_speculative_decoding_metrics(accept_nums) | ||
|
|
There was a problem hiding this comment.
process_sampling_results_use_zmq 里对 speculative decoding 的 accept_nums 统计目前不区分 output_type。开启 enable_draft_logprob 时同一步会先收到 mtype=3(target) 再收到 mtype=4(draft) 两批数据;这里两次都会调用 _record_speculative_decoding_metrics(accept_nums),会导致 spec_decode 指标/Prometheus 计数翻倍。建议仅在 mtype==3(target) 时记录(或显式跳过 mtype==4)。
| pooler_output: Optional[np.array] = None | ||
| # Sparse sampling mask(s) for top_p/top_k: | ||
| # - Non-speculative: single 1-D int64 numpy array of retained vocab indices. | ||
| # - Speculative: List[np.ndarray], one 1-D array per accepted token. | ||
| sampling_mask: Optional[np.array] = None |
There was a problem hiding this comment.
StreamTransferData.sampling_mask 的类型标注目前是 Optional[np.array],但注释/实际用法里 speculative decoding 分支会赋值为 List[np.ndarray](见 _build_speculative_stream_transfer_data 对 sampling_mask 的切片)。这会让类型信息与实际不一致,后续 IDE/静态检查及调用方容易误用。建议把类型改成更准确的 Union(例如 Optional[Union[np.ndarray, list[np.ndarray]]])或直接用 Any,并同步更新注释。
| mask = idx < share_inputs["accept_num"].unsqueeze(1) | ||
| token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) | ||
|
|
||
| # Adapate for sampling mask | ||
| if num_logprobs is None: | ||
| return None, None, output_logits | ||
|
|
||
| # Compute logprobs with temperature scaling and top_p normalization |
There was a problem hiding this comment.
build_output_logprobs 现在在 num_logprobs is None 时返回 (None, None, output_logits),并且后续也返回 3 个值;但函数签名/类型标注/文档仍描述为返回 2-tuple。这会误导调用方并让类型检查失效。建议同步更新返回类型标注与 docstring,或拆分“提取 output_logits”逻辑到独立函数以保持接口清晰。
| model_group.add_argument( | ||
| "--enable-keep-sampling-mask", | ||
| action="store_true", | ||
| default=EngineArgs.enable_keep_sampling_mask, | ||
| help=( | ||
| "Enable output of sampling mask as a sparse index list over the vocabulary. " |
There was a problem hiding this comment.
PR 标题目前使用了 [KSM] 标签,但仓库模板给出的 tag 列表里没有该项,可能影响自动检查/统计。建议将标题 tag 调整为列表内语义最接近的类型(例如 [Feature]、[APIServer] 或 [Engine]),并保持格式为“[TAG] Title”。
| if top_k_list and any(x > 0 for x in top_k_list): | ||
| from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs | ||
|
|
||
| x = top_k_renorm_probs(x, top_k) |
There was a problem hiding this comment.
top_k_renorm_probs 在该函数里已经在前面 (进入 top_p_class != "rejection" 分支后) 统一做了一次 top-k 截断+归一化;这里在默认 CUDA 路径里又重复做了一次,会带来额外 kernel/开销,并可能让逻辑难以维护。建议删除这段重复调用,或将 top-k 归一化逻辑集中在一个位置确保只执行一次。
| if top_k_list and any(x > 0 for x in top_k_list): | |
| from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs | |
| x = top_k_renorm_probs(x, top_k) |
| # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. | ||
| # Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated. | ||
| sampling_mask = None | ||
| if sampling_metadata.keep_sampling_mask: | ||
| sampling_mask = _compute_sampling_mask( | ||
| probs, | ||
| sampling_metadata.top_p, | ||
| top_k=sampling_metadata.top_k, |
There was a problem hiding this comment.
这里的注释仍写“Binary mask [num_reqs, vocab_size]”,但 _compute_sampling_mask 实际返回的是稀疏索引列表(List[np.ndarray]),并非 dense bool mask。建议更新注释(以及 sampling_mask 相关变量命名/类型提示)以避免调用方误解采样掩码的格式与维度。
| "For non-MTP decoding, this is a list[int] per token step indicating which " | ||
| "vocabulary indices were kept after top_p/top_k sampling. " | ||
| "For MTP decoding, this is a list[list[int]] per token step, where each inner " | ||
| "list corresponds to one MTP group." |
There was a problem hiding this comment.
开启 keep_sampling_mask 时,如果 top_p=1.0 且未设置有效 top_k,按当前实现会保留整个 vocab 的索引列表(每个 token 都可能返回接近 vocab_size 个整数),不仅计算端需要完整 argsort/拷贝,网络/序列化开销也会非常大。建议在 CLI help/文档里明确提示该开销风险,或在实现侧对 top_p>=1.0 时提供可选的上限/降采样策略(例如强制要求 top_k>0 或支持只返回 boundary 阈值而非全量索引)。
| "list corresponds to one MTP group." | |
| "list corresponds to one MTP group. Warning: when top_p >= 1.0 and top_k is " | |
| "unset or non-positive, the returned index list may include nearly the entire " | |
| "vocabulary for each token step, which can significantly increase compute, " | |
| "memory, serialization, and network overhead. Prefer using this option with a " | |
| "bounded top_k to avoid very large responses." |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #7222 +/- ##
==========================================
Coverage ? 73.27%
==========================================
Files ? 377
Lines ? 53507
Branches ? 8375
==========================================
Hits ? 39206
Misses ? 11551
Partials ? 2750
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
添加keep_sampling_mask功能,详细见PR:#6725
Modifications
sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
Usage or Command
服务启动指令:
Accuracy Tests
yes
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.