From 7d3110bd5b4359a6617ee18b744935eba1b4854e Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Tue, 7 Apr 2026 17:45:57 +0800 Subject: [PATCH] [Cherry-Pick][BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 1 - .../model_executor/layers/attention/flash_mask_attn_backend.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0ce..dce65b97274 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); constexpr int kBlockM = 128; diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 35d27504ab5..57fb6f93e25 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -309,7 +309,7 @@ def forward_mixed( q, k, v, - forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]], forward_meta.attn_cu_seqlens_k, forward_meta.seq_lens_encoder, res_encoder,