Skip to content

[Speculative Decoding]【Hackathon 10th Spring No.49】Adapt ngram_match and hybrid_mtp_ngram gpu kernels#7103

Closed
NKNaN wants to merge 9 commits intoPaddlePaddle:developfrom
NKNaN:ngram
Closed

[Speculative Decoding]【Hackathon 10th Spring No.49】Adapt ngram_match and hybrid_mtp_ngram gpu kernels#7103
NKNaN wants to merge 9 commits intoPaddlePaddle:developfrom
NKNaN:ngram

Conversation

@NKNaN
Copy link
Copy Markdown
Contributor

@NKNaN NKNaN commented Mar 31, 2026

Motivation

rfc: PaddlePaddle/community#1213

Modifications

  1. 第一版:
  • 实现方式:两个 kernel。
    • 第一阶段:count_and_find_candidate_kernel,网格为 <<<max_batch_size+1, 1024>>>。
      • block 0 用 BlockReduce 统计全局 unprocessed_batch_size。
      • block 1..N 各自负责一个 batch 并行执行候选查找(input_ids / pre_ids)。
    • 第二阶段:truncate_candidate,<<<1, 1024>>>,统一按 threshold 做截断和写回。
      • 该阶段使用 CUB BlockScan 做前缀和(processed_batch_size / sum_token_num),用于计算每个 batch 的可分配 token 上限并完成截断。
  1. 第二版优化:
  • 核心优化思路:

    • 在 speculative decoding 场景下,每个活跃(active)batch 至少消耗 1 个 token(seq_lens_this_time 最小值为 1)。因此以活跃 batch 数作为 token 消耗的下界估算:若前 k 个 batch 的活跃前缀和(exclusive prefix sum of is_active)已达到 threshold - 1,则第 k 个及之后的 batch 即便能找到匹配,可分配的 draft token 数也为 0,无需参与 Phase 1(find_candidate)。
    • 将首个满足条件的 batch 下标记为 cutoff_batch_id,Phase 1 只启动 cutoff_batch_id 个 block,同时在 Phase 0 中为被跳过的 batch 完成 seq_lens_this_time bookkeeping。
  • 双分支启动策略(host 端判断,避免无收益 D2H):

    • threshold >= max_batch_size 时:所有 batch 均在 token budget 之内,cutoff 永远等于 max_batch_size,Phase 0 毫无意义。此时完全跳过 Phase 0,直接以 max_batch_size 个 block 启动 Phase 1(与 v3 相同)。
    • threshold < max_batch_size 时:Phase 0 先以 <<<1, 1024>>> 计算 cutoff_batch_id(利用 cub::BlockScan 做包含前缀和,再用 atomicMin 找首个超阈值位置),D2H 拷贝该标量,Phase 1 以 <<<h_cutoff, 1024>>> 启动——真正减少 GPU wave 数量。

Usage or Command

None

Accuracy Tests

performance test (on windows, cpu AMD Ryzen 7, gpu 3060 laptop)

profiling_results

其中v3是第一版方案,v6是第二版方案,具体方案请参照following repo:
https://github.com/NKNaN/FastDeploy_ngram_match_kernel

v1是在gpu上按完全序列化的顺序进行执行(与cpu逻辑完全相同),在group 5(threshold有限,可能只需要对前几个batch分配了draft token就达到了token分配阈值的情况)的表现中,可以看到cpu的执行时间极快,v1的时间则衡量了在此情况下无法避免的gpu kernel时间

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 31, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 将 speculative decoding 里的 ngram_matchhybrid_mtp_ngram 从原先偏 CPU/Host 逻辑适配为 GPU kernel 两阶段实现,以降低延迟并减少 Host<->Device 拷贝,属于 spec_decode 路径上的算子性能优化与接口适配。

