Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
const int head_dim) {
const int q_token_num = q_input.dims()[0];
const int k_token_num = k_input.dims()[0];
const int batch_size = cu_seq_q.dims()[0] - 1;
const int batch_size = cu_seq_k.dims()[0] - 1;

PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 batch_size 改为由 cu_seq_k 推导是合理的,但同时移除了对 seq_len_encoder / cu_seq_q 与 batch_size 一致性的校验后,若出现 cu_seq_k 真实 batch_size 大于 seq_len_encoder.dims()[0] 或 cu_seq_q.dims()[0]-1 的情况,kernel 仍会按 grid_dims.z=batch_size 访问 seq_len_encoder[bidb] / cu_seq_q[bidb],会产生越界读并导致未定义行为。建议把原来的“==”校验放宽为下界校验(例如 seq_len_encoder.dims()[0] >= batch_size 且 cu_seq_q.dims()[0] >= batch_size+1),至少保证不会 OOB;如果允许更小,也需要相应收缩 params.batch_size / launch grid。

Suggested change
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape");
PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape");

Copilot uses AI. Check for mistakes.
Comment on lines +52 to 56
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本次修改是为兼容 cu_seqlens_q / seq_lens_encoder 可能按 max_batch 预分配但有效 batch 更小的场景。建议补充一个单测覆盖该 case(例如构造 cu_seq_q/seq_len_encoder 的 first-dim > 实际 batch_size,且 cu_seq_k 仍为真实 batch_size+1),以防后续有人恢复“==”断言或再次把 batch_size 推导改回 cu_seq_q 导致回归。

Copilot generated this review using guidance from repository custom instructions.
PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape");
PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape");

constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
Expand Down
Loading