Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
import torch.distributed as dist
from typing import Tuple
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused
from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
from lightllm.distributed import all_reduce
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor


class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
Expand Down Expand Up @@ -48,7 +50,7 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
Expand All @@ -62,9 +64,42 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
apply_deepstack_features(
self._apply_deepstack_features_wrapper_run(
input_embeddings=input_embdings,
infer_state=infer_state,
layer_num=self.layer_num_,
)
return input_embdings

def _apply_deepstack_features_wrapper_run(
self,
input_embeddings: torch.Tensor,
infer_state: InferStateInfo,
layer_num: int,
):
if torch.cuda.is_current_stream_capturing():
input_embeddings = input_embeddings.contiguous()
_input_embeddings = tensor_to_no_ref_tensor(input_embeddings)
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)

infer_state.prefill_cuda_graph_create_graph_obj()
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()

def apply_func(new_infer_state: InferStateInfo):
apply_deepstack_features(
input_embeddings=_input_embeddings,
infer_state=new_infer_state,
layer_num=layer_num,
)
return

infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph)
else:
apply_deepstack_features(
input_embeddings=input_embeddings,
infer_state=infer_state,
layer_num=layer_num,
)

return
Comment on lines +74 to +105
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hints for infer_state and new_infer_state should be the more specific Qwen3VLInferStateInfo instead of InferStateInfo. The apply_deepstack_features function requires attributes that are specific to Qwen3VLInferStateInfo, so using the correct type hint improves type safety and maintainability. Additionally, the redundant return statements can be removed for cleaner code.

    def _apply_deepstack_features_wrapper_run(
        self,
        input_embeddings: torch.Tensor,
        infer_state: Qwen3VLInferStateInfo,
        layer_num: int,
    ):
        if torch.cuda.is_current_stream_capturing():
            input_embeddings = input_embeddings.contiguous()
            _input_embeddings = tensor_to_no_ref_tensor(input_embeddings)
            pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
            pre_capture_graph.__exit__(None, None, None)

            infer_state.prefill_cuda_graph_create_graph_obj()
            infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()

            def apply_func(new_infer_state: Qwen3VLInferStateInfo):
                apply_deepstack_features(
                    input_embeddings=_input_embeddings,
                    infer_state=new_infer_state,
                    layer_num=layer_num,
                )

            infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph)
        else:
            apply_deepstack_features(
                input_embeddings=input_embeddings,
                infer_state=infer_state,
                layer_num=layer_num,
            )

Loading