Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里移除了 batch_size == seq_len_encoder.dims()[0] 校验后,kernel 里仍会按 bidb 访问 seq_len_encoder[bidb](见 flash_mask_attn_kernel.hpp 中 data_params.seq_len_encoder[bidb])。建议至少增加 seq_len_encoder.dims()[0] >= batch_size 的断言(而不是完全去掉),避免 seq_len_encoder 实际长度小于真实 batch_size 时出现越界访问/非法内存读。

Suggested change
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape");

Copilot uses AI. Check for mistakes.
Comment on lines 54 to 57
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前这一组 shape 校验仍然依赖于 batch_size(上方由 cu_seq_q.dims()[0] - 1 推导),因此仅在 Python 侧通过截断 cu_seqlens_q 来“间接修复”。这与 PR 描述的“batch_size 推导由 cu_seq_q 改为 cu_seq_k”不一致,且算子被独立调用时仍可能得到错误 batch_size。建议在算子内部改为用 cu_seq_k.dims()[0] - 1 推导 batch_size,并相应调整校验(例如确保 cu_seq_q.dims()[0] >= batch_size + 1)。

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 batch_size 推导逻辑可进一步优化

PR 描述说"修复 batch_size 的推导来源:由 cu_seq_q 改为 cu_seq_k",但当前 C++ 代码中 batch_size 仍然从 cu_seq_q.dims()[0] - 1 推导(第 52 行,非本次变更)。

虽然 Python 端通过切片 cu_seqlens_q[: attn_cu_seqlens_k.shape[0]] 确保传入 C++ 的维度正确(workaround),但建议在 C++ 端使用 cu_seq_k 推导 batch_size,使代码意图更清晰、与描述一致:

const int batch_size = cu_seq_k.dims()[0] - 1;  // 替换第 52 行


constexpr int kBlockM = 128;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def forward_mixed(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]],
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增了对 cu_seqlens_q 的截断以适配 attn_cu_seqlens_k 的长度变化,但当前仓库的 flash_mask_attention 单测(tests/operators/test_flash_mask_attn.py)只覆盖了 cu_seq_q/cu_seq_k/seq_len_encoder 都按真实 bsz 构造的路径。建议补充一个用例:cu_seqlens_q/seq_lens_encoder 预分配为 max_batch(大于真实 bsz),而 cu_seqlens_k 为真实 bsz,以验证该截断与算子行为在该场景下稳定工作。

Copilot generated this review using guidance from repository custom instructions.
forward_meta.attn_cu_seqlens_k,
forward_meta.seq_lens_encoder,
res_encoder,
Expand Down
Loading