-
Notifications
You must be signed in to change notification settings - Fork 736
[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn #7210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xiaoxiaohehe001
wants to merge
2
commits into
PaddlePaddle:develop
Choose a base branch
from
xiaoxiaohehe001:fix_flash_mask_attn_sm90
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3
−3
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,13 +49,13 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, | |
| const int head_dim) { | ||
| const int q_token_num = q_input.dims()[0]; | ||
| const int k_token_num = k_input.dims()[0]; | ||
| const int batch_size = cu_seq_q.dims()[0] - 1; | ||
| const int batch_size = cu_seq_k.dims()[0] - 1; | ||
|
|
||
| PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); | ||
|
Comment on lines
+52
to
56
|
||
| PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); | ||
| PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); | ||
| PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); | ||
| PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); | ||
|
|
||
| constexpr int kBlockM = 128; | ||
| constexpr int kBlockN = 128; | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里把 batch_size 改为由 cu_seq_k 推导是合理的,但同时移除了对 seq_len_encoder / cu_seq_q 与 batch_size 一致性的校验后,若出现 cu_seq_k 真实 batch_size 大于 seq_len_encoder.dims()[0] 或 cu_seq_q.dims()[0]-1 的情况,kernel 仍会按 grid_dims.z=batch_size 访问 seq_len_encoder[bidb] / cu_seq_q[bidb],会产生越界读并导致未定义行为。建议把原来的“==”校验放宽为下界校验(例如 seq_len_encoder.dims()[0] >= batch_size 且 cu_seq_q.dims()[0] >= batch_size+1),至少保证不会 OOB;如果允许更小,也需要相应收缩 params.batch_size / launch grid。