diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0ce..8b2e1b77a20 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -49,13 +49,13 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, const int head_dim) { const int q_token_num = q_input.dims()[0]; const int k_token_num = k_input.dims()[0]; - const int batch_size = cu_seq_q.dims()[0] - 1; + const int batch_size = cu_seq_k.dims()[0] - 1; PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); - PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); + PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); + PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); constexpr int kBlockM = 128; constexpr int kBlockN = 128;