From f794e43443cd3b1787de98d04a6e3ceae58fbcae Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Tue, 7 Apr 2026 16:48:41 +0800 Subject: [PATCH] [Cherry-Pick][BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0ce..8b2e1b77a20 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -49,13 +49,13 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, const int head_dim) { const int q_token_num = q_input.dims()[0]; const int k_token_num = k_input.dims()[0]; - const int batch_size = cu_seq_q.dims()[0] - 1; + const int batch_size = cu_seq_k.dims()[0] - 1; PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); - PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); + PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); + PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); constexpr int kBlockM = 128; constexpr int kBlockN = 128;