Changes:

  • 新增/替换 ngram_match CUDA 实现:拆分为 “统计+候选查找” 与 “阈值截断写回” 两阶段 kernel。
  • 更新 NgramProposerMTPProposer 调用方式:改为直接使用 GPU 输入,并新增/复用 GPU copy buffer 参数以匹配新算子签名。
  • 更新相关单测以适配新签名,并将测试设备切换为 GPU。

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/operators/test_ngram_match.py 适配 ngram_match 新签名并改为在 GPU 上运行/取回结果检查
tests/operators/test_hybrid_mtp_ngram.py 适配 hybrid_mtp_ngram 新签名并改为 GPU 上构造输入与断言
fastdeploy/spec_decode/ngram.py NgramProposer 改为走 GPU 输入与新增 copy buffer(避免 .cpu()/.cuda() 往返)
fastdeploy/spec_decode/mtp.py MTPProposer 调用 hybrid_mtp_ngram 适配新签名并缓存 copy buffer
custom_ops/gpu_ops/speculate_decoding/ngram_match.cu 新增 ngram_match CUDA 两阶段 kernel 实现并注册静态算子
custom_ops/gpu_ops/speculate_decoding/ngram_match.cc 删除旧的 host/CPU 风格实现
custom_ops/gpu_ops/speculate_decoding/ngram_match_core.cuh 抽取滑窗 ngram search 的 device 内联函数供两算子复用
custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu hybrid_mtp_ngram 适配两阶段 GPU kernel + 新签名
custom_ops/gpu_ops/cpp_extensions.cc 更新 C++ 扩展侧函数声明签名以匹配新增参数

Comment on lines 24 to +25
def setUp(self):
paddle.set_device("cpu")
paddle.set_device("gpu")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在 setUp() 里无条件切到 GPU,会导致在 CPU-only / 未编译 CUDA 的环境下直接报错,CI 也可能无法运行。建议在 setUp() 先判断 paddle.is_compiled_with_cuda(),不满足则 skipTest,并在通过检查后再 paddle.set_device("gpu")。

Copilot generated this review using guidance from repository custom instructions.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Comment on lines 26 to 28
def setUp(self):
paddle.set_device("gpu")
self.max_bsz = 2
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上:setUp() 里无条件 set_device("gpu") 会让未编译 CUDA/无 GPU 的环境无法运行该测试。建议增加 paddle.is_compiled_with_cuda() 判断并 skipTest;必要时也可以根据设备 capability 做更细粒度跳过。

Copilot generated this review using guidance from repository custom instructions.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Comment on lines +284 to +325
const int NTHREADS = 1024;

int *d_unprocessed_ptr;
cudaGetSymbolAddress(reinterpret_cast<void **>(&d_unprocessed_ptr),
d_ngram_unprocessed_batch_size);

ngram_count_and_find_candidate_kernel<NTHREADS>
<<<max_batch_size + 1, NTHREADS>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_ngram_size,
max_draft_tokens,
d_unprocessed_ptr,
max_batch_size);

ngram_truncate_candidate<NTHREADS><<<1, NTHREADS>>>(
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
draft_tokens_stride,
max_batch_size,
max_draft_tokens,
tokennum_threshold,
d_unprocessed_ptr);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 launch CUDA kernel 没有使用 Paddle 提供的 tensor.stream()(当前代码用默认 stream 0)。在 Paddle/phi 的执行流不是默认 stream 的情况下会破坏算子间的顺序依赖,导致竞态或错误结果。建议获取一个输入 tensor 的 stream(例如 seq_lens_this_time.stream()/input_ids.stream()),并在两次 kernel launch 的执行配置里显式传入该 stream。

Copilot uses AI. Check for mistakes.
Comment on lines +289 to +329
const int NTHREADS = 1024;

int *d_unprocessed_ptr;
cudaGetSymbolAddress(reinterpret_cast<void **>(&d_unprocessed_ptr),
d_mixed_unprocessed_batch_size);

mixed_count_and_find_candidate_kernel<NTHREADS>
<<<max_batch_size + 1, NTHREADS>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_ngram_size,
min_ngram_size,
max_draft_tokens,
d_unprocessed_ptr,
max_batch_size);

mixed_truncate_candidate<NTHREADS><<<1, NTHREADS>>>(
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
max_draft_tokens,
tokennum_threshold,
d_unprocessed_ptr);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样的问题:HybridMtpNgram 这里的 kernel launch 也没有绑定到 Paddle 的执行 stream(未使用 input_ids/seq_lens_this_time 的 .stream()),可能造成与前后算子的异步竞态。建议改为在 <<<... , 0, cu_stream>>> 上显式使用 tensor.stream()。

