Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
const int head_dim) {
const int q_token_num = q_input.dims()[0];
const int k_token_num = k_input.dims()[0];
const int batch_size = cu_seq_q.dims()[0] - 1;
const int batch_size = cu_seq_k.dims()[0] - 1;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 添加针对预分配 tensor 场景的测试用例。

当前测试 (tests/operators/test_flash_mask_attn.py) 中 cu_seq_qcu_seq_k 均通过 paddle.arange(bsz + 1) 生成,shape 完全相同,无法触发此 bug。

建议添加测试用例:

  • cu_seq_q 预分配更大的维度(如 max_batch = bsz + 10
  • cu_seq_k 仍按真实 batch size 填充
  • 验证算子能正确计算并输出结果

这样可以确保修复在实际场景中有效,防止未来回归。


PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape");
PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape");
Comment on lines +52 to +58
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前变更修复的场景是 cu_seqlens_q / seq_lens_encoder 可能按 max_batch 预分配且第一维大于真实 batch size。现有单测(如 tests/operators/test_flash_mask_attn.py)仍只覆盖维度严格等于 bsz 的情况,建议补充一个回归用例:构造 cu_seqlens_q/seq_lens_encoder 的第一维 > (cu_seqlens_k.dims()[0]-1) 且只填充前 real_bsz 段,验证算子不报错且输出正确,以防该问题回归。

Copilot uses AI. Check for mistakes.
Comment on lines +57 to +58
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 描述里写的是“注释掉运行时的 PADDLE_ENFORCE shape 校验”,但这里实际仍保留了校验(只是从 == 放宽为 >=)。建议同步更新 PR 描述以匹配实际修改内容,避免后续排查问题时产生误解。

Copilot uses AI. Check for mistakes.

constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
Expand Down
Loading