From d13746489117b821a160860ce459c19e7edd7094 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Tue, 7 Apr 2026 15:37:00 +0800 Subject: [PATCH 1/4] [BugFix] fix_flash_mask_attn_sm90 --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0ce..6e14383b293 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -49,13 +49,11 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, const int head_dim) { const int q_token_num = q_input.dims()[0]; const int k_token_num = k_input.dims()[0]; - const int batch_size = cu_seq_q.dims()[0] - 1; + const int batch_size = cu_seq_k.dims()[0] - 1; PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); - PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); constexpr int kBlockM = 128; constexpr int kBlockN = 128; From ec2a545bb4c6e2d98cd8405ad44e2661836b23b3 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Tue, 7 Apr 2026 15:49:03 +0800 Subject: [PATCH 2/4] [BugFix] fix_flash_mask_attn_sm90 --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index 6e14383b293..8b2e1b77a20 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -54,6 +54,8 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); + PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); + PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); constexpr int kBlockM = 128; constexpr int kBlockN = 128; From 77d6e126d819e8c41ea335572aac43063778d539 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Wed, 8 Apr 2026 16:21:34 +0800 Subject: [PATCH 3/4] [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index 8b2e1b77a20..dce65b97274 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -49,13 +49,12 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, const int head_dim) { const int q_token_num = q_input.dims()[0]; const int k_token_num = k_input.dims()[0]; - const int batch_size = cu_seq_k.dims()[0] - 1; + const int batch_size = cu_seq_q.dims()[0] - 1; PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(seq_len_encoder.dims()[0] >= batch_size, "Unmatched shape"); - PADDLE_ENFORCE(cu_seq_q.dims()[0] >= batch_size + 1, "Unmatched shape"); + PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); constexpr int kBlockM = 128; constexpr int kBlockN = 128; From baf8767059e5bea0c41d38657e686aa72e174cfd Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Wed, 8 Apr 2026 16:22:49 +0800 Subject: [PATCH 4/4] [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn --- .../model_executor/layers/attention/flash_mask_attn_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 6e05ca0c3b8..5b23054ce70 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -311,7 +311,7 @@ def forward_mixed( q, k, v, - forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]], forward_meta.attn_cu_seqlens_k, forward_meta.seq_lens_encoder, res_encoder,