Copilot uses AI. Check for mistakes.
Comment on lines 1225 to +1228
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"],
self.model_inputs["input_ids_len"],
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),
self.target_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_decoder,
self.model_inputs["max_dec_len"].cpu(),
self.model_inputs["input_ids"],
self.model_inputs["input_ids_len"].cuda(),
self.model_inputs["pre_ids"],
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里每次调用都对 self.model_inputs["input_ids_len"] 做 .cuda(),会产生一次额外的 H2D 拷贝/新 Tensor 分配(input_ids_len 在 ProposerInputBatch 里初始化为 device="cpu"),对每 step 的延迟不友好。建议像 draft_tokens_copy 一样做一次性的 GPU buffer 缓存,并在更新 input_ids_len 时同步更新该 GPU buffer(或直接把 input_ids_len 维护在 GPU 上)。

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-07 10:24 CST

📋 Review 摘要

PR 概述:将 ngram_match 和 hybrid_mtp_ngram 算子从 CPU 实现迁移到 GPU CUDA kernel,采用两阶段并行策略优化性能。

变更范围custom_ops/gpu_ops/speculate_decoding/ (CUDA kernels)、fastdeploy/spec_decode/ (Python 调用层)、测试文件

影响面 Tag[OP] [Speculative Decoding]

问题

级别 文件 概述
🔴 Bug ngram_match.cu:214,239 使用了未定义的 NUM_THREADS 宏,编译将失败

总体评价

PR 实现了 ngram_match 的 GPU 并行化,算法设计合理(Phase 0 前缀优化 + Phase 1 并行搜索 + Phase 2 截断),代码结构清晰。但 ngram_truncate_candidate kernel 中存在宏未定义的编译错误,需要修复后才能正常编译。ngram_match_mixed.cu 中的同名 kernel 使用了模板参数正确传递,建议参考其实现方式统一修复。


int is_active_here =
(tid < (int)max_batch_size && seq_lens_this_time[tid] > 0) ? 1 : 0;
typedef cub::BlockReduce<int, NUM_THREADS> BlockReduce;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug NUM_THREADS 宏未定义,会导致编译失败。

ngram_match_core.cuh 中定义了 NGRAM_SEARCH_THREADSNGRAM_TRUNCATION_THREADS,但此处使用了 NUM_THREADS

建议修复方式一(推荐,与 mixed_truncate_candidate 保持一致):
将此 kernel 改为模板函数:

template <int NUM_THREADS>
__global__ void ngram_truncate_candidate(...)

然后在调用处传入模板参数:

ngram_truncate_candidate<NGRAM_TRUNCATION_THREADS><<<1, NGRAM_TRUNCATION_THREADS>>>(...);

建议修复方式二(直接替换):

typedef cub::BlockReduce<int, NGRAM_TRUNCATION_THREADS> BlockReduce;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}
}

typedef cub::BlockScan<int, NUM_THREADS> BlockScan;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 同上,NUM_THREADS 宏未定义。

建议修复:

typedef cub::BlockScan<int, NGRAM_TRUNCATION_THREADS> BlockScan;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-07 10:39 CST

📋 Review 摘要

PR 概述:将 ngram_match 和 hybrid_mtp_ngram 算子从 CPU 实现迁移到 GPU CUDA kernel,引入三阶段并行处理优化
变更范围custom_ops/gpu_ops/speculate_decoding/fastdeploy/spec_decode/tests/operators/
影响面 Tag[OP] [Speculative Decoding]

问题

级别 文件 概述
🔴 Bug ngram_match.cu:214 NUM_THREADS 宏未定义,将导致编译错误

总体评价

本 PR 将 ngram_match 算子从 CPU 实现迁移到 GPU,引入了高效的三阶段处理流程(cutoff 计算 → 候选搜索 → 截断),并优化了 Python 层避免了不必要的 CPU-GPU 数据拷贝。整体设计合理,但 ngram_truncate_candidate kernel 中存在未定义宏的编译错误,需要修复后方可合入。


int is_active_here =
(tid < (int)max_batch_size && seq_lens_this_time[tid] > 0) ? 1 : 0;
typedef cub::BlockReduce<int, NUM_THREADS> BlockReduce;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug NUM_THREADS 宏未定义,将导致编译错误。

