From e99d1b57a692ffa7d8ed842dbcbed575bfded14b Mon Sep 17 00:00:00 2001 From: ming1753 Date: Tue, 7 Apr 2026 19:06:19 +0800 Subject: [PATCH] [Bug Fix] Fix some bugs --- .../get_block_shape_and_split_kv_block.cu | 16 ++++++++++++++-- .../gpu_ops/flash_mask_attn/mainloop_attn.hpp | 17 +++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 2b5c1fbc7d0..799950e52ab 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -290,10 +290,16 @@ void GetBlockShapeAndSplitKVBlock( // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data // is only for branching in attention. #ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU - if (!phi::backends::gpu::IsCUDAGraphCapturing()) -#endif + if (!phi::backends::gpu::IsCUDAGraphCapturing()) { max_len_tensor_cpu.copy_( max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + cudaStreamSynchronize(stream); + } +#else + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + cudaStreamSynchronize(stream); +#endif auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -442,6 +448,12 @@ void GetBlockShapeAndSplitKVBlock( encoder_num_blocks_x_cpu.copy_( encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); } +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) + cudaStreamSynchronize(stream); +#else + cudaStreamSynchronize(stream); +#endif } std::vector> GetBlockShapeAndSplitKVBlockInferShape( diff --git a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp index 0816667c2a7..3bcd31651c6 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp @@ -489,6 +489,23 @@ struct CollectiveMainloopAttn { tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); + if (seq_len_k - n_block * kBlockN < kBlockN) { + int valid_k = seq_len_k - n_block * kBlockN; + auto sVt_this = sVt(_, _, smem_pipe_read_v.index()); + constexpr int kHdLo = decltype(get<0, 0>(shape(sVt_this)))::value; + constexpr int kHdHi = decltype(get<0, 1>(shape(sVt_this)))::value; + if (thread_idx >= valid_k && thread_idx < kBlockN) { +#pragma unroll + for (int hd_hi = 0; hd_hi < kHdHi; ++hd_hi) { +#pragma unroll + for (int hd_lo = 0; hd_lo < kHdLo; ++hd_lo) { + sVt_this(make_coord(make_coord(hd_lo, hd_hi), thread_idx)) = + Element(0); + } + } + } + cutlass::arch::fence_view_async_shared(); + } gemm( tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive();