From c8b388845b80335e532f930ff09e8be41d050eaa Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 26 Mar 2026 08:17:31 +0000 Subject: [PATCH] qwen3_vl_moe support prefill_cudagraph --- .../layer_infer/transformer_layer_infer.py | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 391ee8bf6b..40d4bbc0ad 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -1,12 +1,14 @@ import torch import torch.distributed as dist from typing import Tuple +from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): @@ -48,7 +50,7 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: @@ -62,9 +64,42 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - apply_deepstack_features( + self._apply_deepstack_features_wrapper_run( input_embeddings=input_embdings, infer_state=infer_state, layer_num=self.layer_num_, ) return input_embdings + + def _apply_deepstack_features_wrapper_run( + self, + input_embeddings: torch.Tensor, + infer_state: InferStateInfo, + layer_num: int, + ): + if torch.cuda.is_current_stream_capturing(): + input_embeddings = input_embeddings.contiguous() + _input_embeddings = tensor_to_no_ref_tensor(input_embeddings) + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() + pre_capture_graph.__exit__(None, None, None) + + infer_state.prefill_cuda_graph_create_graph_obj() + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() + + def apply_func(new_infer_state: InferStateInfo): + apply_deepstack_features( + input_embeddings=_input_embeddings, + infer_state=new_infer_state, + layer_num=layer_num, + ) + return + + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph) + else: + apply_deepstack_features( + input_embeddings=input_embeddings, + infer_state=infer_state, + layer_num=layer_num, + ) + + return