该 kernel 未使用模板参数,但在第 214 行和 239 行使用了未定义的 NUM_THREADSngram_match_core.cuh 中只定义了 NGRAM_TRUNCATION_THREADS(值为 1024),而没有 NUM_THREADS

建议修复方式:将此 kernel 改为模板形式,或直接使用已定义的宏:

// 方案1:改为模板
template <int NUM_THREADS>
__global__ void ngram_truncate_candidate(...) {
  ...
}

// 方案2:使用已定义的宏
typedef cub::BlockReduce<int, NGRAM_TRUNCATION_THREADS> BlockReduce;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-07 14:13 CST

📋 Review 摘要

PR 概述:将 ngram_match 和 hybrid_mtp_ngram 算子从 CPU 实现迁移到 GPU CUDA kernel,通过三阶段并行化(cutoff 计算 → 候选查找 → 截断)降低推理延迟
变更范围:custom_ops/gpu_ops/speculate_decoding/、fastdeploy/spec_decode/、tests/operators/
影响面 Tag[Speculative Decoding] [OP]

问题

级别 文件 概述
🟡 建议 ngram_match_core.cuh:24 MAXBATCHSIZE=1024 硬编码限制需确认是否满足所有场景
🟡 建议 ngram_match.cu:1 Copyright 年份 2026 疑似笔误

总体评价

这是一个高质量的性能优化 PR,将原本在 CPU 上串行执行的 ngram 匹配逻辑改为 GPU 并行实现,并通过 cutoff 优化减少不必要的 kernel 启动。代码结构清晰,三阶段设计合理。建议确认 MAXBATCHSIZE=1024 的限制是否满足生产环境需求,以及修正 Copyright 年份。

#define NGRAM_TRUNCATION_THREADS 1024
#endif
#ifndef MAXBATCHSIZE
#define MAXBATCHSIZE 1024
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 MAXBATCHSIZE 硬编码为 1024 会限制 Phase 0 和 Phase 2 kernel 只能处理前 1024 个 batch。

ngram_compute_active_prefixngram_truncate_candidate 中,kernel 以 <<<1, MAXBATCHSIZE>>> 启动,每个线程处理一个 batch(tid < max_batch_size)。如果实际 max_batch_size > 1024

  1. Phase 0 的 cutoff 计算只考虑前 1024 个 batch
  2. Phase 2 的截断逻辑只处理前 1024 个 batch

建议:

  1. 如果当前 1024 已是系统上限,建议添加运行时断言 assert(max_batch_size <= MAXBATCHSIZE)
  2. 如需支持更大 batch,可考虑多 block 实现或提高 MAXBATCHSIZE(需权衡 shared memory 使用)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -0,0 +1,402 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 Copyright 年份 2026 应为 2025(当前为 2026 年 4 月,但代码应使用首次提交年份)。

同样的问题也出现在 ngram_match_core.cuh 第 1 行。

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-07 15:07 CST

📋 Review 摘要

PR 概述:将 ngram_match 和 hybrid_mtp_ngram 算子从 CPU 实现迁移到 GPU CUDA kernel,采用三阶段并行优化策略提升推理性能。

变更范围custom_ops/gpu_ops/speculate_decoding/ (CUDA kernels)、fastdeploy/spec_decode/ (Python 调用)、tests/operators/ (单元测试)

影响面 TagOP Speculative Decoding

问题

级别 文件 概述
🟡 建议 ngram_match.cu:313 one_wave_capacity 使用 static 变量,多 GPU 场景可能不准确
🟡 建议 ngram_match.cu:326 CUDA API 调用缺少错误检查
❓ 疑问 ngram_match_core.cuh:44 atomicMinint64_t*unsigned long long* 类型转换依赖平台特性

总体评价

代码实现质量较高,GPU kernel 的三阶段设计(Phase 0 计算 cutoff、Phase 1 并行查找候选、Phase 2 统一截断)合理,通过双分支启动策略有效避免了不必要的 D2H 拷贝开销。测试覆盖了基本场景和无匹配场景。建议考虑添加 CUDA 错误检查以提升生产环境的可调试性。

