-
Notifications
You must be signed in to change notification settings - Fork 736
[Cherry-Pick][BugFix] Fix batch_size derivation and relax shape check…#7210 #7212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.6
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,13 +49,13 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, | |
| const int head_dim) { | ||
| const int q_token_num = q_input.dims()[0]; | ||
| const int k_token_num = k_input.dims()[0]; | ||
| const int batch_size = cu_seq_q.dims()[0] - 1; | ||
| const int batch_size = cu_seq_k.dims()[0] - 1; | ||
|
|
||
| PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); | ||
| PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); | ||
| PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); | ||
|
Comment on lines
+52
to
+58
|
||
|
|
||
| constexpr int kBlockM = 128; | ||
| constexpr int kBlockN = 128; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议 添加针对预分配 tensor 场景的测试用例。
当前测试 (
tests/operators/test_flash_mask_attn.py) 中cu_seq_q和cu_seq_k均通过paddle.arange(bsz + 1)生成,shape 完全相同,无法触发此 bug。建议添加测试用例:
cu_seq_q预分配更大的维度(如max_batch = bsz + 10)cu_seq_k仍按真实 batch size 填充这样可以确保修复在实际场景中有效,防止未来回归。