-
Notifications
You must be signed in to change notification settings - Fork 736
[Cherry-Pick][BugFix] Fix batch_size derivation and relax shape check… #7216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.6
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, | |
| PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); | ||
|
Comment on lines
54
to
57
|
||
|
|
||
| constexpr int kBlockM = 128; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -309,7 +309,7 @@ def forward_mixed( | |
| q, | ||
| k, | ||
| v, | ||
| forward_meta.cu_seqlens_q, | ||
| forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]], | ||
|
||
| forward_meta.attn_cu_seqlens_k, | ||
| forward_meta.seq_lens_encoder, | ||
| res_encoder, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里移除了
batch_size == seq_len_encoder.dims()[0]校验后,kernel 里仍会按bidb访问seq_len_encoder[bidb](见 flash_mask_attn_kernel.hpp 中data_params.seq_len_encoder[bidb])。建议至少增加seq_len_encoder.dims()[0] >= batch_size的断言(而不是完全去掉),避免seq_len_encoder实际长度小于真实 batch_size 时出现越界访问/非法内存读。