tokennum_threshold = std::stoi(env_var);
}

static int one_wave_capacity = []() {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 one_wave_capacity 使用 static 变量初始化,只在首次调用时获取设备属性。

在多 GPU 环境中,如果进程在不同 GPU 上执行推理,此值可能不准确。建议考虑每次调用时动态获取或使用当前设备的缓存机制。

参考做法:

int get_one_wave_capacity() {
    int dev = 0;
    cudaGetDevice(&dev);
    // 使用 thread_local 或 per-device cache
    ...
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if (tokennum_threshold < static_cast<int>(max_batch_size) &&
static_cast<int>(max_batch_size) > one_wave_capacity) {
int *d_cutoff_ptr;
cudaGetSymbolAddress(reinterpret_cast<void **>(&d_cutoff_ptr),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 CUDA API 调用(如 cudaGetSymbolAddress, cudaMemcpyAsync, cudaStreamSynchronize)缺少错误检查。

在生产环境中,建议添加错误检查以便于调试:

cudaError_t err = cudaGetSymbolAddress(...);
if (err != cudaSuccess) {
    // 处理错误或记录日志
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

}
}
if (match) {
atomicMin(reinterpret_cast<unsigned long long*>(shared_start_idx),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 atomicMin 使用 reinterpret_cast<unsigned long long*>int64_t* 转换为 unsigned long long*

虽然在大多数平台上 int64_tunsigned long long 大小相同(8字节),但这种转换依赖于平台特性。建议确认这在所有目标平台上的正确性,或考虑使用显式的 8 字节无符号类型。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@freeliuzc
Copy link
Copy Markdown
Collaborator

同学你好,你做的工作很棒,但是另一位贡献者率先完成了可合入版本,并且提供了更详细的 benchmark 信息,因此选取他的 PR 进行合入。
但鉴于你也较好的完成了任务,黑客松活动也算成功,可联系相关同学确认
#7200
#7136

@freeliuzc freeliuzc closed this Apr 7, 2026
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-08 11:36 CST

📋 Review 摘要

PR 概述:将 ngram_match 和 hybrid_mtp_ngram 从 CPU 实现改为 GPU 三阶段执行,优化 speculative decoding 性能
变更范围:custom_ops/gpu_ops/speculate_decoding/、fastdeploy/spec_decode/、tests/operators/
影响面 Tag[Speculative Decoding] [OP]

问题

级别 文件 概述
🔴 Bug custom_ops/gpu_ops/speculate_decoding/ngram_match.cu:273 memcpy 可能导致内存越界写入
🔴 Bug custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu:220 同样的 memcpy 越界问题
🟡 建议 tests/operators/test_ngram_match.py 测试覆盖不足,缺少 prefilling phase 和边界情况测试

总体评价

PR 实现了 GPU 化的三阶段 ngram 匹配优化,架构清晰,但存在两处内存越界风险需要在合入前修复。对于 prefilling phase 的 batch,seq_lens_this_time_copy 可能包含原始序列长度而非 draft token 数量,导致 memcpy 写入超出 draft_tokens 缓冲区边界。建议在 memcpy 前添加边界检查。

} else {
memcpy(cur_draft_tokens + 1,
cur_draft_tokens_copy + 1,
sizeof(int64_t) * (seq_lens_this_time_copy[tid] - 1));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug memcpy 可能导致内存越界写入

对于 prefilling phase 的 batch(seq_lens_encoder[tid] > 0),Phase 0 会设置 seq_lens_this_time_copy[tid] = seq_lens_this_time[tid]。此时 seq_lens_this_time[tid] 是输入参数,可能表示原始序列长度,而非 draft token 数量。

seq_lens_this_time_copy[tid] - 1 超过 max_draft_tokens 时,memcpy 会写入超出 draft_tokens 缓冲区边界(draft_tokens 容量为 max_draft_tokens + 1)。

建议修复:

} else {
  int copy_count = seq_lens_this_time_copy[tid] - 1;
  if (copy_count > max_draft_tokens) copy_count = max_draft_tokens;
  if (copy_count > 0) {
    memcpy(cur_draft_tokens + 1,
           cur_draft_tokens_copy + 1,
           sizeof(int64_t) * copy_count);
  }
  seq_lens_this_time[tid] = copy_count + 1;
}

int64_t remaining = max_dec_len[tid] - step_idx[tid] - 1;
if (static_cast<int64_t>(max_draft_tokens_query) > remaining)
max_draft_tokens_query = static_cast<int>(remaining);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 同样的 memcpy 越界问题

与 ngram_match.cu:273 相同,对于 prefilling phase 的 batch,seq_lens_this_time_copy[tid] - 1 可能超过 max_draft_tokens,导致 memcpy 写入越界。

建议添加相同的边界检查:

} else {
  int copy_count = seq_lens_this_time_copy[tid] - ori_seq_len;
  if (copy_count > max_draft_tokens) copy_count = max_draft_tokens;
  if (copy_count > 0) {
    memcpy(cur_draft_tokens + ori_seq_len,
           cur_draft_tokens_copy_ptr + ori_seq_len,
           sizeof(int64_t) * copy_count);
  }
  seq_lens_this_time[tid] = ori_seq_len + copy_count;
}

self.skipTest("CUDA is not available, skipping GPU test")
paddle.set_device("gpu")

def test_basic_match(self):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 测试覆盖不足

当前测试仅覆盖了基本的 decoding phase 场景,建议补充以下测试用例:

  1. Prefilling phase 测试:设置 seq_lens_encoder > 0 验证 prefilling batch 的正确处理
  2. Threshold 截断测试:构造多 batch 场景验证 token 数量截断逻辑
  3. 边界情况max_batch_size > one_wave_capacity 时 Phase 0 启动的场景

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review 摘要

PR 概述:将 ngram_match 和 hybrid_mtp_ngram 从 CPU 实现优化为 GPU 三阶段 kernel 实现,增加 cutoff 优化以减少 GPU wave 数量

变更范围:custom_ops/gpu_ops/speculate_decoding/、fastdeploy/spec_decode/

影响面 Tag[Speculative Decoding] [OP]

📝 PR 规范检查

标题包含有效 Tag [Speculative Decoding],描述包含 RFC 链接和详细的实现说明,Motivation 和 Modifications 部分填写完整,符合规范。

问题

级别 文件 概述
🔴 Bug custom_ops/gpu_ops/speculate_decoding/ngram_match.cu:359 launch_size 可能为 0 时导致无效的 CUDA kernel 启动
🔴 Bug custom_ops/gpu_ops/speculate_decoding/ngram_match.cu:328 tpm = 0 时会导致除零错误

总体评价

代码整体设计合理,三阶段 kernel 和 cutoff 优化思路清晰,使用 CUB 库提升并行效率。但在边界条件处理上存在两个潜在问题,建议修复后合并。测试覆盖基本场景,建议增加更多边界情况测试。

launch_size = h_cutoff;
}

ngram_count_and_find_candidate_kernel<<<launch_size,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Buglaunch_size 为 0 时会导致无效的 CUDA kernel 启动 <<<0, 1024>>>

tokennum_threshold < max_batch_size 且所有 active batches 的前缀和在第一个 batch 就达到 threshold - 1 时,h_cutoff 可能为 0,导致 launch_size = 0

建议添加检查:

if (launch_size > 0) {
  ngram_count_and_find_candidate_kernel<<<launch_size, ...>>>(...);
}

&sm_count, cudaDevAttrMultiProcessorCount, dev));
CUDA_CHECK(cudaDeviceGetAttribute(
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev));
one_wave_capacity_cache[dev] = sm_count * tpm / NGRAM_SEARCH_THREADS;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bugone_wave_capacity 为 0 时会导致除零错误。

one_wave_capacity_cache[dev] = sm_count * tpm / NGRAM_SEARCH_THREADS; 如果 tpm = 0,会导致除零。虽然这在真实硬件上不太可能,但为了健壮性应添加检查。

建议:

if (one_wave_capacity_cache[dev] < 0) {
  int sm_count = 0, tpm = 0;
  CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev));
  CUDA_CHECK(cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev));
  one_wave_capacity_cache[dev] = (tpm > 0) ? sm_count * tpm / NGRAM_SEARCH_THREADS : sm_count;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants