From ecf34aeaa0f25fc076a1bb622c7263c529fbf4ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Wed, 11 Mar 2026 11:01:19 +0800 Subject: [PATCH 01/54] feat: vit seperation --- .../common/basemodel/attention_vit/fa3/fp.py | 41 +- lightllm/models/internvl/model.py | 13 +- .../qwen_vl/layer_infer/pre_layer_infer.py | 36 +- .../vit/triton_kernel/flashattention_nopad.py | 44 +- lightllm/server/api_cli.py | 46 +- lightllm/server/api_http.py | 20 +- lightllm/server/api_lightllm.py | 18 +- lightllm/server/api_server.py | 4 +- lightllm/server/api_start.py | 147 +++++- lightllm/server/config_server/api_http.py | 35 ++ .../server/core/objs/io_objs/group_req.py | 4 +- .../impl/memory_cache_with_redis.py | 74 +++ .../embed_cache/impl/naive_memory_cache.py | 29 +- lightllm/server/embed_cache/manager.py | 14 +- lightllm/server/embed_cache/utils.py | 426 ++++++++++++++++++ lightllm/server/httpserver/manager.py | 112 ++++- lightllm/server/multimodal_params.py | 29 +- .../server/router/model_infer/infer_batch.py | 4 +- .../model_infer/mode_backend/base_backend.py | 2 +- lightllm/server/visualserver/manager.py | 106 ++++- .../visualserver/model_infer/model_rpc.py | 72 +-- lightllm/server/visualserver/register_loop.py | 42 ++ lightllm/server/visualserver/vit_connect.py | 237 ++++++++++ lightllm/utils/redis_utils.py | 74 +++ lightllm/utils/start_utils.py | 8 + requirements.txt | 2 + 26 files changed, 1461 insertions(+), 178 deletions(-) create mode 100644 lightllm/server/embed_cache/impl/memory_cache_with_redis.py create mode 100644 lightllm/server/visualserver/register_loop.py create mode 100644 lightllm/server/visualserver/vit_connect.py create mode 100644 lightllm/utils/redis_utils.py diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index f804116f1f..f1bef078a7 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,6 +1,7 @@ import dataclasses import torch from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.utils.sgl_utils import flash_attn_varlen_func class Fa3VitAttBackend(BaseVitAttBackend): @@ -17,42 +18,18 @@ def _vit_att_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( + o = flash_attn_varlen_func( q, k, v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + window_size=window_size, attention_chunk=0, softcap=0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, ) - return o diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index ccb76d3512..3e5b9c5e2a 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -68,6 +68,7 @@ def init_imageitem_extral_params( img.extra_params["image_patch_max_num"] = 6 elif num_images > 6: img.extra_params["image_patch_max_num"] = 0 + img.patch_num = self.get_image_patch(img) return def init_audioitem_extral_params( @@ -75,14 +76,14 @@ def init_audioitem_extral_params( ): return - def get_image_token_length(self, img: ImageItem): - return ( - self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - * self.image_length + def get_image_patch(self, img: ImageItem): + return self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True ) + def get_image_token_length(self, img: ImageItem): + return self.get_image_patch(img) * self.image_length + def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length audio_token_num = 0 diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..0127fbea8b 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,11 +1,15 @@ +import rpyc +import socket import torch import torch.distributed as dist from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args """ @@ -26,17 +30,33 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) + self.args = get_env_start_args() + self.cache_client = None + if self.args.enable_remote_vit: + self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return + + def _copy_loaded_embed_to_cache( + self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int + ): + if embed_tensor.ndim == 2: + embed_tensor = embed_tensor.unsqueeze(1) + + token_num, layer_num, hidden_size = embed_tensor.shape + cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] + unique_uids = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.weight.shape[1] - for batch_id, p in enumerate(infer_state.multimodal_params): + for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image if img["token_id"] in img_start_token_ids: @@ -44,6 +64,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs_in_cache.append(img["start_index_in_embed_cache"]) + unique_uids.append(img["uuid"]) out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -55,6 +76,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei else cpu_embed_cache_client.cpu_embed_cache_tensor ) + if self.args.enable_remote_vit: + release_ids = [] + for _, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + release_ids.append(img["uuid"]) + + for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache): + embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir) + self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache) + + if release_ids: + self.cache_client.root.release(release_ids) + assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..3a0b2d2069 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -167,44 +167,20 @@ def flash_attention_v3_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( + o = flash_attn_varlen_func( q, k, v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], - 0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + window_size=window_size, + softcap=0.0, ) - - return + return o except ImportError: print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index c8a82d3239..4e5ab7e421 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,17 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], + choices=[ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual", + "visual_only", + ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -605,6 +615,40 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--image_embed_dir", + type=str, + default=None, + help="path for vit embed", + ) + parser.add_argument( + "--enable_remote_vit", + action="store_true", + help="Whether to enable remote vit for multimodal service.", + ) + parser.add_argument( + "--remote_vit_port", + type=int, + default=12346, + help="The port number for the remote vit service.", + ) + parser.add_argument( + "--redis_port", + type=int, + default=6379, + help="The port number for the redis service in config_server mode.", + ) + parser.add_argument( + "--redis_evict_fraction", + type=float, + default=0.3, + help="The evict fraction for the redis service in config_server mode.", + ) + parser.add_argument( + "--start_redis", + action="store_true", + help="Whether to start the redis service in config_server mode.", + ) parser.add_argument( "--enable_cpu_cache", action="store_true", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..bf246f8f0d 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -43,7 +43,7 @@ from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .api_lightllm import lightllm_get_score +from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger from lightllm.utils.error_utils import ServerBusyError @@ -92,6 +92,8 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManagerForPDMaster( args=args, ) + elif args.run_mode == "visual": + self.metric_client = MetricClient(args.metric_port) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -136,7 +138,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode == "pd_master": + if g_objs.args.run_mode in ["pd_master", "visual"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/get_image_embedding") +async def get_image_embed(request: Request) -> Response: + try: + return await lightllm_get_image_embedding(request, g_objs.httpserver_manager) + except ServerBusyError as e: + logger.error("%s", str(e), exc_info=True) + return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: @@ -359,6 +373,8 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + if g_objs.httpserver_manager is None: + return loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..bfb8bff6db 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -1,7 +1,7 @@ import collections from typing import AsyncGenerator from fastapi import BackgroundTasks, Request -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager @@ -150,3 +150,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: + request_dict = await request.json() + # request_dict: {'parameters': {'max_new_tokens': 128}, + # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} + sample_params_dict = request_dict["parameters"] + sampling_params = SamplingParams() + sampling_params.init(tokenizer=None, **sample_params_dict) + sampling_params.verify() + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + + await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) + + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..c6700c0416 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode == "visual": + visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 77355f0d06..ef1a521d76 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -5,7 +5,7 @@ import subprocess import signal from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker -from lightllm.utils.start_utils import process_manager, kill_recursive +from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger @@ -15,12 +15,37 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import has_audio_module, has_vision_module logger = init_logger(__name__) +def _ensure_remote_vit_embed_dir(image_embed_dir: str) -> None: + if os.path.exists(image_embed_dir): + if not os.path.isdir(image_embed_dir): + raise ValueError(f"image_embed_dir is not a directory: {image_embed_dir}") + return + + os.makedirs(image_embed_dir, mode=0o777, exist_ok=True) + os.chmod(image_embed_dir, 0o777) + + +def _prepare_remote_vit_embed_dir(args): + remote_vit_mode = args.enable_remote_vit or args.run_mode in ["visual", "visual_only"] + if not remote_vit_mode: + return + + if not args.image_embed_dir: + raise ValueError("remote vit mode requires --image_embed_dir to be set") + + args.image_embed_dir = os.path.abspath(args.image_embed_dir) + _ensure_remote_vit_embed_dir(args.image_embed_dir) + + logger.info(f"using image_embed_dir: {args.image_embed_dir}") + + def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: @@ -57,11 +82,12 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): +def normal_or_p_d_start(args, only_prepare=False): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args @@ -73,7 +99,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -168,6 +194,8 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + args.enable_multimodal = is_multimodal_mode(args) + _prepare_remote_vit_embed_dir(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -229,11 +257,16 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + if only_prepare: + return + already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) if args.pd_decode_rpyc_port is not None: already_uesd_ports.append(args.pd_decode_rpyc_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -241,8 +274,10 @@ def normal_or_p_d_start(args): ports_locker.lock_port() node_world_size = args.tp // args.nnodes + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -265,8 +300,12 @@ def normal_or_p_d_start(args): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] visual_model_tp_ports.append(tp_ports_for_dp) can_use_ports = can_use_ports[args.visual_tp :] - visual_nccl_ports.append(can_use_ports[0]) - can_use_ports = can_use_ports[1:] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] # 将申请好的端口放入args参数中 if args.nccl_port is None: @@ -316,27 +355,27 @@ def normal_or_p_d_start(args): start_args=[(args,)], ) - if not args.disable_vision: - from .visualserver.manager import start_visual_process + if not args.disable_audio: + from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( start_funcs=[ - start_visual_process, + start_audio_process, ], start_args=[ - (args, visual_model_tp_ports), + (args,), ], ) - if not args.disable_audio: - from .audioserver.manager import start_audio_process + if not args.disable_vision and not args.enable_remote_vit: + from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( start_funcs=[ - start_audio_process, + start_visual_process, ], start_args=[ - (args,), + (args, visual_model_tp_ports), ], ) @@ -463,6 +502,81 @@ def pd_master_start(args): http_server_process.wait() +def visual_start(args): + normal_or_p_d_start(args, only_prepare=True) + + already_uesd_ports = [args.remote_vit_port] + if args.nccl_port is not None: + already_uesd_ports.append(args.nccl_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) + + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_ports=already_uesd_ports, + ) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + visual_port, + audio_port, + cache_port, + metric_port, + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] + + visual_model_tp_ports = [] + visual_nccl_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + visual_model_tp_ports.append(tp_ports_for_dp) + can_use_ports = can_use_ports[args.visual_tp :] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + + args.router_port = router_port + args.visual_port = visual_port + args.audio_port = audio_port + args.cache_port = cache_port + args.metric_port = metric_port + args.visual_node_id = uuid.uuid4().int + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(args,)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": @@ -470,6 +584,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.start_redis: + start_redis_service(args) + set_env_start_args(args) command = [ diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..c55b743480 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.vit_connect import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name @@ -19,7 +20,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -80,6 +107,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_objs) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd2562..75f2c0e2f1 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -23,7 +23,9 @@ def to_group_req_index(self): return GroupReqIndexes( group_req_id=self.group_req_id, multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs] + if self.shm_req_objs is not None + else None, time_mark=self.time_mark, ) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py new file mode 100644 index 0000000000..6ef8c66f68 --- /dev/null +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -0,0 +1,74 @@ +import uuid +import threading +import dataclasses +import requests +from typing import Union, Optional +import torch +import time +from collections import deque +import multiprocessing.shared_memory as shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, EmbedRefCountRedis +from .naive_memory_cache import Record, InMemoryCache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MemoryCacheWithRedis(InMemoryCache): + def __init__(self, args) -> None: + super().__init__(args) + redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" + self.redis_cache = EmbedRefCountRedis( + redis_url=redis_url, + capacity=args.cache_capacity, + evict_fraction=args.redis_evict_fraction, + image_embed_dir=args.image_embed_dir, + ) + # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id + # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 + # 硬盘里的图片image embed 数量。 + self.capacity = max(1, args.cache_capacity * 2) + + # llm 负责release + def release(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + rec = self._records.get(id) + if rec is None: + continue + + redis_exist = self.redis_cache.query(str(id)) + if redis_exist: + self.redis_cache.decr(str(id)) + + # remote_vit 模式下 release 可能走“预层提前释放 + 请求结束兜底释放”两条路径, + # 这里避免本地 ref 被重复减成负数,保证 release 可重复调用。 + if rec.ref > 0: + self._update_record_ref(rec, -1) + + # vit 负责set + def set_items_embed(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + self.redis_cache.insert(str(id)) + rec = self._records.get(id) + if rec is not None: + rec.embed = True + if rec.ref > 0: + self._update_record_ref_by_id(id, -1) + # 保留一份 redis 引用,直到真正的消费者读取完成后再 release, + # 避免 VIT 刚写完文件但 LLM 还没来得及读取时被 LRU 误删。 + + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: + ret = [] + for id in ids: + if embeding_only: + exist = self.redis_cache.query(str(id)) + else: + exist = self.redis_cache.query_and_incre(str(id)) + ret.append(exist) + if exist: + rec = self._records.get(id) + if rec is not None: + rec.embed = True + return ret diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index fbce108762..f76cc5d78d 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -36,6 +36,7 @@ class InMemoryCache: def __init__(self, args) -> None: self.args = args self._id_to_records = dict() + self._records = self._id_to_records self._md5_to_record = dict() self._sorted_records = SortedSet(key=lambda x: (x.ref, x.visittime, x.id)) self.capacity = max(1, args.cache_capacity) @@ -125,18 +126,26 @@ def _free_to_alloc(self, free_min_count: int, new_md5_dict: Dict[str, int]) -> D def _add_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] - self._sorted_records.remove(rec) - rec.ref += 1 - self._sorted_records.add(rec) + self._update_record_ref(rec, 1) return def _del_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] + self._update_record_ref(rec, -1) + return + + def _update_record_ref(self, rec: Record, delta: int): self._sorted_records.remove(rec) - rec.ref -= 1 + rec.ref += delta + rec.visittime = time.time() self._sorted_records.add(rec) return + def _update_record_ref_by_id(self, id_: int, delta: int): + rec: Record = self._id_to_records[id_] + self._update_record_ref(rec, delta) + return + def _judge_enough_token_cache(self, md5sum_list: list[str], token_num_list: list[int]) -> bool: tmp_dict = {} for md5, token_num in zip(md5sum_list, token_num_list): @@ -160,14 +169,13 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l new_md5_dict[m] = token_need new_needed = len(new_md5_dict) - alloc_md5_dict = self._free_to_alloc( free_min_count=new_needed - (self.capacity - self.occupied), new_md5_dict=new_md5_dict ) if len(alloc_md5_dict) == len(new_md5_dict): for md5sum, mem_block in alloc_md5_dict.items(): token_num = new_md5_dict[md5sum] - uid_int = uuid.uuid1().int + uid_int = md5sum self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, @@ -207,15 +215,14 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l return results else: + for md5sum in add_ref_m_list: + self._del_ref(md5sum) return None def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: - rec: Record = self._id_to_records[id_] - self._sorted_records.remove(rec) - rec.ref -= 1 - self._sorted_records.add(rec) + self._update_record_ref_by_id(id_, -1) def set_items_data(self, ids: list[int]) -> None: for id_ in ids: @@ -228,5 +235,5 @@ def set_items_embed(self, ids: list[int]) -> None: for id_ in ids: self._id_to_records[id_].embed = True - def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: return [self._id_to_records.get(id_).embed if id_ in self._id_to_records else False for id_ in ids] diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 5de4df4ab3..faf48c4085 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -6,6 +6,7 @@ from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from rpyc.utils.classic import obtain from lightllm.utils.envs_utils import get_unique_server_name @@ -47,9 +48,16 @@ def exposed_set_items_embed(self, ids: list[int]) -> None: ids = obtain(ids) return self._impl.set_items_embed(ids) - def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]: ids = obtain(ids) - return self._impl.get_items_embed(ids) + return self._impl.get_items_embed(ids, embeding_only) + + +def get_cache_manager(args): + if args.enable_remote_vit or args.run_mode == "visual": + return MemoryCacheWithRedis(args) + else: + return InMemoryCache(args) def start_cache_manager(args: StartArgs, pipe_writer): @@ -57,7 +65,7 @@ def start_cache_manager(args: StartArgs, pipe_writer): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::cache_manager") - manager = InMemoryCache(args) + manager = get_cache_manager(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 367bcc91a9..caeca0b2b6 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,4 +1,59 @@ +import os +import time +import torch +import redis +import numpy as np +from typing import List, Tuple +from io import BytesIO +from pathlib import Path import multiprocessing.shared_memory as shm +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def _get_afs_path(base_dir: str, name: str) -> Path: + if not base_dir: + raise ValueError("image_embed_dir must be set before using disk-backed embed cache") + return Path(base_dir) / name + + +def tensor2bytes(t: torch.Tensor): + buf = BytesIO() + t = t.detach().cpu() + dest = torch.empty_like(t) + dest.copy_(t) + torch.save(dest, buf, _use_new_zipfile_serialization=False, pickle_protocol=4) + buf.seek(0) + return buf.read() + + +def bytes2tensor(b): + return torch.load(BytesIO(b), weights_only=False) + + +def save_tensor_afs(name: str, tensor: torch.Tensor, base_dir: str) -> None: + target_path = _get_afs_path(base_dir, name) + tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" + + try: + with open(tmp_path, "wb") as f: + torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4) + os.replace(tmp_path, target_path) + os.chmod(target_path, 0o777) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to save embed tensor file: {target_path}") + raise + + +def load_tensor_afs(name: str, base_dir: str) -> torch.Tensor: + path = _get_afs_path(base_dir, name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) def create_shm(name, data): @@ -11,17 +66,388 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) +def create_afs(name, data, path): + target_path = _get_afs_path(path, name) + data_size = len(data) + tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" + + try: + with open(tmp_path, "wb") as f: + mem_view = memoryview(data) + f.write(mem_view[:data_size]) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, target_path) + os.chmod(target_path, 0o777) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to create embed file: {target_path}") + raise + + def read_shm(name): shared_memory = shm.SharedMemory(name=name) data = shared_memory.buf.tobytes() return data +def read_afs(name: str, base_dir) -> bytes: + path = _get_afs_path(base_dir, name) + return path.read_bytes() + + def free_shm(name): shared_memory = shm.SharedMemory(name=name) shared_memory.close() shared_memory.unlink() +def free_afs(name: str, base_dir) -> None: + path = _get_afs_path(base_dir, name) + path.unlink(missing_ok=True) + + def get_shm_name_data(uid): return str(uid) + "-data" + + +def get_shm_name_embed(uid): + return str(uid) + "-embed" + + +""" +Importable Redis-backed MD5 refcount with LRU eviction. + +Public API: + from md5_refcount import EmbedRefCountRedis + + cache = EmbedRefCountRedis( + redis_url="redis://localhost:6379/0", + capacity=10000, + evict_fraction=0.2 + ) + + # Insert a new md5 with default ref_count=0 + success, evicted_list = cache.insert(md5) + + # Query if exists and increment ref_count if found + exists = cache.query_and_incre(md5) + + # Decrement ref_count + rc, deleted = cache.decr(md5) + + s = cache.stats() +""" + + +class EmbedRefCountRedis: + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + capacity: int = 50000, + evict_fraction: float = 0.1, + key_prefix: str = "md5:", + image_embed_dir: str = None, + path_ext: str = "-embed", + **redis_kwargs, + ) -> None: + """ + - capacity: max count of md5 entries allowed in Redis + - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity + - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") + - path_ext: file extension for embed files (default: "-embed") + """ + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >=1") + + self.capacity = int(capacity) + self.evict_fraction = float(evict_fraction) + self.zset_key = f"{key_prefix}lru" + self.ref_prefix = f"{key_prefix}rc:" + self.lock_key = f"{key_prefix}evict:lock" + self.image_embed_dir = image_embed_dir + self.path_ext = path_ext + + self.r = redis.Redis.from_url(redis_url, decode_responses=True, **redis_kwargs) + + # Register Lua scripts + self._insert_script = self.r.register_script(self._INSERT_LUA) + self._query_incre_script = self.r.register_script(self._QUERY_INCRE_LUA) + self._decr_script = self.r.register_script(self._DECR_LUA) + self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) + + def insert(self, md5: str) -> Tuple[bool, List[str]]: + """Insert a new md5 with default ref_count=1. May trigger LRU eviction.""" + # 等待任何正在进行的逐出操作 + self._wait_if_eviction() + + res = self._insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + + if res[0] == 0: # No eviction needed + return True, [] + + # Need eviction - use atomic eviction script + try: + if self._try_acquire_lock(): + try: + # 原子执行逐出和插入 + evict_res = self._evict_and_insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + success = bool(evict_res[0]) + victims = evict_res[1:] if len(evict_res) > 1 else [] + + if success: + # 删除被逐出md5对应的AFS文件 + if victims and self.image_embed_dir: + self._delete_afs_files(victims) + return True, victims + else: + # 逐出失败,短暂退避后重试 + time.sleep(0.01) + return self.insert(md5) + finally: + self._release_lock() + else: + # 等待锁释放后重试 + time.sleep(0.01) + return self.insert(md5) + except Exception as e: + self._release_lock() + raise e + + def query(self, md5: str) -> bool: + """Quert if md5 exists.""" + self._wait_if_eviction() + return bool(self.r.exists(self.ref_prefix + md5)) + + def query_and_incre(self, md5: str) -> bool: + """Query if md5 exists and increment ref_count if found.""" + self._wait_if_eviction() + res = self._query_incre_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + return bool(res[0]) + + def decr(self, md5: str) -> Tuple[int, bool]: + """Decrement ref_count for md5. Returns (ref_count, deleted).""" + self._wait_if_eviction() + + res = self._decr_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + if res[0] == -1: + raise KeyError("md5 not found") + return int(res[0]), bool(res[1]) + + def stats(self) -> dict: + self._wait_if_eviction() + + size = self.r.zcard(self.zset_key) + return { + "items": size, + "capacity": self.capacity, + "evict_fraction": self.evict_fraction, + } + + def get_ref(self, md5: str) -> int | None: + self._wait_if_eviction() + val = self.r.get(self.ref_prefix + md5) + return int(val) if val is not None else None + + def _wait_if_eviction(self) -> None: + max_wait = 30 + start_time = time.time() + + while self.r.exists(self.lock_key): + if time.time() - start_time > max_wait: + raise TimeoutError("Eviction operation timeout, waited too long") + time.sleep(0.01) # 短暂等待 + + def _try_acquire_lock(self) -> bool: + return bool(self.r.set(self.lock_key, "1", nx=True, ex=30)) + + def _release_lock(self) -> None: + try: + self.r.delete(self.lock_key) + except Exception: + pass + + def _md5_to_afs_path(self, md5: str) -> str: + """Convert md5 to AFS file path.""" + if not self.image_embed_dir: + return None + return str(_get_afs_path(self.image_embed_dir, f"{md5}{self.path_ext}")) + + def _delete_afs_files(self, victims: List[str]) -> None: + """Delete AFS files for evicted md5s.""" + if not self.image_embed_dir: + return + + for md5 in victims: + try: + file_path = self._md5_to_afs_path(md5) + if file_path and os.path.exists(file_path): + os.remove(file_path) + logger.debug(f"Deleted AFS file: {file_path}") + except Exception as e: + logger.debug(f"Warning: Failed to delete AFS file for {md5}: {e}") + + # ---------------- Lua scripts ---------------- + _INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) + +local unpack = unpack or table.unpack +local ref_key = ref_prefix .. md5 +if redis.call('GET', ref_key) then + return {0} -- Already exists +end + +local size = redis.call('ZCARD', zset) +if size < capacity then + -- Insert with ref_count=1 + redis.call('SET', ref_key, 1) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) + return {0} -- Success, no eviction +end + +return {1} -- Need eviction +""" + + _QUERY_INCRE_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {0} -- Not found +end + +-- Found, increment ref_count and update LRU +local rc = tonumber(val) + 1 +redis.call('SET', ref_key, rc) +local now = redis.call('TIME')[1] * 1000 +redis.call('ZADD', zset, now, md5) +return {1} -- Found and incremented +""" + + _DECR_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {-1, 0} -- Not found +end + +--ref 递减到 0 时保留键,只更新计数与 LRU +local rc = tonumber(val) - 1 +if rc < 0 then rc = 0 end +redis.call('SET', ref_key, rc) + +if rc > 0 then + -- 只有仍被引用时才更新 LRU + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) +end + +return {rc, 0} +""" + + _EVICT_AND_INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = new_md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local new_md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) +local evict_fraction = tonumber(ARGV[3]) + +local unpack = unpack or table.unpack + +-- helper: now millis +local function now_ms() + local t = redis.call('TIME') + return t[1] * 1000 + math.floor(t[2] / 1000) +end + +local new_ref_key = ref_prefix .. new_md5 + +-- If already exists, treat as a hit: bump ref_count and refresh LRU +local cur = redis.call('GET', new_ref_key) +if cur then + local rc = tonumber(cur) + 1 + redis.call('SET', new_ref_key, rc) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- If not at capacity, just insert +local size = redis.call('ZCARD', zset) +if size < capacity then + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed +local max_try = math.max(1, math.floor(size * evict_fraction + 0.5)) +local victims = {} +local freed = 0 + +-- Scan from LRU (smallest score) to MRU +local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') +local i = 1 +while freed < 1 and i <= #all_keys and #victims < max_try do + local md5 = all_keys[i] + local ref_key = ref_prefix .. md5 + local v = redis.call('GET', ref_key) + if v and tonumber(v) <= 0 then + table.insert(victims, md5) + freed = freed + 1 + end + i = i + 2 -- skip score +end + +if freed >= 1 then + -- delete victims + for _, v in ipairs(victims) do + redis.call('DEL', ref_prefix .. v) + redis.call('ZREM', zset, v) + end + -- insert new + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1, unpack(victims)} +else + -- no zero-ref items found + return {0} +end +""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e28e4c93ad..8b7c85aeae 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -82,14 +82,14 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if not self.args.disable_vision: + from lightllm.server.visualserver.vit_connect import VITConnectionManager - if not self.args.disable_vision: - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) - if not self.args.disable_audio: - self.send_to_audio = context.socket(zmq.PUSH) - self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + if not self.args.disable_audio: + self.send_to_audio = context.socket(zmq.PUSH) + self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") if args.enable_cpu_cache and not self.args.enable_multimodal: self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) @@ -124,10 +124,10 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _alloc_resource(self, items, md5sums, token_nums, datas): + async def _alloc_resource(self, items, uuids, token_nums, datas): while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + records = obtain(self.cache_client.root.alloc(uuids, token_nums)) if records is None: await asyncio.sleep(0.1) @@ -147,6 +147,12 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): uid_list.append(rec["id"]) + # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + if self.args.enable_remote_vit: + # 避免远端lru被逐出 + self.cache_client.root.get_items_embed(uid_list, False) + return + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -166,14 +172,15 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] + items, uuids, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -181,13 +188,17 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(audio) - await self._alloc_resource(items, md5sums, tokens_nums, datas) + await self._alloc_resource(items, uuids, tokens_nums, datas) return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): @@ -408,6 +419,48 @@ async def generate( raise e return + async def get_image_embeding( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + original_multimodal_params = None + if self.is_multinode_tp_master: + original_multimodal_params = copy.deepcopy(multimodal_params) + + await multimodal_params.verify_and_preload(request) + image_count = len(multimodal_params.images) + # 记录请求到达的相关信息 + + await self._log_req_header(request_headers, group_request_id) + logger.info(f"image_count:{image_count}") + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + + visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) + + await self.transfer_to_next_module_or_node( + None, sampling_params, original_multimodal_params, visual_req_status + ) + await self._release_multimodal_resources(multimodal_params) + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self._release_multimodal_resources(multimodal_params) + await self.abort(group_request_id) + raise e + return + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -538,23 +591,37 @@ async def transfer_to_next_module( ): if self.pd_mode.is_P_or_NORMAL(): - if not self.args.disable_vision: - self.send_to_visual.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) - return + group_req_index = group_req_objs.to_group_req_index() + has_images = len(group_req_index.multimodal_params.images) > 0 + has_audios = len(group_req_index.multimodal_params.audios) > 0 - if not self.args.disable_audio: - self.send_to_audio.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) + if has_images and not self.args.disable_vision: + free_mode = "all" + if self.args.enable_remote_vit and has_audios and not self.args.disable_audio: + free_mode = "images" + + await self.vit_manager.send_to_vit( + group_req_index, protocol=pickle.HIGHEST_PROTOCOL, free_mode=free_mode + ) + + if not self.args.enable_remote_vit: + return + + if has_audios and not self.args.disable_audio: + self.send_to_audio.send_pyobj(group_req_index, protocol=pickle.HIGHEST_PROTOCOL) + if self.args.enable_remote_vit: + group_req_index.multimodal_params.free() return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( - group_req_objs.to_group_req_index(), + group_req_index, protocol=pickle.HIGHEST_PROTOCOL, ) return self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), + group_req_index, protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -753,6 +820,9 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) + if hasattr(self, "vit_manager"): + asyncio.create_task(self.vit_manager.vit_handle_loop()) + while True: try: await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 09a07455b3..0e6e9eca1d 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -26,6 +26,7 @@ def __init__(self, **kwargs): self.token_num = None # the audio length self.audio_length = None + self.afs_embed = False self._preload_data = None self.extra_params = {} @@ -54,10 +55,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -95,6 +97,7 @@ def __init__(self, **kwargs): self.grid_thwd = None self.image_w = 0 self.image_h = 0 + self.patch_num = 0 self._preload_data = None self.extra_params = {} @@ -128,10 +131,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -163,6 +167,23 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + + def free_images(self): + for image in self.images: + image.free() + + def free_audios(self): + for audio in self.audios: + audio.free() + + def get_all_uuids(self): + return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios] + async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0a83b101be..c3134eebcc 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -59,8 +59,8 @@ def register( self.vocab_size = vocab_size return - def init_cpu_embed_cache_client(self): - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + def init_cpu_embed_cache_client(self, init_shm_data: bool = False): + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=init_shm_data) return def get_overlap_stream(self) -> torch.cuda.Stream: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..eb6fd904f2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -145,7 +145,7 @@ def init_model(self, kvargs): wait_events.append(self.multi_level_cache_module) if self.args.enable_multimodal: - g_infer_context.init_cpu_embed_cache_client() + g_infer_context.init_cpu_embed_cache_client(init_shm_data=False) model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..508860e899 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -20,7 +20,7 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name from rpyc.utils.classic import obtain - +from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data logger = init_logger(__name__) @@ -31,13 +31,16 @@ def __init__( args: StartArgs, visual_model_rpc_ports, ): + self.args = args + self.visual_only = args.run_mode in ["visual", "visual_only"] + self.remote_vit = args.enable_remote_vit or self.visual_only + context = zmq.Context(2) - enable_audio = not args.disable_audio - if enable_audio: - self.send_to_next_module = context.socket(zmq.PUSH) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") - else: - if args.enable_cpu_cache: + if not self.visual_only: + if not args.disable_audio: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + elif args.enable_cpu_cache: self.send_to_next_module = context.socket(zmq.PUSH) self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") else: @@ -45,7 +48,11 @@ def __init__( self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.zmq_recv_socket = context.socket(zmq.PULL) - self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + if self.remote_vit: + self.zmq_recv_socket.bind(f"tcp://*:{args.remote_vit_port}") + else: + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port @@ -56,13 +63,11 @@ def __init__( self.vit_tp = args.visual_tp self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code - self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() async def wait_to_model_ready(self): - self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): @@ -146,13 +151,12 @@ def flush_ready(force: bool = False): continue multimodal_params = group_req_indexes.multimodal_params - img_uuids = [img.uuid for img in multimodal_params.images] # disable prompt cache通常用来测试,需要也去掉image cache的影响 if disable_prompt_cache: ready_image = [False] * len(img_uuids) else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids, True)) for img, ready in zip(multimodal_params.images, ready_image): if not ready: @@ -180,6 +184,43 @@ def flush_ready(force: bool = False): processing_group_reqs = [] flush_ready(force=True) + async def _recv_reqs(self): + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if not self.remote_vit: + return recv_req + + uuids = [img.uuid for img in recv_req.multimodal_params.images] + already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids, True) + if all(already_embed): + return None + + missing_uuids = [] + token_nums = [] + datas = [] + for img, embed_ready in zip(recv_req.multimodal_params.images, already_embed): + if embed_ready: + continue + missing_uuids.append(img.uuid) + token_nums.append(img.token_num) + datas.append(img.read()) + img.free() + + while True: + if await asyncio.to_thread(self.cache_client.root.alloc, missing_uuids, token_nums) is not None: + break + await asyncio.sleep(0.01) + + ready_flags = obtain(self.cache_client.root.get_items_data(missing_uuids)) + update_data_ids = [] + for uid, ready, data in zip(missing_uuids, ready_flags, datas): + if not ready: + create_shm(get_shm_name_data(uid), data) + update_data_ids.append(uid) + + if update_data_ids: + await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids) + return recv_req + async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -187,7 +228,9 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req = await self._recv_reqs() + if recv_req is None: + continue if isinstance(recv_req, GroupReqIndexes): logger.info( f"visual recv req id {recv_req.group_req_id} " @@ -196,12 +239,31 @@ async def loop_for_netio_req(self): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) + self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 await asyncio.sleep(0.01) + async def loop_for_fwd_visual_only(self): + while True: + if len(self.waiting_reqs) == 0: + await asyncio.sleep(0.01) + continue + + images_need_infer = [] + while len(self.waiting_reqs) > 0: + visual_req = self.waiting_reqs.pop(0) + for img in visual_req.multimodal_params.images: + images_need_infer.append(img) + if len(images_need_infer) == self.infer_batch_size: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + + if len(images_need_infer) > 0: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -210,17 +272,29 @@ def clean_up(self): return +def create_forward_loop(args, visualserver: VisualManager, loop: asyncio.AbstractEventLoop): + if args.run_mode in ["visual", "visual_only"]: + from .register_loop import register_loop + + loop.create_task(visualserver.loop_for_fwd_visual_only()) + loop.create_task(register_loop(args)) + else: + loop.create_task(visualserver.loop_for_fwd()) + + def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() + visualserver = None try: visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) - visualserver.clean_up() + if visualserver is not None: + visualserver.clean_up() raise e pipe_writer.send("init ok") @@ -231,6 +305,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) + create_forward_loop(args, visualserver, loop) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 6355ac2dbf..8c94b821d1 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -6,25 +6,27 @@ import inspect from datetime import timedelta from typing import Dict, List, Tuple -from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer +from transformers.configuration_utils import PretrainedConfig + from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer -from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel +from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.server.embed_cache.utils import create_afs, get_shm_name_embed, tensor2bytes, save_tensor_afs from lightllm.server.visualserver import set_vit_att_backend @@ -34,6 +36,7 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist + self.args = get_env_start_args() self.vit_dp = kvargs["vit_dp"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -41,6 +44,11 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] + self.image_embed_dir = self.args.image_embed_dir + self.remote_vit = self.args.enable_remote_vit or self.args.run_mode in ["visual", "visual_only"] + if self.remote_vit and not self.image_embed_dir: + raise ValueError("remote vit mode requires image_embed_dir") + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] @@ -56,6 +64,7 @@ def exposed_init_model(self, kvargs): "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], "max_batch_size": kvargs["max_batch_size"], + "remote_vit": self.remote_vit, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -92,10 +101,10 @@ def exposed_init_model(self, kvargs): ) else: raise Exception(f"can not support {self.model_type} now") - self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + if not self.remote_vit: + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -116,33 +125,47 @@ def forward(self, images: List[ImageItem]): def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) - - if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] + + if self.tp_rank_id != 0: + return + + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids, True)) + ids_to_set = [] + cpu_embeds = None + if self.remote_vit: + cpu_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True) + + for i, ready in enumerate(ready_flags): + if ready: + continue + uid = uuids[i] + start, end = valid_ids[i] + image = images[i] + if self.remote_vit: + save_tensor_afs(get_shm_name_embed(uid), cpu_embeds[start:end], self.image_embed_dir) + else: self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + embed_tensor=all_img_embeds[start:end], + start_index_in_cache=image.start_index_in_embed_cache, ) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) + ids_to_set.append(uid) + + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) + if not self.remote_vit: torch.cuda.current_stream().synchronize() return class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc + def __init__(self, conn, vit_tp, rpc_server_process=None): + self.conn = conn + self.model: VisualModelRpcServer = conn.root self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process self.use_rpc = True + self._bg = rpyc.BgServingThread(self.conn) + if self.use_rpc: def async_wrap(f): @@ -161,15 +184,12 @@ async def _func(*args, **kwargs): else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode - return async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) if self.use_rpc: await ans return - else: - return async def encode(self, images: List[ImageItem]): ans = self._encode(images) @@ -215,4 +235,4 @@ async def start_model_process(port, vit_tp, device_id): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc) diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py new file mode 100644 index 0000000000..31d0f7b8ac --- /dev/null +++ b/lightllm/server/visualserver/register_loop.py @@ -0,0 +1,42 @@ +import asyncio +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.log_utils import init_logger +from .vit_connect import VIT_Obj + +logger = init_logger(__name__) + + +async def register_loop(args): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.remote_vit_port}") + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py new file mode 100644 index 0000000000..c30ba65cba --- /dev/null +++ b/lightllm/server/visualserver/vit_connect.py @@ -0,0 +1,237 @@ +import asyncio +import base64 +import pickle +import time +from dataclasses import dataclass +from typing import Dict, Optional + +import httpx +import rpyc +import zmq +import zmq.asyncio + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.io_objs import GroupReqIndexes + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + +class VITConnectionManager: + """VIT连接管理器""" + + def __init__(self, args, context, local_visual_port: int, cache_client: rpyc.Connection): + self.args = args + self.context = context + self.local_visual_port = local_visual_port + + self.send_to_visual = None + self.remote_vit_instances = {} + self.current_vit_index = 0 + self.remote_vit = args.enable_remote_vit + self.remote_vit_port = args.remote_vit_port + self.cache_client = cache_client + + self._setup_vit_connections() + + def _setup_vit_connections(self): + """ + 设置VIT连接,支持本地和远程VIT实例 + 支持多种连接模式: + 1. 本地VIT实例 (默认) + 2. 远程多个VIT实例 (负载均衡) + """ + if self.remote_vit: + # 远程VIT实例模式 + self._setup_remote_vit_connections() + else: + self._setup_local_vit_connection() + + def _setup_local_vit_connection(self): + self.send_to_visual = self.context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + + def _setup_remote_vit_connections(self): + """ + 初始化远程VIT连接,同步获取初始实例 + """ + logger.info("Setting up remote VIT connections...") + + self._sync_init_vit_instances() + + retry_count = 0 + max_retries = 30 # 最多等待30秒 + while len(self.remote_vit_instances) == 0 and retry_count < max_retries: + logger.info(f"Waiting for VIT instances... (attempt {retry_count + 1}/{max_retries})") + time.sleep(1) + retry_count += 1 + self._sync_init_vit_instances() + + if len(self.remote_vit_instances) == 0: + logger.warning("No VIT instances available after initialization") + else: + logger.info(f"Successfully connected to {len(self.remote_vit_instances)} VIT instances") + + def _sync_init_vit_instances(self): + """ + 同步初始化VIT实例连接 + """ + try: + # 使用同步方式获取VIT实例 + vit_objs = self._sync_get_vit_objs() + if vit_objs: + self._update_vit_connections(vit_objs) + except Exception as e: + logger.error(f"Failed to initialize VIT instances: {e}") + + def _sync_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 同步获取VIT实例信息 + """ + import requests + + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + response = requests.get(uri, timeout=10) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.error(f"Error getting VIT instances: {e}") + return None + + def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): + """ + 更新VIT连接,添加新的连接,关闭失效的连接 + """ + # 关闭不再存在的连接 + closed_ids = [] + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance.close() + except: + pass + closed_ids.append(id) + logger.info(f"Closed VIT connection {id}") + + for id in closed_ids: + self.remote_vit_instances.pop(id) + + # 建立新的连接 + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + try: + socket = self.context.socket(zmq.PUSH) + # print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) + ip, port = vit_obj.host_ip_port.split(":") + socket.connect(f"tcp://{ip}:{port}") + self.remote_vit_instances[id] = socket + logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") + except Exception as e: + logger.error(f"Failed to connect to VIT instance {id}: {e}") + + def _get_vit_instance(self): + """ + 获取下一个可用的VIT实例 (轮询负载均衡) + """ + if not self.remote_vit: + return self.send_to_visual + + if len(self.remote_vit_instances) == 0: + raise Exception("No available VIT instances") + + # 简单的轮询负载均衡 + index = (self.current_vit_index + 1) % len(self.remote_vit_instances) + self.current_vit_index = index + return list(self.remote_vit_instances.values())[index] + + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, free_mode: str = "all"): + """ + 发送数据到VIT实例,支持本地和远程模式 + """ + instance = self._get_vit_instance() + + try: + instance.send_pyobj(req, protocol=protocol) + except Exception as e: + logger.error(f"Failed to send to VIT instance: {e}") + raise Exception(f"Failed to send to VIT instance: {e}") + + if self.remote_vit: + await self._wait_visual_embed_ready(req) + + if free_mode == "all": + req.multimodal_params.free() + elif free_mode == "images": + req.multimodal_params.free_images() + + async def vit_handle_loop(self): + """ + 异步VIT连接管理循环,由外部启动 + """ + if not self.remote_vit: + return + logger.info("Starting VIT connection management loop") + while True: + try: + id_to_vit_obj = await self._async_get_vit_objs() + if id_to_vit_obj: + self._update_vit_connections(id_to_vit_obj) + await asyncio.sleep(30) + except Exception as e: + logger.exception(f"Error in VIT handle loop: {e}") + await asyncio.sleep(10) + + async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 异步获取VIT实例信息 + """ + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.exception(f"Error getting VIT instances: {e}") + return None + + async def _wait_visual_embed_ready( + self, + req: GroupReqIndexes, + timeout_seconds: int = 1000, + ): + # 本地模式不需要等待 + if not self.remote_vit: + return + uuids = req.multimodal_params.get_all_uuids() + + async def wait_for_embeds(): + while not all(self.cache_client.root.get_items_embed(uuids, True)): + await asyncio.sleep(0.01) + + try: + await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) + except asyncio.TimeoutError: + raise TimeoutError( + f"Req {req.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" + ) diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 0000000000..acc4deb589 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,74 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + if not hasattr(args, "start_redis") or not args.start_redis: + return None + + config_server_host = args.config_server_host + redis_port = args.redis_port + try: + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + + try: + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{config_server_host}", + "--daemonize", + "no", + "--logfile", + "/dev/stdout", + "--loglevel", + "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception: + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None diff --git a/lightllm/utils/start_utils.py b/lightllm/utils/start_utils.py index 372b7e1cfa..8245431084 100644 --- a/lightllm/utils/start_utils.py +++ b/lightllm/utils/start_utils.py @@ -111,4 +111,12 @@ def kill_recursive(proc): logger.warning(f"Process {proc.pid} does not exist.") +def is_multimodal_mode(args): + from transformers import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir) + is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg + return is_multimodal + + process_manager = SubmoduleManager() diff --git a/requirements.txt b/requirements.txt index 5b0b201ae3..5331227586 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,3 +94,5 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +xformers==0.0.33.post2 +redis==7.3.0 From 9f434002eefaf4cc666c274731280ba82eff35fd Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 20 Mar 2026 18:16:29 +0800 Subject: [PATCH 02/54] refine --- .../qwen_vl/layer_infer/pre_layer_infer.py | 12 +++---- lightllm/server/api_http.py | 4 +-- lightllm/server/api_server.py | 2 +- lightllm/server/api_start.py | 19 ++++++----- lightllm/server/embed_cache/manager.py | 2 +- lightllm/server/httpserver/manager.py | 32 ++++++------------- lightllm/server/multimodal_params.py | 19 ++--------- lightllm/server/visualserver/manager.py | 2 +- lightllm/server/visualserver/vit_connect.py | 9 ++---- lightllm/utils/start_utils.py | 8 ----- 10 files changed, 33 insertions(+), 76 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 0127fbea8b..96f9fd26f4 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -31,7 +31,6 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) self.args = get_env_start_args() - self.cache_client = None if self.args.enable_remote_vit: self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -52,12 +51,14 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_token_lens = [] img_start_locs_in_cache = [] unique_uids = [] + all_uids = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.weight.shape[1] for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: + all_uids.append(img["uuid"]) # skip the same image if img["token_id"] in img_start_token_ids: continue @@ -77,17 +78,12 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei ) if self.args.enable_remote_vit: - release_ids = [] - for _, p in enumerate(infer_state.multimodal_params): - for img in p["images"] + p["audios"]: - release_ids.append(img["uuid"]) - for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache): embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir) self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache) - if release_ids: - self.cache_client.root.release(release_ids) + if all_uids: + self.cache_client.root.release(all_uids) assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index bf246f8f0d..c803de7db3 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -92,7 +92,7 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManagerForPDMaster( args=args, ) - elif args.run_mode == "visual": + elif args.run_mode in ["visual", "visual_only"]: self.metric_client = MetricClient(args.metric_port) else: init_tokenizer(args) # for openai api @@ -138,7 +138,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode in ["pd_master", "visual"]: + if g_objs.args.run_mode in ["pd_master", "visual", "visual_only"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index c6700c0416..7542f7be6c 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -11,7 +11,7 @@ pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) - elif args.run_mode == "visual": + elif args.run_mode in ["visual", "visual_only"]: visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index ef1a521d76..dc696e9171 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -5,7 +5,7 @@ import subprocess import signal from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker -from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode +from lightllm.utils.start_utils import process_manager, kill_recursive from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger @@ -194,7 +194,6 @@ def normal_or_p_d_start(args, only_prepare=False): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 - args.enable_multimodal = is_multimodal_mode(args) _prepare_remote_vit_embed_dir(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: @@ -355,27 +354,27 @@ def normal_or_p_d_start(args, only_prepare=False): start_args=[(args,)], ) - if not args.disable_audio: - from .audioserver.manager import start_audio_process + if not args.disable_vision and not args.enable_remote_vit: + from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( start_funcs=[ - start_audio_process, + start_visual_process, ], start_args=[ - (args,), + (args, visual_model_tp_ports), ], ) - if not args.disable_vision and not args.enable_remote_vit: - from .visualserver.manager import start_visual_process + if not args.disable_audio: + from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( start_funcs=[ - start_visual_process, + start_audio_process, ], start_args=[ - (args, visual_model_tp_ports), + (args,), ], ) diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index faf48c4085..ebf57f6594 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -54,7 +54,7 @@ def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) - def get_cache_manager(args): - if args.enable_remote_vit or args.run_mode == "visual": + if args.enable_remote_vit or args.run_mode in ["visual", "visual_only"]: return MemoryCacheWithRedis(args) else: return InMemoryCache(args) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 8b7c85aeae..d57becefa2 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -82,14 +82,15 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - if not self.args.disable_vision: - from lightllm.server.visualserver.vit_connect import VITConnectionManager - self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) + if not self.args.disable_vision: + from lightllm.server.visualserver.vit_connect import VITConnectionManager - if not self.args.disable_audio: - self.send_to_audio = context.socket(zmq.PUSH) - self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) + + if not self.args.disable_audio: + self.send_to_audio = context.socket(zmq.PUSH) + self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") if args.enable_cpu_cache and not self.args.enable_multimodal: self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) @@ -151,7 +152,6 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): if self.args.enable_remote_vit: # 避免远端lru被逐出 self.cache_client.root.get_items_embed(uid_list, False) - return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -592,25 +592,13 @@ async def transfer_to_next_module( if self.pd_mode.is_P_or_NORMAL(): group_req_index = group_req_objs.to_group_req_index() - has_images = len(group_req_index.multimodal_params.images) > 0 - has_audios = len(group_req_index.multimodal_params.audios) > 0 - - if has_images and not self.args.disable_vision: - free_mode = "all" - if self.args.enable_remote_vit and has_audios and not self.args.disable_audio: - free_mode = "images" - - await self.vit_manager.send_to_vit( - group_req_index, protocol=pickle.HIGHEST_PROTOCOL, free_mode=free_mode - ) - + if not self.args.disable_vision: + await self.vit_manager.send_to_vit(group_req_index, protocol=pickle.HIGHEST_PROTOCOL) if not self.args.enable_remote_vit: return - if has_audios and not self.args.disable_audio: + if not self.args.disable_audio: self.send_to_audio.send_pyobj(group_req_index, protocol=pickle.HIGHEST_PROTOCOL) - if self.args.enable_remote_vit: - group_req_index.multimodal_params.free() return if self.args.enable_cpu_cache: diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 0e6e9eca1d..7afd95d3c0 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -26,7 +26,6 @@ def __init__(self, **kwargs): self.token_num = None # the audio length self.audio_length = None - self.afs_embed = False self._preload_data = None self.extra_params = {} @@ -55,11 +54,10 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - return self._preload_data - - def free(self): + ans = self._preload_data self._preload_data = None self._data = None + return ans def to_dict(self): ret = {} @@ -167,23 +165,10 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return - def free(self): - for image in self.images: - image.free() - for audio in self.audios: - audio.free() - def free_images(self): for image in self.images: image.free() - def free_audios(self): - for audio in self.audios: - audio.free() - - def get_all_uuids(self): - return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios] - async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 508860e899..607888f6b7 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -239,7 +239,7 @@ async def loop_for_netio_req(self): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256) + self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index c30ba65cba..9720eb6698 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -159,7 +159,7 @@ def _get_vit_instance(self): self.current_vit_index = index return list(self.remote_vit_instances.values())[index] - async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, free_mode: str = "all"): + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL): """ 发送数据到VIT实例,支持本地和远程模式 """ @@ -174,10 +174,7 @@ async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOC if self.remote_vit: await self._wait_visual_embed_ready(req) - if free_mode == "all": - req.multimodal_params.free() - elif free_mode == "images": - req.multimodal_params.free_images() + req.multimodal_params.free_images() async def vit_handle_loop(self): """ @@ -223,7 +220,7 @@ async def _wait_visual_embed_ready( # 本地模式不需要等待 if not self.remote_vit: return - uuids = req.multimodal_params.get_all_uuids() + uuids = [image.uuid for image in req.multimodal_params.images] async def wait_for_embeds(): while not all(self.cache_client.root.get_items_embed(uuids, True)): diff --git a/lightllm/utils/start_utils.py b/lightllm/utils/start_utils.py index 8245431084..372b7e1cfa 100644 --- a/lightllm/utils/start_utils.py +++ b/lightllm/utils/start_utils.py @@ -111,12 +111,4 @@ def kill_recursive(proc): logger.warning(f"Process {proc.pid} does not exist.") -def is_multimodal_mode(args): - from transformers import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir) - is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg - return is_multimodal - - process_manager = SubmoduleManager() From a9db45d86797ac043edfc8848e9b7d131a1d3998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Mon, 23 Mar 2026 15:58:18 +0800 Subject: [PATCH 03/54] fix --- lightllm/models/internvl/model.py | 2 -- .../qwen_vl/layer_infer/pre_layer_infer.py | 23 ++++++++------- lightllm/server/api_lightllm.py | 7 +++-- lightllm/server/api_start.py | 2 ++ .../impl/memory_cache_with_redis.py | 7 +++++ lightllm/server/httpserver/manager.py | 29 +++++++++++++++---- 6 files changed, 51 insertions(+), 19 deletions(-) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 3e5b9c5e2a..cbe9d4dfa6 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -56,10 +56,8 @@ def init_imageitem_extral_params( ): if sampling_params.image_max_patch_num > 0: img.extra_params["image_patch_max_num"] = sampling_params.image_max_patch_num - return elif os.getenv("MAX_PATCH_NUM"): img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM")) - return else: num_images = len(multi_params.images) if num_images == 1: diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 96f9fd26f4..867f2a7d3a 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -50,22 +50,18 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] - unique_uids = [] - all_uids = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.weight.shape[1] - for _, p in enumerate(infer_state.multimodal_params): + for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - all_uids.append(img["uuid"]) # skip the same image if img["token_id"] in img_start_token_ids: continue img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs_in_cache.append(img["start_index_in_embed_cache"]) - unique_uids.append(img["uuid"]) out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -78,12 +74,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei ) if self.args.enable_remote_vit: - for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache): - embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir) - self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache) + unique_image_uids = [] + for _, p in enumerate(infer_state.multimodal_params): + for img in p["images"]: + if img["uuid"] in unique_image_uids: + continue + img_uid = img["uuid"] + img_idx = img["start_index_in_embed_cache"] + unique_image_uids.append(img_uid) + embed_tensor = load_tensor_afs(get_shm_name_embed(img_uid), self.args.image_embed_dir) + self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, img_idx) - if all_uids: - self.cache_client.root.release(all_uids) + if unique_image_uids: + self.cache_client.root.release(unique_image_uids) assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index bfb8bff6db..3ad7b49daf 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -5,6 +5,7 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from lightllm.utils.envs_utils import get_env_start_args import ujson as json @@ -154,13 +155,15 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() - # request_dict: {'parameters': {'max_new_tokens': 128}, - # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} + args = get_env_start_args() + assert not args.disable_vision + assert args.enable_remote_vit sample_params_dict = request_dict["parameters"] sampling_params = SamplingParams() sampling_params.init(tokenizer=None, **sample_params_dict) sampling_params.verify() multimodal_params_dict = request_dict.get("multimodal_params", {}) + assert not multimodal_params_dict.get("audios") multimodal_params = MultimodalParams(**multimodal_params_dict) await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index dc696e9171..0d0e744fae 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -537,6 +537,8 @@ def visual_start(args): if args.visual_nccl_ports is not None: args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + else: + args.visual_nccl_ports = visual_nccl_ports args.router_port = router_port args.visual_port = visual_port diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 6ef8c66f68..f8d9da5314 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -54,6 +54,13 @@ def set_items_embed(self, ids: list[int]) -> None: rec = self._records.get(id) if rec is not None: rec.embed = True + # Before the embed becomes ready, concurrent miss requests are only + # tracked by the local record refcount. Materialize the remaining + # pending readers into Redis so each later release has a matching + # remote ref to consume. + pending_remote_readers = max(rec.ref - 1, 0) + for _ in range(pending_remote_readers): + self.redis_cache.query_and_incre(str(id)) if rec.ref > 0: self._update_record_ref_by_id(id, -1) # 保留一份 redis 引用,直到真正的消费者读取完成后再 release, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d57becefa2..1e96d4a737 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -139,6 +139,7 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): raise Exception(str(records) + "and try to set --embed_cache_storage_size bigger") uid_list = [] + unique_image_uids = [] for item, rec in zip(items, records): item: Union[ImageItem, AudioItem] = item item.uuid = rec["id"] @@ -147,11 +148,13 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): item.start_index_in_embed_cache = rec["start_index_in_embed_cache"] uid_list.append(rec["id"]) + if isinstance(item, ImageItem) and rec["id"] not in unique_image_uids: + unique_image_uids.append(rec["id"]) - # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + # # If enable the vit-llm disaggregation, no need to cache the data in the memory of the server if self.args.enable_remote_vit: # 避免远端lru被逐出 - self.cache_client.root.get_items_embed(uid_list, False) + self.cache_client.root.get_items_embed(unique_image_uids, False) ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -251,6 +254,16 @@ async def loop_for_request(self): sampling_params, multimodal_params, ) = await self.multinode_req_manager.recv_pyobj() + + # 多机tp下,slave节点收到/get_image_embedding请求,无prompt + if prompt is None: + + async def image_embedding_wrapper(sampling_params, multimodal_params): + await self.get_image_embeding(sampling_params, multimodal_params, None) + + asyncio.create_task(image_embedding_wrapper(sampling_params, multimodal_params)) + continue + results_generator = self.generate(prompt, sampling_params, multimodal_params, None) async def generate_wrapper(results_generator): @@ -450,7 +463,11 @@ async def get_image_embeding( visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) await self.transfer_to_next_module_or_node( - None, sampling_params, original_multimodal_params, visual_req_status + None, + sampling_params, + original_multimodal_params, + visual_req_status, + only_visual=True, ) await self._release_multimodal_resources(multimodal_params) @@ -573,6 +590,7 @@ async def transfer_to_next_module_or_node( sampling_params: SamplingParams, original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, + only_visual: bool = False, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: @@ -582,19 +600,20 @@ async def transfer_to_next_module_or_node( protocol=pickle.HIGHEST_PROTOCOL, ) - await self.transfer_to_next_module(group_req_objs) + await self.transfer_to_next_module(group_req_objs, only_visual=only_visual) return async def transfer_to_next_module( self, group_req_objs: Optional[GroupReqObjs] = None, + only_visual: bool = False, ): if self.pd_mode.is_P_or_NORMAL(): group_req_index = group_req_objs.to_group_req_index() if not self.args.disable_vision: await self.vit_manager.send_to_vit(group_req_index, protocol=pickle.HIGHEST_PROTOCOL) - if not self.args.enable_remote_vit: + if only_visual or not self.args.enable_remote_vit: return if not self.args.disable_audio: From 9ec3d88df6ba8a9181ad2616641780b1659f5cbf Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 08:41:03 +0000 Subject: [PATCH 04/54] fix --- lightllm/models/internvl/model.py | 15 ++++++++------- lightllm/server/multimodal_params.py | 10 ++-------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index cbe9d4dfa6..ccb76d3512 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -56,8 +56,10 @@ def init_imageitem_extral_params( ): if sampling_params.image_max_patch_num > 0: img.extra_params["image_patch_max_num"] = sampling_params.image_max_patch_num + return elif os.getenv("MAX_PATCH_NUM"): img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM")) + return else: num_images = len(multi_params.images) if num_images == 1: @@ -66,7 +68,6 @@ def init_imageitem_extral_params( img.extra_params["image_patch_max_num"] = 6 elif num_images > 6: img.extra_params["image_patch_max_num"] = 0 - img.patch_num = self.get_image_patch(img) return def init_audioitem_extral_params( @@ -74,13 +75,13 @@ def init_audioitem_extral_params( ): return - def get_image_patch(self, img: ImageItem): - return self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - def get_image_token_length(self, img: ImageItem): - return self.get_image_patch(img) * self.image_length + return ( + self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True + ) + * self.image_length + ) def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 7afd95d3c0..09a07455b3 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -95,7 +95,6 @@ def __init__(self, **kwargs): self.grid_thwd = None self.image_w = 0 self.image_h = 0 - self.patch_num = 0 self._preload_data = None self.extra_params = {} @@ -129,11 +128,10 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - return self._preload_data - - def free(self): + ans = self._preload_data self._preload_data = None self._data = None + return ans def to_dict(self): ret = {} @@ -165,10 +163,6 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return - def free_images(self): - for image in self.images: - image.free() - async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) From 155e8ee5fa08994955ac7342b5a0f3896c370783 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 08:46:52 +0000 Subject: [PATCH 05/54] fix --- .../common/basemodel/attention_vit/fa3/fp.py | 40 +++++++++++++++---- .../vit/triton_kernel/flashattention_nopad.py | 40 +++++++++++++++---- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index f1bef078a7..d3a5b5166b 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -18,18 +18,42 @@ def _vit_att_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - o = flash_attn_varlen_func( + torch.ops.sgl_kernel.fwd.default( q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - softmax_scale=softmax_scale, - causal=False, - window_size=window_size, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], attention_chunk=0, softcap=0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, ) + return o diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 3a0b2d2069..768ebd9139 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -167,18 +167,42 @@ def flash_attention_v3_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - o = flash_attn_varlen_func( + torch.ops.sgl_kernel.fwd.default( q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - softmax_scale=softmax_scale, - causal=False, - window_size=window_size, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + attention_chunk=0, softcap=0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, ) return o From d2e3919eaf4878fb82e891a321281afa2cb2fac4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 08:52:16 +0000 Subject: [PATCH 06/54] fix --- lightllm/server/visualserver_proxy/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 lightllm/server/visualserver_proxy/__init__.py diff --git a/lightllm/server/visualserver_proxy/__init__.py b/lightllm/server/visualserver_proxy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 0369dec023744d98543f80d6c20c89c27aceb037 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 09:02:22 +0000 Subject: [PATCH 07/54] fix --- lightllm/common/basemodel/attention_vit/fa3/fp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index d3a5b5166b..f804116f1f 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,7 +1,6 @@ import dataclasses import torch from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend -from lightllm.utils.sgl_utils import flash_attn_varlen_func class Fa3VitAttBackend(BaseVitAttBackend): From 958cfc162ee24f1a16fb2d0ea6018e1e728e7d54 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 09:04:24 +0000 Subject: [PATCH 08/54] fix --- .../qwen_vl/layer_infer/pre_layer_infer.py | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 867f2a7d3a..9b9fe2569c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,15 +1,11 @@ -import rpyc -import socket import torch import torch.distributed as dist from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce -from lightllm.utils.envs_utils import get_env_start_args """ @@ -30,20 +26,6 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) - self.args = get_env_start_args() - if self.args.enable_remote_vit: - self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) - self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - return - - def _copy_loaded_embed_to_cache( - self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int - ): - if embed_tensor.ndim == 2: - embed_tensor = embed_tensor.unsqueeze(1) - - token_num, layer_num, hidden_size = embed_tensor.shape - cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -73,21 +55,6 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei else cpu_embed_cache_client.cpu_embed_cache_tensor ) - if self.args.enable_remote_vit: - unique_image_uids = [] - for _, p in enumerate(infer_state.multimodal_params): - for img in p["images"]: - if img["uuid"] in unique_image_uids: - continue - img_uid = img["uuid"] - img_idx = img["start_index_in_embed_cache"] - unique_image_uids.append(img_uid) - embed_tensor = load_tensor_afs(get_shm_name_embed(img_uid), self.args.image_embed_dir) - self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, img_idx) - - if unique_image_uids: - self.cache_client.root.release(unique_image_uids) - assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" From d95cba6e3f27a594b7efe5190b0b4d82ffb72274 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 09:09:01 +0000 Subject: [PATCH 09/54] fix --- lightllm/server/core/objs/io_objs/group_req.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index 75f2c0e2f1..dfcbdd2562 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -23,9 +23,7 @@ def to_group_req_index(self): return GroupReqIndexes( group_req_id=self.group_req_id, multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs] - if self.shm_req_objs is not None - else None, + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], time_mark=self.time_mark, ) From 778c4831bb453add7c645b695912631c49450afc Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 09:29:18 +0000 Subject: [PATCH 10/54] fix --- .../embed_cache/impl/naive_memory_cache.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index f76cc5d78d..9251b87149 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -36,7 +36,6 @@ class InMemoryCache: def __init__(self, args) -> None: self.args = args self._id_to_records = dict() - self._records = self._id_to_records self._md5_to_record = dict() self._sorted_records = SortedSet(key=lambda x: (x.ref, x.visittime, x.id)) self.capacity = max(1, args.cache_capacity) @@ -169,13 +168,18 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l new_md5_dict[m] = token_need new_needed = len(new_md5_dict) + alloc_md5_dict = self._free_to_alloc( free_min_count=new_needed - (self.capacity - self.occupied), new_md5_dict=new_md5_dict ) + for md5 in add_ref_m_list: + # 解锁 + self._del_ref(md5) + if len(alloc_md5_dict) == len(new_md5_dict): for md5sum, mem_block in alloc_md5_dict.items(): token_num = new_md5_dict[md5sum] - uid_int = md5sum + uid_int = uuid.uuid1().int self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, @@ -195,10 +199,6 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l self._sorted_records.add(rec) self.occupied += 1 - for md5 in add_ref_m_list: - # 解锁 - self._del_ref(md5) - # 遍历加 ref results = [] for md5 in md5sum_list: @@ -215,8 +215,6 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l return results else: - for md5sum in add_ref_m_list: - self._del_ref(md5sum) return None def release(self, ids: list[int]) -> None: @@ -235,5 +233,5 @@ def set_items_embed(self, ids: list[int]) -> None: for id_ in ids: self._id_to_records[id_].embed = True - def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: return [self._id_to_records.get(id_).embed if id_ in self._id_to_records else False for id_ in ids] From d9baab30556778332d69d7e9847f19ca2fa4ee4e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 09:31:40 +0000 Subject: [PATCH 11/54] fix --- lightllm/server/embed_cache/manager.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index ebf57f6594..5de4df4ab3 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -6,7 +6,6 @@ from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache -from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from rpyc.utils.classic import obtain from lightllm.utils.envs_utils import get_unique_server_name @@ -48,16 +47,9 @@ def exposed_set_items_embed(self, ids: list[int]) -> None: ids = obtain(ids) return self._impl.set_items_embed(ids) - def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]: + def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: ids = obtain(ids) - return self._impl.get_items_embed(ids, embeding_only) - - -def get_cache_manager(args): - if args.enable_remote_vit or args.run_mode in ["visual", "visual_only"]: - return MemoryCacheWithRedis(args) - else: - return InMemoryCache(args) + return self._impl.get_items_embed(ids) def start_cache_manager(args: StartArgs, pipe_writer): @@ -65,7 +57,7 @@ def start_cache_manager(args: StartArgs, pipe_writer): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::cache_manager") - manager = get_cache_manager(args) + manager = InMemoryCache(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer import lightllm.utils.rpyc_fix_utils as _ From 05d5094308aabee81f3a0d51a634ac229de17d6a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 24 Mar 2026 09:42:03 +0000 Subject: [PATCH 12/54] fix --- lightllm/server/httpserver/manager.py | 109 +++--------------- .../server/router/model_infer/infer_batch.py | 4 +- .../model_infer/mode_backend/base_backend.py | 2 +- 3 files changed, 19 insertions(+), 96 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 1e96d4a737..e28e4c93ad 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -84,9 +84,8 @@ def __init__( self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if not self.args.disable_vision: - from lightllm.server.visualserver.vit_connect import VITConnectionManager - - self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") if not self.args.disable_audio: self.send_to_audio = context.socket(zmq.PUSH) @@ -125,10 +124,10 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _alloc_resource(self, items, uuids, token_nums, datas): + async def _alloc_resource(self, items, md5sums, token_nums, datas): while True: - records = obtain(self.cache_client.root.alloc(uuids, token_nums)) + records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) if records is None: await asyncio.sleep(0.1) @@ -139,7 +138,6 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): raise Exception(str(records) + "and try to set --embed_cache_storage_size bigger") uid_list = [] - unique_image_uids = [] for item, rec in zip(items, records): item: Union[ImageItem, AudioItem] = item item.uuid = rec["id"] @@ -148,13 +146,6 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): item.start_index_in_embed_cache = rec["start_index_in_embed_cache"] uid_list.append(rec["id"]) - if isinstance(item, ImageItem) and rec["id"] not in unique_image_uids: - unique_image_uids.append(rec["id"]) - - # # If enable the vit-llm disaggregation, no need to cache the data in the memory of the server - if self.args.enable_remote_vit: - # 避免远端lru被逐出 - self.cache_client.root.get_items_embed(unique_image_uids, False) ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -175,15 +166,14 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, uuids, tokens_nums, datas = [], [], [], [] + items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) - uuid = int(md5sum, 16) - uuids.append(uuid) + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) + md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -191,17 +181,13 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), - ) - uuid = int(md5sum, 16) - uuids.append(uuid) + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) + md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) items.append(audio) - await self._alloc_resource(items, uuids, tokens_nums, datas) + await self._alloc_resource(items, md5sums, tokens_nums, datas) return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): @@ -254,16 +240,6 @@ async def loop_for_request(self): sampling_params, multimodal_params, ) = await self.multinode_req_manager.recv_pyobj() - - # 多机tp下,slave节点收到/get_image_embedding请求,无prompt - if prompt is None: - - async def image_embedding_wrapper(sampling_params, multimodal_params): - await self.get_image_embeding(sampling_params, multimodal_params, None) - - asyncio.create_task(image_embedding_wrapper(sampling_params, multimodal_params)) - continue - results_generator = self.generate(prompt, sampling_params, multimodal_params, None) async def generate_wrapper(results_generator): @@ -432,52 +408,6 @@ async def generate( raise e return - async def get_image_embeding( - self, - sampling_params: SamplingParams, - multimodal_params: MultimodalParams, - request: Request, - is_health_req: bool = False, - ) -> Tuple[int, str, dict, FinishStatus]: - start_time = time.time() - request_headers = request.headers if request is not None else {} - group_request_id = self.alloc_req_id(sampling_params, is_health_req) - - try: - original_multimodal_params = None - if self.is_multinode_tp_master: - original_multimodal_params = copy.deepcopy(multimodal_params) - - await multimodal_params.verify_and_preload(request) - image_count = len(multimodal_params.images) - # 记录请求到达的相关信息 - - await self._log_req_header(request_headers, group_request_id) - logger.info(f"image_count:{image_count}") - assert ( - len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity - ), "too many multimodal items!" - - await self._alloc_multimodal_resources(multimodal_params, sampling_params) - - visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) - - await self.transfer_to_next_module_or_node( - None, - sampling_params, - original_multimodal_params, - visual_req_status, - only_visual=True, - ) - await self._release_multimodal_resources(multimodal_params) - - except Exception as e: - logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") - await self._release_multimodal_resources(multimodal_params) - await self.abort(group_request_id) - raise e - return - def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -590,7 +520,6 @@ async def transfer_to_next_module_or_node( sampling_params: SamplingParams, original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, - only_visual: bool = False, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: @@ -600,35 +529,32 @@ async def transfer_to_next_module_or_node( protocol=pickle.HIGHEST_PROTOCOL, ) - await self.transfer_to_next_module(group_req_objs, only_visual=only_visual) + await self.transfer_to_next_module(group_req_objs) return async def transfer_to_next_module( self, group_req_objs: Optional[GroupReqObjs] = None, - only_visual: bool = False, ): if self.pd_mode.is_P_or_NORMAL(): - group_req_index = group_req_objs.to_group_req_index() if not self.args.disable_vision: - await self.vit_manager.send_to_vit(group_req_index, protocol=pickle.HIGHEST_PROTOCOL) - if only_visual or not self.args.enable_remote_vit: - return + self.send_to_visual.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) + return if not self.args.disable_audio: - self.send_to_audio.send_pyobj(group_req_index, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_audio.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) return if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( - group_req_index, + group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return self.send_to_router.send_pyobj( - group_req_index, + group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return @@ -827,9 +753,6 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) - if hasattr(self, "vit_manager"): - asyncio.create_task(self.vit_manager.vit_handle_loop()) - while True: try: await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index c3134eebcc..0a83b101be 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -59,8 +59,8 @@ def register( self.vocab_size = vocab_size return - def init_cpu_embed_cache_client(self, init_shm_data: bool = False): - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=init_shm_data) + def init_cpu_embed_cache_client(self): + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) return def get_overlap_stream(self) -> torch.cuda.Stream: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index eb6fd904f2..8b085c45ed 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -145,7 +145,7 @@ def init_model(self, kvargs): wait_events.append(self.multi_level_cache_module) if self.args.enable_multimodal: - g_infer_context.init_cpu_embed_cache_client(init_shm_data=False) + g_infer_context.init_cpu_embed_cache_client() model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) From b27633dcd29858ba6c8a5fd296150c2f03f1c96c Mon Sep 17 00:00:00 2001 From: wzj Date: Tue, 24 Mar 2026 14:32:12 +0000 Subject: [PATCH 13/54] fix --- lightllm/server/api_cli.py | 43 +++++--------------- lightllm/server/core/objs/start_args_type.py | 2 + 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4e5ab7e421..7d1cdb136e 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -15,8 +15,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "nixl_decode", "pd_master", "config_server", - "visual", - "visual_only", + "only_visual_infer", ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, @@ -71,6 +70,14 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--config_server_vit_redis_port", + type=int, + default=None, + help="""when run_mode is config_server, set this params will start a redis server, + when a llm infer node start to set this params, the visual infer module will start a + proxy module use config server to find remote vit infer nodes to infer img""" + ) parser.add_argument( "--nixl_pd_kv_page_num", type=int, @@ -616,38 +623,10 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""The interval of the schedule time, default is 30ms.""", ) parser.add_argument( - "--image_embed_dir", + "--afs_image_embed_dir", type=str, default=None, - help="path for vit embed", - ) - parser.add_argument( - "--enable_remote_vit", - action="store_true", - help="Whether to enable remote vit for multimodal service.", - ) - parser.add_argument( - "--remote_vit_port", - type=int, - default=12346, - help="The port number for the remote vit service.", - ) - parser.add_argument( - "--redis_port", - type=int, - default=6379, - help="The port number for the redis service in config_server mode.", - ) - parser.add_argument( - "--redis_evict_fraction", - type=float, - default=0.3, - help="The evict fraction for the redis service in config_server mode.", - ) - parser.add_argument( - "--start_redis", - action="store_true", - help="Whether to start the redis service in config_server mode.", + help="path for vit embed, when use vit remote infer mode", ) parser.add_argument( "--enable_cpu_cache", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d3dc849664..6948bceb3b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -20,6 +20,8 @@ class StartArgs: pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) + config_server_vit_redis_port: int = field(default=None) + afs_image_embed_dir: str = field(default=None) pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") From b62231a8679c933c99bdad315860c2d235613530 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 03:19:03 +0000 Subject: [PATCH 14/54] fix --- lightllm/server/embed_cache/redis_utils.py | 180 +++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 lightllm/server/embed_cache/redis_utils.py diff --git a/lightllm/server/embed_cache/redis_utils.py b/lightllm/server/embed_cache/redis_utils.py new file mode 100644 index 0000000000..517f629c81 --- /dev/null +++ b/lightllm/server/embed_cache/redis_utils.py @@ -0,0 +1,180 @@ +import redis +from typing import List, Tuple, Union, Optional + + +class RedisMetadataClient: + """ + # 代码任务 + 创建一个基于redis 管理的元数据操作库代码。 + 要求: + 2. 提供一个包装的 redis 操作client 库,提供以下功能: + (1) 提供输入为(md5, token, time_out) 为其创建一个零时具有超时时间的记录,同时提供输入为(md5, token)的解锁接口,防止多线程的异步操作出现问题。 + (2) 提供一个时间排序队列,当出现对md5的任何操作的时候,向队列中插入md5,并更新时间错(单位为s即可), 当时创建锁和解锁不更新时间错。 + (3) 输入为(md5,token), 先校验 md5锁对应的内容为token, 然后标记 md5 对应的资源已经准备就绪, 向时间排序队列插入更新md5的时间错。 + (4) 输入为(md5, token) 先校验 md5锁对应的内容为token, 当 md5 对应的资源存在的时候,同时更新排序队列中的时间错,同时返回True, 否则返回False,不更新时间错。 + (5) 输入为(md5, token), 先校验 md5锁对应的内容为token, 移除标记 md5 对应的资源已经准备就绪,并同时从时间排序队列中移除对应的md5。 + (6) 输入为(remove_size, capcity), 当时间排序队列中的元素数量大于等于capcity, 返回时间排序队列中排在前面的 remove_size 个元素,其内容为 md5。 + (7) 所有操作都使用lua 脚本,以实现原子化操作,同时返回的错误要能区分具体错误的原因,注意lua脚本的可读性,和相关函数的输入输出测试。 + """ + + def __init__(self, redis_url: str = "redis://localhost:6379/0", prefix: str = "meta"): + self.r = redis.Redis.from_url(redis_url, decode_responses=True) + self.prefix = prefix + self.lru_key = f"{prefix}:queue:lru" + self._register_scripts() + + def _register_scripts(self): + """注册 Lua 脚本""" + + # (1) 解锁脚本 (不更新时间戳) + self._lua_unlock = self.r.register_script( + """ + local lock_key = KEYS[1] + local token = ARGV[1] + if redis.call("GET", lock_key) == token then + return redis.call("DEL", lock_key) + elseif redis.call("EXISTS", lock_key) == 0 then + return -2 + else + return -1 + end + """ + ) + + # (3, 4, 5) 元数据操作脚本 + # 内部通过 redis.call('TIME') 获取服务器时间 + self._lua_meta_op = self.r.register_script( + """ + local lock_key = KEYS[1] + local ready_key = KEYS[2] + local lru_key = KEYS[3] + local op = ARGV[1] + local token = ARGV[2] + local md5 = ARGV[3] + + -- 校验锁 + local current_token = redis.call("GET", lock_key) + if not current_token then return -2 end + if current_token ~= token then return -1 end + + -- 获取服务器时间 (秒) + local server_time = redis.call('TIME')[1] + + if op == "mark_ready" then + redis.call("SET", ready_key, "1") + redis.call("ZADD", lru_key, server_time, md5) + return 1 + elseif op == "check_touch" then + if redis.call("EXISTS", ready_key) == 1 then + redis.call("ZADD", lru_key, server_time, md5) + return 1 + else + return 0 + end + elseif op == "remove_ready" then + redis.call("DEL", ready_key) + redis.call("ZREM", lru_key, md5) + return 1 + end + """ + ) + + # (6) 逐出检查脚本 + self._lua_evict = self.r.register_script( + """ + local lru_key = KEYS[1] + local remove_size = tonumber(ARGV[1]) + local capacity = tonumber(ARGV[2]) + + local current_size = redis.call("ZCARD", lru_key) + if current_size >= capacity then + return redis.call("ZRANGE", lru_key, 0, remove_size - 1) + else + return {} + end + """ + ) + + def _get_keys(self, md5: str): + return [f"{self.prefix}:lock:{md5}", f"{self.prefix}:ready:{md5}", self.lru_key] + + def _handle_res(self, res: int): + """映射错误原因""" + errors = { + 1: (True, "Success"), + 0: (False, "Resource not ready"), + -1: (False, "Error: Token mismatch (Permission denied)"), + -2: (False, "Error: Lock missing or expired"), + } + return errors.get(res, (False, f"Unknown error code: {res}")) + + # (1) 创建锁 + def acquire_lock(self, md5: str, token: str, time_out: int) -> bool: + """创建临时超时记录 (不更新排序队列)""" + lock_key = self._get_keys(md5)[0] + return bool(self.r.set(lock_key, token, nx=True, ex=time_out)) + + # (1) 解锁 + def release_lock(self, md5: str, token: str) -> Tuple[bool, str]: + """解锁 (不更新排序队列)""" + res = self._lua_unlock(keys=[self._get_keys(md5)[0]], args=[token]) + return self._handle_res(res) + + # (3) 标记就绪 + def mark_ready(self, md5: str, token: str) -> Tuple[bool, str]: + """标记就绪并在 Lua 内部更新服务器时间戳""" + keys = self._get_keys(md5) + # 不再传入 now,Lua 脚本内部自行获取 + res = self._lua_meta_op(keys=keys, args=["mark_ready", token, md5]) + return self._handle_res(res) + + # (4) 检查就绪并 Touch + def check_ready_and_touch(self, md5: str, token: str) -> Tuple[bool, str]: + """校验锁和就绪状态,并在 Lua 内部更新服务器时间戳""" + keys = self._get_keys(md5) + res = self._lua_meta_op(keys=keys, args=["check_touch", token, md5]) + return self._handle_res(res) + + # (5) 移除就绪 + def remove_ready(self, md5: str, token: str) -> Tuple[bool, str]: + """移除就绪状态并从队列删除""" + keys = self._get_keys(md5) + res = self._lua_meta_op(keys=keys, args=["remove_ready", token, md5]) + return self._handle_res(res) + + # (6) 获取逐出列表 + def get_eviction_candidates(self, remove_size: int, capacity: int) -> List[str]: + """当数量达到上限,返回最旧的元素""" + return self._lua_evict(keys=[self.lru_key], args=[remove_size, capacity]) + + +# ---------------- 测试验证 ---------------- + + +def test_client(): + + client = RedisMetadataClient() + md5 = "test_file_server_time" + token = "secure_token_123" + + print("Step 1: Acquire Lock") + client.acquire_lock(md5, token, 60) + + print("Step 2: Mark Ready (Updates time inside Lua)") + ok, msg = client.mark_ready(md5, token) + print(f"Result: {ok}, {msg}") + + # 检查 Redis 内部 ZSet 存储的时间戳 + score = client.r.zscore(client.lru_key, md5) + print(f"Server Timestamp in ZSet: {score}") + + print("\nStep 3: Check and Touch (Updates time inside Lua)") + ok, msg = client.check_ready_and_touch(md5, token) + print(f"Result: {ok}, {msg}") + + new_score = client.r.zscore(client.lru_key, md5) + print(f"Updated Server Timestamp: {new_score}") + + +if __name__ == "__main__": + test_client() From 05389e358f68ed49299e7d0bbd12bf100cddd806 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 07:31:19 +0000 Subject: [PATCH 15/54] fix --- lightllm/server/embed_cache/afs_utils.py | 132 +++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 lightllm/server/embed_cache/afs_utils.py diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py new file mode 100644 index 0000000000..0982cbf31d --- /dev/null +++ b/lightllm/server/embed_cache/afs_utils.py @@ -0,0 +1,132 @@ +import os +import time +import torch +import uuid +from typing import List, Tuple, Optional +from pathlib import Path +from .redis_utils import RedisMetadataClient + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class AfsUtils: + def __init__(self, base_dir: str): + self.base_dir = base_dir + # 判断 base_dir 是否存在,不存在则创建并赋予777权限,让其他人也可以写入 + if not os.path.exists(base_dir): + os.makedirs(base_dir, exist_ok=True) + os.chmod(base_dir, 0o777) + return + + def _get_afs_path(self, name: str) -> Path: + return Path(self.base_dir) / name + + def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> None: + target_path = self._get_afs_path(name) + + try: + with open(target_path, "wb") as f: + tensor = tensor.detach().cpu() + dest = torch.empty_like(tensor) + dest.copy_(tensor) + torch.save(dest, f, _use_new_zipfile_serialization=False, pickle_protocol=4) + + os.chmod(target_path, 0o777) + except Exception as e: + try: + target_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to save embed tensor file: {target_path} excetion {str(e)}") + raise e + + def load_tensor_afs(self, name: str) -> torch.Tensor: + path = self._get_afs_path(name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) + + def free_afs(self, name: str) -> None: + path = self._get_afs_path(name) + path.unlink(missing_ok=True) + return + + +class SepEmbedManager: + def __init__( + self, + afs_embed_dir: str, + redis_url: str = "redis://localhost:6379/0", + capacity: int = 50000, + evict_fraction: float = 0.1, + ) -> None: + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >=1") + + self.redis_client = RedisMetadataClient(redis_url=redis_url) + self.capacity = capacity + self.remove_count = min(int(self.capacity * evict_fraction), 1000) # full的时候,每次清理的数量 + self.afs_embed_dir = afs_embed_dir + self.afs_utils = AfsUtils(self.afs_embed_dir) + + def full_to_clean(self): + remove_objs: List[str] = self.redis_client.get_eviction_candidates( + remove_size=self.remove_count, capcity=self.capacity + ) + for obj in remove_objs: + _token = str(uuid.uuid4()) + try: + if self.redis_client.acquire_lock(md5=obj, token=_token, time_out=10): + if self.redis_client.remove_ready(md5=obj, token=_token)[0]: + self.afs_utils.free_afs(obj) + self.redis_client.release_lock(md5=obj, token=_token) + except BaseException as e: + logger.warning(f"full_to_clean md5 {obj} error {str(e)}") + + def insert(self, md5: str, tensor: torch.Tensor) -> bool: + for _ in range(3): + if self._insert(md5, tensor): + return True + else: + time.sleep(30) + return False + + def _insert(self, md5: str, tensor: torch.Tensor) -> bool: + self.full_to_clean() + try: + _token = str(uuid.uuid4()) + if self.redis_client.acquire_lock(md5=md5, token=_token, time_out=30): + self.afs_utils.save_tensor_afs(md5, tensor) + ret = self.redis_client.mark_ready(md5=md5, token=_token) + if ret[0]: + self.redis_client.release_lock(md5=md5, token=_token) + return True + else: + self.redis_client.release_lock(md5=md5, token=_token) + logger.warning(f"insert {md5} failed error {ret[1]}") + return False + except: + return False + + def query_to_lock(self, md5: str) -> Optional[str]: + """ + 返回 None, 或者 token, 返回token代表可以去afs中读取数据了, + """ + try: + _token = str(uuid.uuid4()) + if self.redis_client.acquire_lock(md5=md5, token=_token, time_out=60): + ret = self.redis_client.check_ready_and_touch(md5=md5, token=_token) + if ret[0]: + return _token + else: + logger.warning(f"query_to_lock {md5} failed {ret[1]}") + self.redis_client.release_lock(md5=md5, token=_token) + except: + try: + self.redis_client.release_lock(md5=md5, token=_token) + except: + pass + return None From 012597813cad6103d7de9553070e362813c47473 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 07:35:16 +0000 Subject: [PATCH 16/54] fix --- lightllm/server/embed_cache/utils.py | 426 --------------------------- 1 file changed, 426 deletions(-) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index caeca0b2b6..367bcc91a9 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,59 +1,4 @@ -import os -import time -import torch -import redis -import numpy as np -from typing import List, Tuple -from io import BytesIO -from pathlib import Path import multiprocessing.shared_memory as shm -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -def _get_afs_path(base_dir: str, name: str) -> Path: - if not base_dir: - raise ValueError("image_embed_dir must be set before using disk-backed embed cache") - return Path(base_dir) / name - - -def tensor2bytes(t: torch.Tensor): - buf = BytesIO() - t = t.detach().cpu() - dest = torch.empty_like(t) - dest.copy_(t) - torch.save(dest, buf, _use_new_zipfile_serialization=False, pickle_protocol=4) - buf.seek(0) - return buf.read() - - -def bytes2tensor(b): - return torch.load(BytesIO(b), weights_only=False) - - -def save_tensor_afs(name: str, tensor: torch.Tensor, base_dir: str) -> None: - target_path = _get_afs_path(base_dir, name) - tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" - - try: - with open(tmp_path, "wb") as f: - torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4) - os.replace(tmp_path, target_path) - os.chmod(target_path, 0o777) - except Exception: - try: - tmp_path.unlink(missing_ok=True) - except Exception: - pass - logger.exception(f"failed to save embed tensor file: {target_path}") - raise - - -def load_tensor_afs(name: str, base_dir: str) -> torch.Tensor: - path = _get_afs_path(base_dir, name) - with open(path, "rb") as f: - return torch.load(f, weights_only=False) def create_shm(name, data): @@ -66,388 +11,17 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) -def create_afs(name, data, path): - target_path = _get_afs_path(path, name) - data_size = len(data) - tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" - - try: - with open(tmp_path, "wb") as f: - mem_view = memoryview(data) - f.write(mem_view[:data_size]) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, target_path) - os.chmod(target_path, 0o777) - except Exception: - try: - tmp_path.unlink(missing_ok=True) - except Exception: - pass - logger.exception(f"failed to create embed file: {target_path}") - raise - - def read_shm(name): shared_memory = shm.SharedMemory(name=name) data = shared_memory.buf.tobytes() return data -def read_afs(name: str, base_dir) -> bytes: - path = _get_afs_path(base_dir, name) - return path.read_bytes() - - def free_shm(name): shared_memory = shm.SharedMemory(name=name) shared_memory.close() shared_memory.unlink() -def free_afs(name: str, base_dir) -> None: - path = _get_afs_path(base_dir, name) - path.unlink(missing_ok=True) - - def get_shm_name_data(uid): return str(uid) + "-data" - - -def get_shm_name_embed(uid): - return str(uid) + "-embed" - - -""" -Importable Redis-backed MD5 refcount with LRU eviction. - -Public API: - from md5_refcount import EmbedRefCountRedis - - cache = EmbedRefCountRedis( - redis_url="redis://localhost:6379/0", - capacity=10000, - evict_fraction=0.2 - ) - - # Insert a new md5 with default ref_count=0 - success, evicted_list = cache.insert(md5) - - # Query if exists and increment ref_count if found - exists = cache.query_and_incre(md5) - - # Decrement ref_count - rc, deleted = cache.decr(md5) - - s = cache.stats() -""" - - -class EmbedRefCountRedis: - def __init__( - self, - redis_url: str = "redis://localhost:6379/0", - capacity: int = 50000, - evict_fraction: float = 0.1, - key_prefix: str = "md5:", - image_embed_dir: str = None, - path_ext: str = "-embed", - **redis_kwargs, - ) -> None: - """ - - capacity: max count of md5 entries allowed in Redis - - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity - - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") - - path_ext: file extension for embed files (default: "-embed") - """ - if not (0.0 <= evict_fraction <= 1.0): - raise ValueError("evict_fraction must be 0..1") - if capacity < 1: - raise ValueError("capacity must be >=1") - - self.capacity = int(capacity) - self.evict_fraction = float(evict_fraction) - self.zset_key = f"{key_prefix}lru" - self.ref_prefix = f"{key_prefix}rc:" - self.lock_key = f"{key_prefix}evict:lock" - self.image_embed_dir = image_embed_dir - self.path_ext = path_ext - - self.r = redis.Redis.from_url(redis_url, decode_responses=True, **redis_kwargs) - - # Register Lua scripts - self._insert_script = self.r.register_script(self._INSERT_LUA) - self._query_incre_script = self.r.register_script(self._QUERY_INCRE_LUA) - self._decr_script = self.r.register_script(self._DECR_LUA) - self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) - - def insert(self, md5: str) -> Tuple[bool, List[str]]: - """Insert a new md5 with default ref_count=1. May trigger LRU eviction.""" - # 等待任何正在进行的逐出操作 - self._wait_if_eviction() - - res = self._insert_script( - keys=[self.zset_key, self.ref_prefix], - args=[md5, self.capacity, self.evict_fraction], - ) - - if res[0] == 0: # No eviction needed - return True, [] - - # Need eviction - use atomic eviction script - try: - if self._try_acquire_lock(): - try: - # 原子执行逐出和插入 - evict_res = self._evict_and_insert_script( - keys=[self.zset_key, self.ref_prefix], - args=[md5, self.capacity, self.evict_fraction], - ) - success = bool(evict_res[0]) - victims = evict_res[1:] if len(evict_res) > 1 else [] - - if success: - # 删除被逐出md5对应的AFS文件 - if victims and self.image_embed_dir: - self._delete_afs_files(victims) - return True, victims - else: - # 逐出失败,短暂退避后重试 - time.sleep(0.01) - return self.insert(md5) - finally: - self._release_lock() - else: - # 等待锁释放后重试 - time.sleep(0.01) - return self.insert(md5) - except Exception as e: - self._release_lock() - raise e - - def query(self, md5: str) -> bool: - """Quert if md5 exists.""" - self._wait_if_eviction() - return bool(self.r.exists(self.ref_prefix + md5)) - - def query_and_incre(self, md5: str) -> bool: - """Query if md5 exists and increment ref_count if found.""" - self._wait_if_eviction() - res = self._query_incre_script( - keys=[self.zset_key, self.ref_prefix], - args=[md5], - ) - return bool(res[0]) - - def decr(self, md5: str) -> Tuple[int, bool]: - """Decrement ref_count for md5. Returns (ref_count, deleted).""" - self._wait_if_eviction() - - res = self._decr_script( - keys=[self.zset_key, self.ref_prefix], - args=[md5], - ) - if res[0] == -1: - raise KeyError("md5 not found") - return int(res[0]), bool(res[1]) - - def stats(self) -> dict: - self._wait_if_eviction() - - size = self.r.zcard(self.zset_key) - return { - "items": size, - "capacity": self.capacity, - "evict_fraction": self.evict_fraction, - } - - def get_ref(self, md5: str) -> int | None: - self._wait_if_eviction() - val = self.r.get(self.ref_prefix + md5) - return int(val) if val is not None else None - - def _wait_if_eviction(self) -> None: - max_wait = 30 - start_time = time.time() - - while self.r.exists(self.lock_key): - if time.time() - start_time > max_wait: - raise TimeoutError("Eviction operation timeout, waited too long") - time.sleep(0.01) # 短暂等待 - - def _try_acquire_lock(self) -> bool: - return bool(self.r.set(self.lock_key, "1", nx=True, ex=30)) - - def _release_lock(self) -> None: - try: - self.r.delete(self.lock_key) - except Exception: - pass - - def _md5_to_afs_path(self, md5: str) -> str: - """Convert md5 to AFS file path.""" - if not self.image_embed_dir: - return None - return str(_get_afs_path(self.image_embed_dir, f"{md5}{self.path_ext}")) - - def _delete_afs_files(self, victims: List[str]) -> None: - """Delete AFS files for evicted md5s.""" - if not self.image_embed_dir: - return - - for md5 in victims: - try: - file_path = self._md5_to_afs_path(md5) - if file_path and os.path.exists(file_path): - os.remove(file_path) - logger.debug(f"Deleted AFS file: {file_path}") - except Exception as e: - logger.debug(f"Warning: Failed to delete AFS file for {md5}: {e}") - - # ---------------- Lua scripts ---------------- - _INSERT_LUA = r""" --- KEYS[1] = zset key, KEYS[2] = ref_prefix --- ARGV[1] = md5, ARGV[2] = capacity, ARGV[3] = evict_fraction -local zset = KEYS[1] -local ref_prefix = KEYS[2] -local md5 = ARGV[1] -local capacity = tonumber(ARGV[2]) - -local unpack = unpack or table.unpack -local ref_key = ref_prefix .. md5 -if redis.call('GET', ref_key) then - return {0} -- Already exists -end - -local size = redis.call('ZCARD', zset) -if size < capacity then - -- Insert with ref_count=1 - redis.call('SET', ref_key, 1) - local now = redis.call('TIME')[1] * 1000 - redis.call('ZADD', zset, now, md5) - return {0} -- Success, no eviction -end - -return {1} -- Need eviction -""" - - _QUERY_INCRE_LUA = r""" --- KEYS[1] = zset key, KEYS[2] = ref_prefix --- ARGV[1] = md5 -local zset = KEYS[1] -local ref_prefix = KEYS[2] -local md5 = ARGV[1] - -local ref_key = ref_prefix .. md5 -local val = redis.call('GET', ref_key) - -if not val then - return {0} -- Not found -end - --- Found, increment ref_count and update LRU -local rc = tonumber(val) + 1 -redis.call('SET', ref_key, rc) -local now = redis.call('TIME')[1] * 1000 -redis.call('ZADD', zset, now, md5) -return {1} -- Found and incremented -""" - - _DECR_LUA = r""" --- KEYS[1] = zset key, KEYS[2] = ref_prefix --- ARGV[1] = md5 -local zset = KEYS[1] -local ref_prefix = KEYS[2] -local md5 = ARGV[1] - -local ref_key = ref_prefix .. md5 -local val = redis.call('GET', ref_key) - -if not val then - return {-1, 0} -- Not found -end - ---ref 递减到 0 时保留键,只更新计数与 LRU -local rc = tonumber(val) - 1 -if rc < 0 then rc = 0 end -redis.call('SET', ref_key, rc) - -if rc > 0 then - -- 只有仍被引用时才更新 LRU - local now = redis.call('TIME')[1] * 1000 - redis.call('ZADD', zset, now, md5) -end - -return {rc, 0} -""" - - _EVICT_AND_INSERT_LUA = r""" --- KEYS[1] = zset key, KEYS[2] = ref_prefix --- ARGV[1] = new_md5, ARGV[2] = capacity, ARGV[3] = evict_fraction -local zset = KEYS[1] -local ref_prefix = KEYS[2] -local new_md5 = ARGV[1] -local capacity = tonumber(ARGV[2]) -local evict_fraction = tonumber(ARGV[3]) - -local unpack = unpack or table.unpack - --- helper: now millis -local function now_ms() - local t = redis.call('TIME') - return t[1] * 1000 + math.floor(t[2] / 1000) -end - -local new_ref_key = ref_prefix .. new_md5 - --- If already exists, treat as a hit: bump ref_count and refresh LRU -local cur = redis.call('GET', new_ref_key) -if cur then - local rc = tonumber(cur) + 1 - redis.call('SET', new_ref_key, rc) - redis.call('ZADD', zset, now_ms(), new_md5) - return {1} -- success, no victims -end - --- If not at capacity, just insert -local size = redis.call('ZCARD', zset) -if size < capacity then - redis.call('SET', new_ref_key, 1) - redis.call('ZADD', zset, now_ms(), new_md5) - return {1} -- success, no victims -end - --- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed -local max_try = math.max(1, math.floor(size * evict_fraction + 0.5)) -local victims = {} -local freed = 0 - --- Scan from LRU (smallest score) to MRU -local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') -local i = 1 -while freed < 1 and i <= #all_keys and #victims < max_try do - local md5 = all_keys[i] - local ref_key = ref_prefix .. md5 - local v = redis.call('GET', ref_key) - if v and tonumber(v) <= 0 then - table.insert(victims, md5) - freed = freed + 1 - end - i = i + 2 -- skip score -end - -if freed >= 1 then - -- delete victims - for _, v in ipairs(victims) do - redis.call('DEL', ref_prefix .. v) - redis.call('ZREM', zset, v) - end - -- insert new - redis.call('SET', new_ref_key, 1) - redis.call('ZADD', zset, now_ms(), new_md5) - return {1, unpack(victims)} -else - -- no zero-ref items found - return {0} -end -""" From 6af47b6539c8f71b6a36157b818e64729f672006 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 07:38:13 +0000 Subject: [PATCH 17/54] fix --- .../impl/memory_cache_with_redis.py | 81 ------------------- 1 file changed, 81 deletions(-) delete mode 100644 lightllm/server/embed_cache/impl/memory_cache_with_redis.py diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py deleted file mode 100644 index f8d9da5314..0000000000 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ /dev/null @@ -1,81 +0,0 @@ -import uuid -import threading -import dataclasses -import requests -from typing import Union, Optional -import torch -import time -from collections import deque -import multiprocessing.shared_memory as shm -from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, EmbedRefCountRedis -from .naive_memory_cache import Record, InMemoryCache -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class MemoryCacheWithRedis(InMemoryCache): - def __init__(self, args) -> None: - super().__init__(args) - redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" - self.redis_cache = EmbedRefCountRedis( - redis_url=redis_url, - capacity=args.cache_capacity, - evict_fraction=args.redis_evict_fraction, - image_embed_dir=args.image_embed_dir, - ) - # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id - # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 - # 硬盘里的图片image embed 数量。 - self.capacity = max(1, args.cache_capacity * 2) - - # llm 负责release - def release(self, ids: list[int]) -> None: - with self.lock: - for id in ids: - rec = self._records.get(id) - if rec is None: - continue - - redis_exist = self.redis_cache.query(str(id)) - if redis_exist: - self.redis_cache.decr(str(id)) - - # remote_vit 模式下 release 可能走“预层提前释放 + 请求结束兜底释放”两条路径, - # 这里避免本地 ref 被重复减成负数,保证 release 可重复调用。 - if rec.ref > 0: - self._update_record_ref(rec, -1) - - # vit 负责set - def set_items_embed(self, ids: list[int]) -> None: - with self.lock: - for id in ids: - self.redis_cache.insert(str(id)) - rec = self._records.get(id) - if rec is not None: - rec.embed = True - # Before the embed becomes ready, concurrent miss requests are only - # tracked by the local record refcount. Materialize the remaining - # pending readers into Redis so each later release has a matching - # remote ref to consume. - pending_remote_readers = max(rec.ref - 1, 0) - for _ in range(pending_remote_readers): - self.redis_cache.query_and_incre(str(id)) - if rec.ref > 0: - self._update_record_ref_by_id(id, -1) - # 保留一份 redis 引用,直到真正的消费者读取完成后再 release, - # 避免 VIT 刚写完文件但 LLM 还没来得及读取时被 LRU 误删。 - - def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: - ret = [] - for id in ids: - if embeding_only: - exist = self.redis_cache.query(str(id)) - else: - exist = self.redis_cache.query_and_incre(str(id)) - ret.append(exist) - if exist: - rec = self._records.get(id) - if rec is not None: - rec.embed = True - return ret From cee7837c08f19f34c97857f3335f16b1053fde15 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 07:46:01 +0000 Subject: [PATCH 18/54] fix --- .../visualserver/model_infer/model_rpc.py | 72 +++++++------------ 1 file changed, 26 insertions(+), 46 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 8c94b821d1..6355ac2dbf 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -6,27 +6,25 @@ import inspect from datetime import timedelta from typing import Dict, List, Tuple +from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer -from transformers.configuration_utils import PretrainedConfig - from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer -from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.llava.llava_visual import LlavaVisionModel +from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel -from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient -from lightllm.server.embed_cache.utils import create_afs, get_shm_name_embed, tensor2bytes, save_tensor_afs from lightllm.server.visualserver import set_vit_att_backend @@ -36,7 +34,6 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist - self.args = get_env_start_args() self.vit_dp = kvargs["vit_dp"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -44,11 +41,6 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] - self.image_embed_dir = self.args.image_embed_dir - self.remote_vit = self.args.enable_remote_vit or self.args.run_mode in ["visual", "visual_only"] - if self.remote_vit and not self.image_embed_dir: - raise ValueError("remote vit mode requires image_embed_dir") - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] @@ -64,7 +56,6 @@ def exposed_init_model(self, kvargs): "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], "max_batch_size": kvargs["max_batch_size"], - "remote_vit": self.remote_vit, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -101,10 +92,10 @@ def exposed_init_model(self, kvargs): ) else: raise Exception(f"can not support {self.model_type} now") + self.model.load_model(weight_dir) self.model = self.model.cuda() - if not self.remote_vit: - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -125,47 +116,33 @@ def forward(self, images: List[ImageItem]): def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) - - if self.tp_rank_id != 0: - return - - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids, True)) - ids_to_set = [] - cpu_embeds = None - if self.remote_vit: - cpu_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True) - - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] - if self.remote_vit: - save_tensor_afs(get_shm_name_embed(uid), cpu_embeds[start:end], self.image_embed_dir) - else: + all_img_embeds = all_img_embeds.to(torch.device("cuda")) + + if self.tp_rank_id == 0: + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_flags): + if ready: + continue + uid = uuids[i] + start, end = valid_ids[i] + image = images[i] self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], - start_index_in_cache=image.start_index_in_embed_cache, + embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache ) - ids_to_set.append(uid) - - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) - if not self.remote_vit: + ids_to_set.append(uid) + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) torch.cuda.current_stream().synchronize() return class VisualModelRpcClient: - def __init__(self, conn, vit_tp, rpc_server_process=None): - self.conn = conn - self.model: VisualModelRpcServer = conn.root + def __init__(self, model_rpc, vit_tp, rpc_server_process=None): + self.model: VisualModelRpcServer = model_rpc self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process self.use_rpc = True - self._bg = rpyc.BgServingThread(self.conn) - if self.use_rpc: def async_wrap(f): @@ -184,12 +161,15 @@ async def _func(*args, **kwargs): else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode + return async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) if self.use_rpc: await ans return + else: + return async def encode(self, images: List[ImageItem]): ans = self._encode(images) @@ -235,4 +215,4 @@ async def start_model_process(port, vit_tp, device_id): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) From b620e0ef6a5f667b8d88ae1ce1564890f992dc8a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 08:11:02 +0000 Subject: [PATCH 19/54] fix --- lightllm/server/visualserver/manager.py | 104 ++++-------------------- 1 file changed, 15 insertions(+), 89 deletions(-) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 607888f6b7..8fba9f08d7 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -20,7 +20,7 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name from rpyc.utils.classic import obtain -from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data + logger = init_logger(__name__) @@ -31,16 +31,13 @@ def __init__( args: StartArgs, visual_model_rpc_ports, ): - self.args = args - self.visual_only = args.run_mode in ["visual", "visual_only"] - self.remote_vit = args.enable_remote_vit or self.visual_only - context = zmq.Context(2) - if not self.visual_only: - if not args.disable_audio: - self.send_to_next_module = context.socket(zmq.PUSH) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") - elif args.enable_cpu_cache: + enable_audio = not args.disable_audio + if enable_audio: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + else: + if args.enable_cpu_cache: self.send_to_next_module = context.socket(zmq.PUSH) self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") else: @@ -48,11 +45,7 @@ def __init__( self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.zmq_recv_socket = context.socket(zmq.PULL) - if self.remote_vit: - self.zmq_recv_socket.bind(f"tcp://*:{args.remote_vit_port}") - else: - self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") - + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port @@ -63,11 +56,13 @@ def __init__( self.vit_tp = args.visual_tp self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code + self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() async def wait_to_model_ready(self): + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): @@ -151,12 +146,13 @@ def flush_ready(force: bool = False): continue multimodal_params = group_req_indexes.multimodal_params + img_uuids = [img.uuid for img in multimodal_params.images] # disable prompt cache通常用来测试,需要也去掉image cache的影响 if disable_prompt_cache: ready_image = [False] * len(img_uuids) else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids, True)) + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) for img, ready in zip(multimodal_params.images, ready_image): if not ready: @@ -184,43 +180,6 @@ def flush_ready(force: bool = False): processing_group_reqs = [] flush_ready(force=True) - async def _recv_reqs(self): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) - if not self.remote_vit: - return recv_req - - uuids = [img.uuid for img in recv_req.multimodal_params.images] - already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids, True) - if all(already_embed): - return None - - missing_uuids = [] - token_nums = [] - datas = [] - for img, embed_ready in zip(recv_req.multimodal_params.images, already_embed): - if embed_ready: - continue - missing_uuids.append(img.uuid) - token_nums.append(img.token_num) - datas.append(img.read()) - img.free() - - while True: - if await asyncio.to_thread(self.cache_client.root.alloc, missing_uuids, token_nums) is not None: - break - await asyncio.sleep(0.01) - - ready_flags = obtain(self.cache_client.root.get_items_data(missing_uuids)) - update_data_ids = [] - for uid, ready, data in zip(missing_uuids, ready_flags, datas): - if not ready: - create_shm(get_shm_name_data(uid), data) - update_data_ids.append(uid) - - if update_data_ids: - await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids) - return recv_req - async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -228,9 +187,7 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req = await self._recv_reqs() - if recv_req is None: - continue + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): logger.info( f"visual recv req id {recv_req.group_req_id} " @@ -245,25 +202,6 @@ async def loop_for_netio_req(self): self.visual_recv_max_count = 64 await asyncio.sleep(0.01) - async def loop_for_fwd_visual_only(self): - while True: - if len(self.waiting_reqs) == 0: - await asyncio.sleep(0.01) - continue - - images_need_infer = [] - while len(self.waiting_reqs) > 0: - visual_req = self.waiting_reqs.pop(0) - for img in visual_req.multimodal_params.images: - images_need_infer.append(img) - if len(images_need_infer) == self.infer_batch_size: - await self.infer_imgs(images_need_infer) - images_need_infer = [] - - if len(images_need_infer) > 0: - await self.infer_imgs(images_need_infer) - images_need_infer = [] - def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -272,29 +210,17 @@ def clean_up(self): return -def create_forward_loop(args, visualserver: VisualManager, loop: asyncio.AbstractEventLoop): - if args.run_mode in ["visual", "visual_only"]: - from .register_loop import register_loop - - loop.create_task(visualserver.loop_for_fwd_visual_only()) - loop.create_task(register_loop(args)) - else: - loop.create_task(visualserver.loop_for_fwd()) - - def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() - visualserver = None try: visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) - if visualserver is not None: - visualserver.clean_up() + visualserver.clean_up() raise e pipe_writer.send("init ok") @@ -305,6 +231,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - create_forward_loop(args, visualserver, loop) + loop.create_task(visualserver.loop_for_fwd()) loop.run_until_complete(visualserver.loop_for_netio_req()) return From 46133a71fd5546a1e138ccc0a2837a563fa177b4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 08:29:39 +0000 Subject: [PATCH 20/54] fix --- lightllm/server/embed_cache/afs_utils.py | 4 +++- .../visualserver/model_infer/model_rpc.py | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py index 0982cbf31d..12643fe14c 100644 --- a/lightllm/server/embed_cache/afs_utils.py +++ b/lightllm/server/embed_cache/afs_utils.py @@ -57,7 +57,8 @@ class SepEmbedManager: def __init__( self, afs_embed_dir: str, - redis_url: str = "redis://localhost:6379/0", + redis_host: str, + redis_port: int, capacity: int = 50000, evict_fraction: float = 0.1, ) -> None: @@ -66,6 +67,7 @@ def __init__( if capacity < 1: raise ValueError("capacity must be >=1") + redis_url = f"redis://{redis_host}:{redis_port}/0" self.redis_client = RedisMetadataClient(redis_url=redis_url) self.capacity = capacity self.remove_count = min(int(self.capacity * evict_fraction), 1000) # full的时候,每次清理的数量 diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 6355ac2dbf..54d8579814 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -26,6 +26,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.server.embed_cache.afs_utils import SepEmbedManager class VisualModelRpcServer(rpyc.Service): @@ -41,8 +42,7 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) - self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.is_visual_only_mode = get_env_start_args().run_mode == "visual_only" self.data_type = kvargs["data_type"] self.vit_attn_backend = kvargs["vit_attn_backend"] set_vit_att_backend(self.vit_attn_backend) @@ -95,7 +95,19 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + if not self.is_visual_only_mode: + # 独立部署vit模式下,不需要连接 cache_client + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + else: + args = get_env_start_args() + self.redis_afs_client = SepEmbedManager( + afs_embed_dir=args.afs_embed_dir, + redis_host=args.config_server_host, + redis_port=args.config_server_vit_redis_port, + capacity=args.afs_embed_capacity, + ) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) From 07a46c72b1db20ca9316f757dde0442c7df3856c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 26 Mar 2026 10:43:58 +0000 Subject: [PATCH 21/54] fix --- .../visualserver/model_infer/model_rpc.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 54d8579814..95a010c106 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -148,6 +148,23 @@ def exposed_encode(self, images: List[ImageItem]): torch.cuda.current_stream().synchronize() return + def exposed_encode_visual_only(self, images: List[ImageItem]): + images = obtain(images) + all_img_embeds, uuids, valid_ids = self.forward(images) + all_img_embeds = all_img_embeds.detach().cpu() + + if self.tp_rank_id == 0: + for i in range(len(uuids)): + # uid = uuids[i] + start, end = valid_ids[i] + image = images[i] + embed_tensor = all_img_embeds[start:end] + try: + self.redis_afs_client.insert(image.md5, tensor=embed_tensor) + except: + pass + return + class VisualModelRpcClient: def __init__(self, model_rpc, vit_tp, rpc_server_process=None): From 4d63f47413c810215609eeac90f12fa2f79dd4a8 Mon Sep 17 00:00:00 2001 From: wzj Date: Thu, 26 Mar 2026 15:24:58 +0000 Subject: [PATCH 22/54] fix --- lightllm/server/embed_cache/afs_utils.py | 51 ++-- lightllm/server/embed_cache/redis_utils.py | 259 +++++++++------------ 2 files changed, 150 insertions(+), 160 deletions(-) diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py index 12643fe14c..4e50714891 100644 --- a/lightllm/server/embed_cache/afs_utils.py +++ b/lightllm/server/embed_cache/afs_utils.py @@ -20,37 +20,56 @@ def __init__(self, base_dir: str): os.chmod(base_dir, 0o777) return - def _get_afs_path(self, name: str) -> Path: - return Path(self.base_dir) / name - - def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> None: + def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> bool: target_path = self._get_afs_path(name) - + if target_path.exists(): + return True + tmp_path = self._get_afs_path(name=name, uuid_tail_str=str(uuid.uuid4())) try: - with open(target_path, "wb") as f: + with open(tmp_path, "wb") as f: tensor = tensor.detach().cpu() dest = torch.empty_like(tensor) dest.copy_(tensor) torch.save(dest, f, _use_new_zipfile_serialization=False, pickle_protocol=4) - + os.rename(tmp_path, target_path) os.chmod(target_path, 0o777) + return True except Exception as e: try: target_path.unlink(missing_ok=True) - except Exception: + except: pass - logger.exception(f"failed to save embed tensor file: {target_path} excetion {str(e)}") - raise e + logger.warning(f"failed to save embed tensor file: {target_path} tmp_path: {tmp_path} excetion {str(e)}") + return False + finally: + try: + tmp_path.unlink(missing_ok=True) + except: + pass + - def load_tensor_afs(self, name: str) -> torch.Tensor: - path = self._get_afs_path(name) - with open(path, "rb") as f: - return torch.load(f, weights_only=False) + def load_tensor_afs(self, name: str) -> Optional[torch.Tensor]: + try: + path = self._get_afs_path(name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) + except Exception as e: + logger.warning(f"fail to load afs file {name} error: {str(e)}") + return None def free_afs(self, name: str) -> None: - path = self._get_afs_path(name) - path.unlink(missing_ok=True) + try: + path = self._get_afs_path(name) + path.unlink(missing_ok=True) + except Exception as e: + logger.warning(f"free_afs name: {name} error: {str(e)}") return + + def _get_afs_path(self, name: str, uuid_tail_str: Optional[str] = None) -> Path: + if uuid_tail_str is None: + return Path(self.base_dir) / name + else: + return Path(self.base_dir) / f"{name}.{uuid_tail_str}" class SepEmbedManager: diff --git a/lightllm/server/embed_cache/redis_utils.py b/lightllm/server/embed_cache/redis_utils.py index 517f629c81..2dd266cacb 100644 --- a/lightllm/server/embed_cache/redis_utils.py +++ b/lightllm/server/embed_cache/redis_utils.py @@ -1,180 +1,151 @@ import redis from typing import List, Tuple, Union, Optional - -class RedisMetadataClient: +class RedisMetadataLib: """ # 代码任务 创建一个基于redis 管理的元数据操作库代码。 要求: 2. 提供一个包装的 redis 操作client 库,提供以下功能: - (1) 提供输入为(md5, token, time_out) 为其创建一个零时具有超时时间的记录,同时提供输入为(md5, token)的解锁接口,防止多线程的异步操作出现问题。 - (2) 提供一个时间排序队列,当出现对md5的任何操作的时候,向队列中插入md5,并更新时间错(单位为s即可), 当时创建锁和解锁不更新时间错。 - (3) 输入为(md5,token), 先校验 md5锁对应的内容为token, 然后标记 md5 对应的资源已经准备就绪, 向时间排序队列插入更新md5的时间错。 - (4) 输入为(md5, token) 先校验 md5锁对应的内容为token, 当 md5 对应的资源存在的时候,同时更新排序队列中的时间错,同时返回True, 否则返回False,不更新时间错。 - (5) 输入为(md5, token), 先校验 md5锁对应的内容为token, 移除标记 md5 对应的资源已经准备就绪,并同时从时间排序队列中移除对应的md5。 - (6) 输入为(remove_size, capcity), 当时间排序队列中的元素数量大于等于capcity, 返回时间排序队列中排在前面的 remove_size 个元素,其内容为 md5。 - (7) 所有操作都使用lua 脚本,以实现原子化操作,同时返回的错误要能区分具体错误的原因,注意lua脚本的可读性,和相关函数的输入输出测试。 + (1) 提供一个时间排序队列,向队列中插入md5,并更新时间错(单位为s即可). + (2) 输入为(md5_list,), 向队列中插入所有的md5, 并更新其对应时间错。 + (3) 输入为(md5_list,), 将队列中的md5进行删除。 + (4) 输入为(md5_list,), 返回 md5_list 中所有md5 每个是否在链表中存在,返回一个bool list来标识,同时对所有存在的md5,更新时间错到最新。 + (5) 输入为(remove_size, capcity), 当时间排序队列中的元素数量大于等于capcity, 返回时间排序队列中排在前面的 remove_size 个元素,其内容为 md5。 + (6) 所有操作都使用lua 脚本,以实现原子化操作,同时返回的错误要能区分具体错误的原因,注意lua脚本的可读性,和相关函数的输入输出测试。时间错为server端s级别的参数。 """ - def __init__(self, redis_url: str = "redis://localhost:6379/0", prefix: str = "meta"): + # decode_responses=True 确保返回的是字符串而非字节 self.r = redis.Redis.from_url(redis_url, decode_responses=True) - self.prefix = prefix self.lru_key = f"{prefix}:queue:lru" self._register_scripts() def _register_scripts(self): - """注册 Lua 脚本""" - - # (1) 解锁脚本 (不更新时间戳) - self._lua_unlock = self.r.register_script( - """ - local lock_key = KEYS[1] - local token = ARGV[1] - if redis.call("GET", lock_key) == token then - return redis.call("DEL", lock_key) - elseif redis.call("EXISTS", lock_key) == 0 then - return -2 - else - return -1 + """注册 Lua 脚本实现原子化操作""" + + # (1) & (2) 更新/插入:支持传入单个或多个 MD5 + # 逻辑:获取服务器时间,循环执行 ZADD + self._lua_update = self.r.register_script(""" + local lru_key = KEYS[1] + local now = redis.call('TIME')[1] + local count = 0 + for i, md5 in ipairs(ARGV) do + redis.call('ZADD', lru_key, now, md5) + count = count + 1 end - """ - ) - - # (3, 4, 5) 元数据操作脚本 - # 内部通过 redis.call('TIME') 获取服务器时间 - self._lua_meta_op = self.r.register_script( - """ - local lock_key = KEYS[1] - local ready_key = KEYS[2] - local lru_key = KEYS[3] - local op = ARGV[1] - local token = ARGV[2] - local md5 = ARGV[3] - - -- 校验锁 - local current_token = redis.call("GET", lock_key) - if not current_token then return -2 end - if current_token ~= token then return -1 end - - -- 获取服务器时间 (秒) - local server_time = redis.call('TIME')[1] - - if op == "mark_ready" then - redis.call("SET", ready_key, "1") - redis.call("ZADD", lru_key, server_time, md5) - return 1 - elseif op == "check_touch" then - if redis.call("EXISTS", ready_key) == 1 then - redis.call("ZADD", lru_key, server_time, md5) - return 1 + return count + """) + + # (3) 删除:从队列中移除指定的 MD5 + self._lua_remove = self.r.register_script(""" + local lru_key = KEYS[1] + local count = 0 + for i, md5 in ipairs(ARGV) do + count = count + redis.call('ZREM', lru_key, md5) + end + return count + """) + + # (4) 检查并更新:判断是否存在,存在则刷新时间,返回 bool 状态列表 + self._lua_check_update = self.r.register_script(""" + local lru_key = KEYS[1] + local now = redis.call('TIME')[1] + local results = {} + for i, md5 in ipairs(ARGV) do + if redis.call('ZSCORE', lru_key, md5) then + redis.call('ZADD', lru_key, now, md5) + table.insert(results, 1) else - return 0 + table.insert(results, 0) end - elseif op == "remove_ready" then - redis.call("DEL", ready_key) - redis.call("ZREM", lru_key, md5) - return 1 end - """ - ) + return results + """) - # (6) 逐出检查脚本 - self._lua_evict = self.r.register_script( - """ + # (5) 容量清理:检查容量并获取候选列表 + self._lua_evict = self.r.register_script(""" local lru_key = KEYS[1] local remove_size = tonumber(ARGV[1]) local capacity = tonumber(ARGV[2]) - local current_size = redis.call("ZCARD", lru_key) + local current_size = redis.call('ZCARD', lru_key) if current_size >= capacity then - return redis.call("ZRANGE", lru_key, 0, remove_size - 1) + -- 按照分数(时间戳)从小到大排列,获取最旧的 N 个 + return redis.call('ZRANGE', lru_key, 0, remove_size - 1) else return {} end - """ - ) - - def _get_keys(self, md5: str): - return [f"{self.prefix}:lock:{md5}", f"{self.prefix}:ready:{md5}", self.lru_key] - - def _handle_res(self, res: int): - """映射错误原因""" - errors = { - 1: (True, "Success"), - 0: (False, "Resource not ready"), - -1: (False, "Error: Token mismatch (Permission denied)"), - -2: (False, "Error: Lock missing or expired"), - } - return errors.get(res, (False, f"Unknown error code: {res}")) - - # (1) 创建锁 - def acquire_lock(self, md5: str, token: str, time_out: int) -> bool: - """创建临时超时记录 (不更新排序队列)""" - lock_key = self._get_keys(md5)[0] - return bool(self.r.set(lock_key, token, nx=True, ex=time_out)) - - # (1) 解锁 - def release_lock(self, md5: str, token: str) -> Tuple[bool, str]: - """解锁 (不更新排序队列)""" - res = self._lua_unlock(keys=[self._get_keys(md5)[0]], args=[token]) - return self._handle_res(res) - - # (3) 标记就绪 - def mark_ready(self, md5: str, token: str) -> Tuple[bool, str]: - """标记就绪并在 Lua 内部更新服务器时间戳""" - keys = self._get_keys(md5) - # 不再传入 now,Lua 脚本内部自行获取 - res = self._lua_meta_op(keys=keys, args=["mark_ready", token, md5]) - return self._handle_res(res) - - # (4) 检查就绪并 Touch - def check_ready_and_touch(self, md5: str, token: str) -> Tuple[bool, str]: - """校验锁和就绪状态,并在 Lua 内部更新服务器时间戳""" - keys = self._get_keys(md5) - res = self._lua_meta_op(keys=keys, args=["check_touch", token, md5]) - return self._handle_res(res) - - # (5) 移除就绪 - def remove_ready(self, md5: str, token: str) -> Tuple[bool, str]: - """移除就绪状态并从队列删除""" - keys = self._get_keys(md5) - res = self._lua_meta_op(keys=keys, args=["remove_ready", token, md5]) - return self._handle_res(res) - - # (6) 获取逐出列表 - def get_eviction_candidates(self, remove_size: int, capacity: int) -> List[str]: - """当数量达到上限,返回最旧的元素""" - return self._lua_evict(keys=[self.lru_key], args=[remove_size, capacity]) + """) + def _to_list(self, data: Union[str, List[str]]) -> List[str]: + """内部工具:将输入统一转为列表形式""" + if isinstance(data, str): + return [data] + return data -# ---------------- 测试验证 ---------------- - - -def test_client(): - - client = RedisMetadataClient() - md5 = "test_file_server_time" - token = "secure_token_123" - - print("Step 1: Acquire Lock") - client.acquire_lock(md5, token, 60) - - print("Step 2: Mark Ready (Updates time inside Lua)") - ok, msg = client.mark_ready(md5, token) - print(f"Result: {ok}, {msg}") + def update(self, md5_list: Union[str, List[str]]) -> int: + """ + 功能 (1) & (2):插入或更新 md5 的时间戳。 + 支持传入单个字符串或字符串列表。 + """ + items = self._to_list(md5_list) + if not items: return 0 + return self._lua_update(keys=[self.lru_key], args=items) - # 检查 Redis 内部 ZSet 存储的时间戳 - score = client.r.zscore(client.lru_key, md5) - print(f"Server Timestamp in ZSet: {score}") + def remove(self, md5_list: Union[str, List[str]]) -> int: + """ + 功能 (3):将队列中的 md5 进行删除。 + 支持传入单个字符串或字符串列表。 + """ + items = self._to_list(md5_list) + if not items: return 0 + return self._lua_remove(keys=[self.lru_key], args=items) - print("\nStep 3: Check and Touch (Updates time inside Lua)") - ok, msg = client.check_ready_and_touch(md5, token) - print(f"Result: {ok}, {msg}") + def check_and_update(self, md5_list: List[str]) -> List[bool]: + """ + 功能 (4):返回 md5_list 中每个 md5 是否在队列中存在。 + 对存在的 md5 会同时更新时间戳到最新。 + """ + if not md5_list: return [] + raw_res = self._lua_check_update(keys=[self.lru_key], args=md5_list) + return [res == 1 for res in raw_res] - new_score = client.r.zscore(client.lru_key, md5) - print(f"Updated Server Timestamp: {new_score}") + def get_eviction_candidates(self, remove_size: int, capacity: int) -> List[str]: + """ + 功能 (5):当队列数量 >= capacity 时,返回排在前面的 remove_size 个 md5。 + """ + return self._lua_evict(keys=[self.lru_key], args=[remove_size, capacity]) +# ---------------- 功能测试 ---------------- + +def test_meta_lib(): + lib = RedisMetadataLib(prefix="test_service") + # 清理历史数据 + lib.r.delete(lib.lru_key) + + print("1. 测试更新 (update)") + lib.update("file_0") # 单个 + lib.update(["file_1", "file_2", "file_3"]) # 批量 + print(f"当前队列大小: {lib.r.zcard(lib.lru_key)}") + + print("\n2. 测试检查并更新 (check_and_update)") + # file_1 存在,file_none 不存在,file_3 存在 + check_list = ["file_1", "file_none", "file_3"] + exists_results = lib.check_and_update(check_list) + for m, exists in zip(check_list, exists_results): + print(f"MD5: {m}, 存在状态: {exists}") + + print("\n3. 测试容量逐出 (get_eviction_candidates)") + # 当前有 4 个元素,设容量为 3,要求返回最旧的 2 个 + candidates = lib.get_eviction_candidates(remove_size=2, capacity=3) + print(f"容量达到3时,建议删除的最旧2个元素: {candidates}") + + print("\n4. 测试删除 (remove)") + removed_count = lib.remove(["file_0", "file_1"]) + print(f"成功移除数量: {removed_count}") + + final_check = lib.check_and_update(["file_1", "file_2"]) + print(f"最终检查 [file_1, file_2]: {final_check}") if __name__ == "__main__": - test_client() + test_meta_lib() \ No newline at end of file From 22f7a1cc584a34a442b17a0babb86d9d8d551065 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 02:46:41 +0000 Subject: [PATCH 23/54] fix --- lightllm/server/embed_cache/afs_utils.py | 44 +++++++++++++----------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py index 4e50714891..b68f162ac1 100644 --- a/lightllm/server/embed_cache/afs_utils.py +++ b/lightllm/server/embed_cache/afs_utils.py @@ -2,9 +2,10 @@ import time import torch import uuid +import itertools from typing import List, Tuple, Optional from pathlib import Path -from .redis_utils import RedisMetadataClient +from .redis_utils import RedisMetadataLib from lightllm.utils.log_utils import init_logger @@ -12,20 +13,26 @@ class AfsUtils: - def __init__(self, base_dir: str): + def __init__(self, base_dir: str, dir_depth: int = 2): self.base_dir = base_dir # 判断 base_dir 是否存在,不存在则创建并赋予777权限,让其他人也可以写入 if not os.path.exists(base_dir): - os.makedirs(base_dir, exist_ok=True) - os.chmod(base_dir, 0o777) + os.makedirs(base_dir, mode=0o777, exist_ok=True) + + # build sub dirs + parent_dir = Path(base_dir) + subdirs = ["".join(p) for p in itertools.product("0123456789abcdef", repeat=dir_depth)] + for sub in subdirs: + sub_dir_path = parent_dir / sub + os.makedirs(sub_dir_path, mode=0o777, exist_ok=True) return def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> bool: - target_path = self._get_afs_path(name) - if target_path.exists(): - return True - tmp_path = self._get_afs_path(name=name, uuid_tail_str=str(uuid.uuid4())) try: + target_path = self._get_afs_path(name) + if target_path.exists(): + return True + tmp_path = self._get_afs_path(name=name, uuid_tail_str=str(uuid.uuid4())) with open(tmp_path, "wb") as f: tensor = tensor.detach().cpu() dest = torch.empty_like(tensor) @@ -35,10 +42,6 @@ def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> bool: os.chmod(target_path, 0o777) return True except Exception as e: - try: - target_path.unlink(missing_ok=True) - except: - pass logger.warning(f"failed to save embed tensor file: {target_path} tmp_path: {tmp_path} excetion {str(e)}") return False finally: @@ -47,7 +50,6 @@ def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> bool: except: pass - def load_tensor_afs(self, name: str) -> Optional[torch.Tensor]: try: path = self._get_afs_path(name) @@ -57,19 +59,21 @@ def load_tensor_afs(self, name: str) -> Optional[torch.Tensor]: logger.warning(f"fail to load afs file {name} error: {str(e)}") return None - def free_afs(self, name: str) -> None: + def free_afs(self, name: str) -> bool: try: path = self._get_afs_path(name) path.unlink(missing_ok=True) + return True except Exception as e: logger.warning(f"free_afs name: {name} error: {str(e)}") + return False return - + def _get_afs_path(self, name: str, uuid_tail_str: Optional[str] = None) -> Path: if uuid_tail_str is None: - return Path(self.base_dir) / name + return Path(self.base_dir) / name[0:2] / name else: - return Path(self.base_dir) / f"{name}.{uuid_tail_str}" + return Path(self.base_dir) / name[0:2] / f"{name}.{uuid_tail_str}" class SepEmbedManager: @@ -78,16 +82,16 @@ def __init__( afs_embed_dir: str, redis_host: str, redis_port: int, - capacity: int = 50000, + capacity: int = 250000, evict_fraction: float = 0.1, ) -> None: if not (0.0 <= evict_fraction <= 1.0): raise ValueError("evict_fraction must be 0..1") if capacity < 1: - raise ValueError("capacity must be >=1") + raise ValueError("capacity must be >= 1") redis_url = f"redis://{redis_host}:{redis_port}/0" - self.redis_client = RedisMetadataClient(redis_url=redis_url) + self.redis_client = RedisMetadataLib(redis_url=redis_url) self.capacity = capacity self.remove_count = min(int(self.capacity * evict_fraction), 1000) # full的时候,每次清理的数量 self.afs_embed_dir = afs_embed_dir From dbb0ef391795ac4b549b5c3819345062ffd22abc Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 03:14:02 +0000 Subject: [PATCH 24/54] fix --- lightllm/server/embed_cache/afs_utils.py | 83 ++++++++++++------------ 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py index b68f162ac1..3bd55609dc 100644 --- a/lightllm/server/embed_cache/afs_utils.py +++ b/lightllm/server/embed_cache/afs_utils.py @@ -62,6 +62,8 @@ def load_tensor_afs(self, name: str) -> Optional[torch.Tensor]: def free_afs(self, name: str) -> bool: try: path = self._get_afs_path(name) + if not path.exists(): + return True path.unlink(missing_ok=True) return True except Exception as e: @@ -69,6 +71,14 @@ def free_afs(self, name: str) -> bool: return False return + def exist_afs(self, name: str) -> bool: + try: + path = self._get_afs_path(name) + return path.exists() + except Exception as e: + logger.warning(f"exist_afs name: {name} error: {str(e)}") + return False + def _get_afs_path(self, name: str, uuid_tail_str: Optional[str] = None) -> Path: if uuid_tail_str is None: return Path(self.base_dir) / name[0:2] / name @@ -102,56 +112,45 @@ def full_to_clean(self): remove_size=self.remove_count, capcity=self.capacity ) for obj in remove_objs: - _token = str(uuid.uuid4()) try: - if self.redis_client.acquire_lock(md5=obj, token=_token, time_out=10): - if self.redis_client.remove_ready(md5=obj, token=_token)[0]: - self.afs_utils.free_afs(obj) - self.redis_client.release_lock(md5=obj, token=_token) + if self.afs_utils.free_afs(obj): + self.redis_client.remove([obj]) except BaseException as e: logger.warning(f"full_to_clean md5 {obj} error {str(e)}") def insert(self, md5: str, tensor: torch.Tensor) -> bool: - for _ in range(3): - if self._insert(md5, tensor): - return True - else: - time.sleep(30) - return False - - def _insert(self, md5: str, tensor: torch.Tensor) -> bool: self.full_to_clean() try: - _token = str(uuid.uuid4()) - if self.redis_client.acquire_lock(md5=md5, token=_token, time_out=30): - self.afs_utils.save_tensor_afs(md5, tensor) - ret = self.redis_client.mark_ready(md5=md5, token=_token) - if ret[0]: - self.redis_client.release_lock(md5=md5, token=_token) - return True - else: - self.redis_client.release_lock(md5=md5, token=_token) - logger.warning(f"insert {md5} failed error {ret[1]}") - return False + # 保证一定会有清理的可能性 + self.redis_client.update(md5) + self.afs_utils.save_tensor_afs(md5, tensor) + self.redis_client.update(md5) except: return False - def query_to_lock(self, md5: str) -> Optional[str]: - """ - 返回 None, 或者 token, 返回token代表可以去afs中读取数据了, - """ + def load(self, md5: str) -> Optional[torch.Tensor]: try: - _token = str(uuid.uuid4()) - if self.redis_client.acquire_lock(md5=md5, token=_token, time_out=60): - ret = self.redis_client.check_ready_and_touch(md5=md5, token=_token) - if ret[0]: - return _token - else: - logger.warning(f"query_to_lock {md5} failed {ret[1]}") - self.redis_client.release_lock(md5=md5, token=_token) - except: - try: - self.redis_client.release_lock(md5=md5, token=_token) - except: - pass - return None + ans = self.afs_utils.load_tensor_afs(md5) + if ans: + self.redis_client.update(md5) + return ans + else: + return None + except Exception as e: + logger.warning(f"load md5 {md5} error {str(e)}") + return None + + def check_ready(self, md5_list: List[str]) -> List[bool]: + try: + tmp1 = self.redis_client.check_and_update(md5_list) + start = time.time() + tmp2 = [self.afs_utils.exist_afs(md5) for md5 in md5_list] + cost_time = time.time() - start + if cost_time > 0.05: + logger.warning(f"slow afs check exist {cost_time} seconds, md5_list size: {len(md5_list)}") + assert len(tmp1) == len(tmp2) + ans = [a and b for a, b in zip(tmp1, tmp2)] + return ans + except Exception as e: + logger.warning(f"check_ready error {str(e)}") + return [False] * len(md5_list) From 4ba14b5062ddd342c2526e50a6042ea31f04a727 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 06:54:41 +0000 Subject: [PATCH 25/54] fix --- lightllm/server/embed_cache/afs_utils.py | 5 +- lightllm/server/visualserver/manager.py | 26 +-- .../visualserver/model_infer/model_rpc.py | 118 +++++-------- .../model_infer/visual_only_model_rpc.py | 160 ++++++++++++++++++ 4 files changed, 218 insertions(+), 91 deletions(-) create mode 100644 lightllm/server/visualserver/model_infer/visual_only_model_rpc.py diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py index 3bd55609dc..9281559921 100644 --- a/lightllm/server/embed_cache/afs_utils.py +++ b/lightllm/server/embed_cache/afs_utils.py @@ -86,7 +86,7 @@ def _get_afs_path(self, name: str, uuid_tail_str: Optional[str] = None) -> Path: return Path(self.base_dir) / name[0:2] / f"{name}.{uuid_tail_str}" -class SepEmbedManager: +class SepEmbedHandler: def __init__( self, afs_embed_dir: str, @@ -143,8 +143,9 @@ def load(self, md5: str) -> Optional[torch.Tensor]: def check_ready(self, md5_list: List[str]) -> List[bool]: try: tmp1 = self.redis_client.check_and_update(md5_list) + assert len(tmp1) == len(md5_list) start = time.time() - tmp2 = [self.afs_utils.exist_afs(md5) for md5 in md5_list] + tmp2 = [exists and self.afs_utils.exist_afs(md5) for md5, exists in zip(md5_list, tmp1)] cost_time = time.time() - start if cost_time > 0.05: logger.warning(f"slow afs check exist {cost_time} seconds, md5_list size: {len(md5_list)}") diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..c28432de6b 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -31,6 +31,7 @@ def __init__( args: StartArgs, visual_model_rpc_ports, ): + self.args = args context = zmq.Context(2) enable_audio = not args.disable_audio if enable_audio: @@ -48,15 +49,12 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.cache_port = args.cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir - self.tp_world_size = args.tp self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp + # image 最大推理 batch size self.infer_batch_size = args.visual_infer_batch_size - self.trust_remote_code = args.trust_remote_code - self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() @@ -66,29 +64,25 @@ async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): - tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] - rpc_model = await start_model_process( - port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id - ) + + rpc_model = await start_model_process() self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] for dp_rank_id in range(self.vit_dp): # async init model process for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "vit_dp": self.vit_dp, + "device_id": device_id, + "trust_remote_code": self.args.trust_remote_code, "vit_tp": self.vit_tp, - "cache_port": self.cache_port, + "cache_port": self.args.cache_port, "tp_rank_id": tp_rank_id, "dp_rank_id": dp_rank_id, - "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id, "data_type": self.args.data_type, "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], - "visual_gpu_ids": self.args.visual_gpu_ids, "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), @@ -203,10 +197,6 @@ async def loop_for_netio_req(self): await asyncio.sleep(0.01) def clean_up(self): - for model_rpc in self.model_rpcs: - model_rpc.rpc_server_process.kill() - for model_rpc in self.model_rpcs: - model_rpc.rpc_server_process.join() return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 95a010c106..58f6dc8290 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -4,10 +4,14 @@ import torch import socket import inspect +import uuid +import os +import torch.multiprocessing as mp from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig -from rpyc.utils.classic import obtain +from lightllm.utils.retry_utils import retry +from rpyc.utils.classic import obtain, unix_connect from rpyc.utils.server import ThreadedServer from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel @@ -26,7 +30,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend -from lightllm.server.embed_cache.afs_utils import SepEmbedManager +from lightllm.server.embed_cache.afs_utils import SepEmbedHandler class VisualModelRpcServer(rpyc.Service): @@ -102,7 +106,7 @@ def exposed_init_model(self, kvargs): self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) else: args = get_env_start_args() - self.redis_afs_client = SepEmbedManager( + self.redis_afs_client = SepEmbedHandler( afs_embed_dir=args.afs_embed_dir, redis_host=args.config_server_host, redis_port=args.config_server_vit_redis_port, @@ -148,100 +152,72 @@ def exposed_encode(self, images: List[ImageItem]): torch.cuda.current_stream().synchronize() return - def exposed_encode_visual_only(self, images: List[ImageItem]): - images = obtain(images) - all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.detach().cpu() - if self.tp_rank_id == 0: - for i in range(len(uuids)): - # uid = uuids[i] - start, end = valid_ids[i] - image = images[i] - embed_tensor = all_img_embeds[start:end] - try: - self.redis_afs_client.insert(image.md5, tensor=embed_tensor) - except: - pass - return +class VisualModelRpcClient: + def __init__(self, rpc_conn): + self.rpc_conn: VisualModelRpcServer = rpc_conn + def async_wrap(f): + f = rpyc.async_(f) + + async def _func(*args, **kwargs): + ans = f(*args, **kwargs) + await asyncio.to_thread(ans.wait) + # raise if exception + return ans.value + + return _func + + self._init_model = async_wrap(self.rpc_conn.init_model) + self._encode = async_wrap(self.rpc_conn.encode) -class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc - self.vit_tp = vit_tp - self.rpc_server_process = rpc_server_process - self.use_rpc = True - if self.use_rpc: - - def async_wrap(f): - f = rpyc.async_(f) - - async def _func(*args, **kwargs): - ans = f(*args, **kwargs) - await asyncio.to_thread(ans.wait) - # raise if exception - return ans.value - - return _func - - self._init_model = async_wrap(self.model.init_model) - self._encode = async_wrap(self.model.encode) - else: - self._init_model = self.model.exposed_init_model - self._encode = self.model.exposed_encode return async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) - if self.use_rpc: - await ans - return - else: - return + await ans + return async def encode(self, images: List[ImageItem]): ans = self._encode(images) - if self.use_rpc: - return await ans - else: - return ans + return await ans -def _init_env(port, device_id): +def _init_env(scoket_path: str, success_event: "mp.Event"): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) import lightllm.utils.rpyc_fix_utils as _ - t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) + t = ThreadedServer(VisualModelRpcServer(), socket_path=scoket_path, protocol_config={"allow_pickle": True}) + success_event.set() t.start() return -async def start_model_process(port, vit_tp, device_id): - import multiprocessing +async def start_model_process(): + socket_path = _generate_unix_socket_path() + if os.path.exists(socket_path): + os.remove(socket_path) - proc = multiprocessing.Process( + success_event = mp.Event() + proc = mp.Process( target=_init_env, args=( - port, - device_id, + socket_path, + success_event, ), ) proc.start() - await asyncio.sleep(2) - repeat_count = 0 - while repeat_count < 20: - try: - con = rpyc.connect("localhost", port, config={"allow_pickle": True}) - con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - break - except BaseException: - await asyncio.sleep(1) - repeat_count += 1 - if repeat_count == 20: - raise Exception("init rpc env error!") + await asyncio.to_thread(success_event.wait, timeout=40) + + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(conn.root) + + +def _generate_unix_socket_path() -> str: + """Generate a random Unix socket path""" + unique_id = uuid.uuid4().hex[:8] + return f"/tmp/lightllm_model_infer_{unique_id}.sock" diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py new file mode 100644 index 0000000000..1b8350af69 --- /dev/null +++ b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py @@ -0,0 +1,160 @@ +import os +import queue +import threading +import dataclasses +import rpyc +import socket +import asyncio +import inspect +import uuid +from lightllm.utils.retry_utils import retry +from rpyc.utils.factory import unix_connect +from typing import List, Any +from .model_rpc import VisualModelRpcServer, VisualModelRpcClient +from lightllm.server.multimodal_params import ImageItem +from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from rpyc.utils.server import ThreadedServer +from lightllm.utils.envs_utils import get_env_start_args +from rpyc.utils.classic import obtain +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class _Task: + images: List["ImageItem"] + ret: Any + event: threading.Event + hasError: bool = False + + def wait(self, timeout: float = None): + self.event.wait(timeout=timeout) + + +class VisualOnlyModelRpcServer(VisualModelRpcServer): + """ + 完善这个代码: + 1. 创建一个队列, 用于接受别人放入的task, + 2. 创建一个线程,从队列中取出任务,完成后,修改task中的event,让放入的人得到结果和通知。这是任务循环。 + 3. 能不能封装比较易读的流程。 + """ + + def __init__(self): + super().__init__() + + # 异步队列, 用于接受任务 + self.task_queue = queue.Queue() + # 限制并发, 主要控制内存用量,防止过多照成爆炸。 + self.sempare = threading.Semaphore(3) + + self.afs_handler = SepEmbedHandler( + afs_embed_dir=get_env_start_args().afs_embed_dir, + redis_host=get_env_start_args().config_server_host, + redis_port=get_env_start_args().config_server_vit_redis_port, + capacity=get_env_start_args().afs_embed_capacity, + ) + + # 启动任务处理线程 + self.worker_thread = threading.Thread(target=self._task_worker, daemon=True) + self.worker_thread.start() + + def _task_worker(self): + """ + 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 + """ + while True: + try: + # 从队列获取任务, 阻塞等待 + task: _Task = self.task_queue.get() + + # 执行任务: 调用父类的forward方法处理图像 + try: + all_img_embeds, uuids, valid_ids = self.forward(task.images) + all_img_embeds = all_img_embeds.detach().cpu() + + # 存储结果到task.ret + task.ret = {"embeds": all_img_embeds, "valid_ids": valid_ids} + except Exception as e: + task.hasError = True + logger.exception(str(e)) + raise e + finally: + # 标记任务完成, 唤醒等待的调用者 + task.event.set() + self.task_queue.task_done() + + except Exception as e: + logger.exception(str(e)) + raise e + + def exposed_run_task(self, images: List["ImageItem"]): + """ + 添加任务到队列 + + Args: + images: 要处理的图像列表 + + Returns: + _Task: 任务对象, 包含ret和event + """ + images = obtain(images) + with self.sempare: + event = threading.Event() + task = _Task(images=images, ret=None, event=event) + self.task_queue.put(task) + task.event.wait(timeout=8888) + + all_img_embeds = task.ret["embeds"] + valid_ids = task.ret["valid_ids"] + + if self.tp_rank_id == 0: + for i in enumerate(len(images)): + start, end = valid_ids[i] + image = images[i] + self.afs_handler.insert(image.md5, all_img_embeds[start:end]) + return + + +def _init_env(socket_path: str, device_id: int, success_event): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + + import lightllm.utils.rpyc_fix_utils as _ + + t = ThreadedServer(VisualOnlyModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) + success_event.set() + t.start() + return + + +async def start_model_process(vit_tp, device_id): + import multiprocessing + + socket_path = _generate_unix_socket_path() + if os.path.exists(socket_path): + os.remove(socket_path) + + success_event = multiprocessing.Event() + proc = multiprocessing.Process( + target=_init_env, + args=( + socket_path, + device_id, + success_event, + ), + ) + proc.start() + await asyncio.to_thread(success_event.wait, timeout=40) + assert proc.is_alive() + + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) + assert proc.is_alive() + return VisualModelRpcClient(conn.root, vit_tp, rpc_server_process=proc) + + +def _generate_unix_socket_path() -> str: + """Generate a random Unix socket path""" + unique_id = uuid.uuid4().hex[:8] + return f"/tmp/lightllm_model_infer_{unique_id}.sock" From 48be1264d39a3b5e0ffd04456d56afd1b9d30dc9 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 07:00:28 +0000 Subject: [PATCH 26/54] fix --- lightllm/server/api_start.py | 6 +----- lightllm/server/visualserver/manager.py | 6 ++---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 0d0e744fae..a0a74d81bf 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -525,12 +525,8 @@ def visual_start(args): ) = can_use_ports[0:5] can_use_ports = can_use_ports[5:] - visual_model_tp_ports = [] visual_nccl_ports = [] for _ in range(args.visual_dp): - tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - visual_model_tp_ports.append(tp_ports_for_dp) - can_use_ports = can_use_ports[args.visual_tp :] if args.visual_nccl_ports is None: visual_nccl_ports.append(can_use_ports[0]) can_use_ports = can_use_ports[1:] @@ -564,7 +560,7 @@ def visual_start(args): start_visual_process, ], start_args=[ - (args, visual_model_tp_ports), + (args,), ], ) setup_signal_handlers(None, process_manager) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index c28432de6b..dd134cd944 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -29,7 +29,6 @@ class VisualManager: def __init__( self, args: StartArgs, - visual_model_rpc_ports, ): self.args = args context = zmq.Context(2) @@ -55,7 +54,6 @@ def __init__( self.vit_tp = args.visual_tp # image 最大推理 batch size self.infer_batch_size = args.visual_infer_batch_size - self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() @@ -200,13 +198,13 @@ def clean_up(self): return -def start_visual_process(args, model_rpc_ports, pipe_writer): +def start_visual_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() try: - visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) + visualserver = VisualManager(args=args) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) From c95b7f8230fb5acbbf410c75e36e57f4bec9d745 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 07:09:43 +0000 Subject: [PATCH 27/54] fix --- lightllm/server/visualserver/manager.py | 1 - .../visualserver/model_infer/model_rpc.py | 18 ++++++++++++++++-- lightllm/utils/dist_utils.py | 19 +++++++++++++++++-- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index dd134cd944..e7a081f42f 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -74,7 +74,6 @@ async def wait_to_model_ready(self): kvargs = { "weight_dir": self.model_weightdir, "device_id": device_id, - "trust_remote_code": self.args.trust_remote_code, "vit_tp": self.vit_tp, "cache_port": self.args.cache_port, "tp_rank_id": tp_rank_id, diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 58f6dc8290..f61e010ca5 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -39,12 +39,26 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist - self.vit_dp = kvargs["vit_dp"] + # kvargs = { + # "weight_dir": self.model_weightdir, + # "device_id": device_id, + # "vit_tp": self.vit_tp, + # "cache_port": self.args.cache_port, + # "tp_rank_id": tp_rank_id, + # "dp_rank_id": dp_rank_id, + # "data_type": self.args.data_type, + # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + # "quant_type": self.args.vit_quant_type, + # "quant_cfg": self.args.vit_quant_cfg, + # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "vit_attn_backend": self.vit_attn_backend, + # } + + weight_dir = kvargs["weight_dir"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] self.cache_port = kvargs["cache_port"] - weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] self.is_visual_only_mode = get_env_start_args().run_mode == "visual_only" self.data_type = kvargs["data_type"] diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4c..bec02cd4cc 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,14 +55,29 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): + """ + # kvargs = { + # "weight_dir": self.model_weightdir, + # "device_id": device_id, + # "vit_tp": self.vit_tp, + # "cache_port": self.args.cache_port, + # "tp_rank_id": tp_rank_id, + # "dp_rank_id": dp_rank_id, + # "data_type": self.args.data_type, + # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + # "quant_type": self.args.vit_quant_type, + # "quant_cfg": self.args.vit_quant_cfg, + # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "vit_attn_backend": self.vit_attn_backend, + # } + """ tp_world_size = kvargs["vit_tp"] dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = kvargs["visual_gpu_ids"] - device_id = visual_gpu_ids[kvargs["vit_rank_id"]] + device_id = kvargs["device_id"] set_current_device_id(device_id) torch.cuda.set_device(device_id) dist.init_process_group( From abfaa6f69d8c09bb8b363e199f24c40b8fc5bc4a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 07:42:27 +0000 Subject: [PATCH 28/54] fix --- .../visualserver/model_infer/model_rpc.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index f61e010ca5..837beca58f 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -36,8 +36,6 @@ class VisualModelRpcServer(rpyc.Service): def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) - import torch - import torch.distributed as dist # kvargs = { # "weight_dir": self.model_weightdir, @@ -148,6 +146,13 @@ def exposed_encode(self, images: List[ImageItem]): all_img_embeds, uuids, valid_ids = self.forward(images) all_img_embeds = all_img_embeds.to(torch.device("cuda")) + if not self.is_visual_only_mode: + self._not_visual_only_mode_handle(all_img_embeds, uuids, valid_ids, images) + else: + self._visual_only_mode_handle(all_img_embeds, uuids, valid_ids, images) + return + + def _not_visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images): if self.tp_rank_id == 0: ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] @@ -164,7 +169,14 @@ def exposed_encode(self, images: List[ImageItem]): if ids_to_set: self.cache_client.root.set_items_embed(ids_to_set) torch.cuda.current_stream().synchronize() - return + + def _visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images): + if self.tp_rank_id == 0: + all_img_embeds = all_img_embeds.detach().cpu() + for i in enumerate(len(images)): + start, end = valid_ids[i] + image = images[i] + self.redis_afs_client.insert(image.md5, all_img_embeds[start:end]) class VisualModelRpcClient: From b42f4d5ef59e328c0d0493595af77f4fd44bb8df Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 08:05:55 +0000 Subject: [PATCH 29/54] fix --- .../visualserver/model_infer/model_rpc.py | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 837beca58f..ef0530397c 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -7,8 +7,9 @@ import uuid import os import torch.multiprocessing as mp +import collections from datetime import timedelta -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Deque from transformers.configuration_utils import PretrainedConfig from lightllm.utils.retry_utils import retry from rpyc.utils.classic import obtain, unix_connect @@ -118,12 +119,14 @@ def exposed_init_model(self, kvargs): self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) else: args = get_env_start_args() + assert args.visual_dp == 1 self.redis_afs_client = SepEmbedHandler( afs_embed_dir=args.afs_embed_dir, redis_host=args.config_server_host, redis_port=args.config_server_vit_redis_port, capacity=args.afs_embed_capacity, ) + self.async_ret_handle_list: Deque[tuple] = collections.deque() except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -173,10 +176,16 @@ def _not_visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images) def _visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images): if self.tp_rank_id == 0: all_img_embeds = all_img_embeds.detach().cpu() - for i in enumerate(len(images)): - start, end = valid_ids[i] - image = images[i] - self.redis_afs_client.insert(image.md5, all_img_embeds[start:end]) + self.async_ret_handle_list.append((all_img_embeds, valid_ids, images)) + + def exposed_put_to_afs(self): + assert self.tp_rank_id == 0 + assert len(self.async_ret_handle_list) > 0 + all_img_embeds, valid_ids, images = self.async_ret_handle_list.popleft() + for i in enumerate(len(images)): + start, end = valid_ids[i] + image = images[i] + self.redis_afs_client.insert(image.md5, all_img_embeds[start:end]) class VisualModelRpcClient: @@ -196,6 +205,7 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.rpc_conn.init_model) self._encode = async_wrap(self.rpc_conn.encode) + self._put_to_afs = async_wrap(self.rpc_conn.put_to_afs) return @@ -208,6 +218,10 @@ async def encode(self, images: List[ImageItem]): ans = self._encode(images) return await ans + async def put_to_afs(self): + ans = self._put_to_afs() + return await ans + def _init_env(scoket_path: str, success_event: "mp.Event"): # 注册graceful 退出的处理 @@ -237,10 +251,17 @@ async def start_model_process(): proc.start() await asyncio.to_thread(success_event.wait, timeout=40) - conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) + if get_env_start_args().run_mode != "visual_only": + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) + + assert proc.is_alive() + return VisualModelRpcClient(conn.root) + else: + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) + conn1 = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - assert proc.is_alive() - return VisualModelRpcClient(conn.root) + assert proc.is_alive() + return VisualModelRpcClient(conn.root), VisualModelRpcClient(conn1.root) def _generate_unix_socket_path() -> str: From 9ce0edccebd0008d1557f4081c701c49b216c92a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 09:08:44 +0000 Subject: [PATCH 30/54] fix --- .../visualserver/visual_only_manager.py | 189 ++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 lightllm/server/visualserver/visual_only_manager.py diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py new file mode 100644 index 0000000000..ac4a0a1af8 --- /dev/null +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -0,0 +1,189 @@ +import asyncio +import uvloop +import inspect +import setproctitle +import threading +import queue +import dataclasses +import rpyc +from typing import List, Any +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.core.objs import StartArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain + + +logger = init_logger(__name__) + + +class VisualManager(rpyc.Service): + def __init__( + self, + args: StartArgs, + ): + self.args = args + self.waiting_reqs: List[GroupReqIndexes] = [] + self.model_weightdir = args.model_dir + self.vit_dp = args.visual_dp + self.vit_tp = args.visual_tp + assert self.vit_dp == 1 + # image 最大推理 batch size + self.infer_batch_size = args.visual_infer_batch_size + self.send_batch_size = args.visual_send_batch_size + + # 工作线程 + self.task_queue = queue.Queue() + # 限制并发, 主要控制内存用量,防止过多照成爆炸。 + self.sempare = threading.Semaphore(3) + # 启动任务处理线程 + self.worker_thread = threading.Thread(target=self._task_worker, daemon=True) + self.worker_thread.start() + + async def wait_to_model_ready(self): + + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.model_rpcs_1: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.vit_attn_backend = init_vit_att_backend(index=0) + for dp_rank_id in range(self.vit_dp): + for tp_rank_id in range(self.vit_tp): + + rpc_model = await start_model_process() + self.model_rpcs[dp_rank_id].append(rpc_model[0]) + self.model_rpcs_1[dp_rank_id].append(rpc_model[1]) + + init_model_ret = [] + for dp_rank_id in range(self.vit_dp): # async init model process + for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] + kvargs = { + "weight_dir": self.model_weightdir, + "device_id": device_id, + "vit_tp": self.vit_tp, + "cache_port": self.args.cache_port, + "tp_rank_id": tp_rank_id, + "dp_rank_id": dp_rank_id, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + "quant_type": self.args.vit_quant_type, + "quant_cfg": self.args.vit_quant_cfg, + "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, + } + init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) + await asyncio.gather(*init_model_ret) + return + + async def infer_imgs(self, images: List[ImageItem]): + if len(images) == 0: + return + + tasks = [] + for vit_dp_rank in range(self.vit_dp): + assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] + if assigned_images: + for vit_tp_rank in range(self.vit_tp): + task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) + tasks.append(task) + + await asyncio.gather(*tasks) + return + + def _task_worker(self): + """ + 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 + """ + while True: + try: + # 从队列获取任务, 阻塞等待 + task: _Task = self.task_queue.get() + + # 执行任务: 调用父类的forward方法处理图像 + try: + all_img_embeds, uuids, valid_ids = self.forward(task.images) + all_img_embeds = all_img_embeds.detach().cpu() + + # 存储结果到task.ret + task.ret = {"embeds": all_img_embeds, "valid_ids": valid_ids} + except Exception as e: + task.hasError = True + logger.exception(str(e)) + raise e + finally: + # 标记任务完成, 唤醒等待的调用者 + task.event.set() + self.task_queue.task_done() + + except Exception as e: + logger.exception(str(e)) + raise e + + def exposed_run_task(self, images: List["ImageItem"]): + """ + 添加任务到队列 + + Args: + images: 要处理的图像列表 + + Returns: + _Task: 任务对象, 包含ret和event + """ + images = obtain(images) + with self.sempare: + event = threading.Event() + task = _Task(images=images, ret=None, event=event) + self.task_queue.put(task) + task.event.wait(timeout=8888) + + all_img_embeds = task.ret["embeds"] + valid_ids = task.ret["valid_ids"] + + if self.tp_rank_id == 0: + for i in enumerate(len(images)): + start, end = valid_ids[i] + image = images[i] + self.afs_handler.insert(image.md5, all_img_embeds[start:end]) + return + + def clean_up(self): + return + + +def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + start_parent_check_thread() + try: + visualserver = VisualManager(args=args) + asyncio.run(visualserver.wait_to_model_ready()) + t = rpyc.ThreadedServer(visualserver, port=None, protocol_config={"allow_pickle": True}) + except Exception as e: + logger.exception(str(e)) + visualserver.clean_up() + raise e + + pipe_writer.send("init ok") + + t.start() + return + + +@dataclasses.dataclass +class _Task: + images: List["ImageItem"] + ret: Any + event: threading.Event + hasError: bool = False + + def wait(self, timeout: float = None): + self.event.wait(timeout=timeout) From 83cf458c9c907c36f2e30f659580e0f232391a94 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 09:31:47 +0000 Subject: [PATCH 31/54] fix --- .../visualserver/visual_only_manager.py | 63 +++++++++---------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index ac4a0a1af8..0aa90bd5ca 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -6,6 +6,7 @@ import queue import dataclasses import rpyc +import uuid from typing import List, Any from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import StartArgs @@ -35,9 +36,6 @@ def __init__( self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp assert self.vit_dp == 1 - # image 最大推理 batch size - self.infer_batch_size = args.visual_infer_batch_size - self.send_batch_size = args.visual_send_batch_size # 工作线程 self.task_queue = queue.Queue() @@ -81,21 +79,20 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - async def infer_imgs(self, images: List[ImageItem]): - if len(images) == 0: - return - + async def infer_imgs(self, images: List[ImageItem], infer_uids: str): + assert len(images) != 0 tasks = [] - for vit_dp_rank in range(self.vit_dp): - assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] - if assigned_images: - for vit_tp_rank in range(self.vit_tp): - task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) - tasks.append(task) + for vit_tp_rank in range(self.vit_tp): + task = asyncio.create_task(self.model_rpcs[0][vit_tp_rank].encode(images, infer_uids=infer_uids)) + tasks.append(task) await asyncio.gather(*tasks) return + async def put_to_afs(self, infer_uids: str): + await self.model_rpcs_1[0][0].put_to_afs(infer_uids) + return + def _task_worker(self): """ 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 @@ -107,11 +104,7 @@ def _task_worker(self): # 执行任务: 调用父类的forward方法处理图像 try: - all_img_embeds, uuids, valid_ids = self.forward(task.images) - all_img_embeds = all_img_embeds.detach().cpu() - - # 存储结果到task.ret - task.ret = {"embeds": all_img_embeds, "valid_ids": valid_ids} + asyncio.run(self.infer_imgs(task.images)) except Exception as e: task.hasError = True logger.exception(str(e)) @@ -135,21 +128,23 @@ def exposed_run_task(self, images: List["ImageItem"]): Returns: _Task: 任务对象, 包含ret和event """ - images = obtain(images) - with self.sempare: - event = threading.Event() - task = _Task(images=images, ret=None, event=event) - self.task_queue.put(task) - task.event.wait(timeout=8888) - - all_img_embeds = task.ret["embeds"] - valid_ids = task.ret["valid_ids"] - - if self.tp_rank_id == 0: - for i in enumerate(len(images)): - start, end = valid_ids[i] - image = images[i] - self.afs_handler.insert(image.md5, all_img_embeds[start:end]) + try: + images = obtain(images) + # 写入 shm, 然后 + + with self.sempare: + event = threading.Event() + task = _Task(images=images, infer_uid=uuid.uuid4().hex, vent=event) + self.task_queue.put(task) + task.event.wait(timeout=8888) + + asyncio.run(self.put_to_afs(infer_uids=task.infer_uid)) + + # 将 shm 进行删除 + + except BaseException as e: + logger.exception(str(e)) + raise e return def clean_up(self): @@ -181,8 +176,8 @@ def start_visual_process(args, pipe_writer): @dataclasses.dataclass class _Task: images: List["ImageItem"] - ret: Any event: threading.Event + infer_uid: str hasError: bool = False def wait(self, timeout: float = None): From ea6530f69f33e84f6f88ebafda2c3975f5c358d1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 27 Mar 2026 09:43:45 +0000 Subject: [PATCH 32/54] fix --- .../visualserver/model_infer/model_rpc.py | 34 ++++++----- .../visualserver/visual_only_manager.py | 58 +++++++++---------- 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index ef0530397c..5065806d7e 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -9,7 +9,7 @@ import torch.multiprocessing as mp import collections from datetime import timedelta -from typing import Dict, List, Tuple, Deque +from typing import Dict, List, Tuple, Deque, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.retry_utils import retry from rpyc.utils.classic import obtain, unix_connect @@ -126,7 +126,7 @@ def exposed_init_model(self, kvargs): redis_port=args.config_server_vit_redis_port, capacity=args.afs_embed_capacity, ) - self.async_ret_handle_list: Deque[tuple] = collections.deque() + self.async_ret_handle_dict: Dict[str, tuple] = {} except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -144,15 +144,20 @@ def forward(self, images: List[ImageItem]): return self.model.encode(images) # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images: List[ImageItem]): + def exposed_encode(self, images: List[ImageItem], infer_uid: Optional[str] = None): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) all_img_embeds = all_img_embeds.to(torch.device("cuda")) if not self.is_visual_only_mode: - self._not_visual_only_mode_handle(all_img_embeds, uuids, valid_ids, images) + assert infer_uid is None + self._not_visual_only_mode_handle( + all_img_embeds=all_img_embeds, uuids=uuids, valid_ids=valid_ids, images=images + ) else: - self._visual_only_mode_handle(all_img_embeds, uuids, valid_ids, images) + self._visual_only_mode_handle( + all_img_embeds=all_img_embeds, uuids=uuids, valid_ids=valid_ids, images=images, infer_uid=infer_uid + ) return def _not_visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images): @@ -173,19 +178,20 @@ def _not_visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images) self.cache_client.root.set_items_embed(ids_to_set) torch.cuda.current_stream().synchronize() - def _visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images): + def _visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images, infer_uid): if self.tp_rank_id == 0: all_img_embeds = all_img_embeds.detach().cpu() - self.async_ret_handle_list.append((all_img_embeds, valid_ids, images)) + self.async_ret_handle_dict[infer_uid] = (all_img_embeds, valid_ids, images) - def exposed_put_to_afs(self): + def exposed_put_to_afs(self, infer_uid: str): assert self.tp_rank_id == 0 - assert len(self.async_ret_handle_list) > 0 - all_img_embeds, valid_ids, images = self.async_ret_handle_list.popleft() - for i in enumerate(len(images)): - start, end = valid_ids[i] - image = images[i] - self.redis_afs_client.insert(image.md5, all_img_embeds[start:end]) + ret = self.async_ret_handle_dict.pop(infer_uid, None) + if ret is not None: + all_img_embeds, valid_ids, images = ret + for i in enumerate(len(images)): + start, end = valid_ids[i] + image = images[i] + self.redis_afs_client.insert(image.md5, all_img_embeds[start:end]) class VisualModelRpcClient: diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 0aa90bd5ca..e44aece533 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -47,50 +47,48 @@ def __init__( async def wait_to_model_ready(self): - self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] - self.model_rpcs_1: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.model_rpcs: List[List[VisualModelRpcClient]] = [] + self.model_rpcs_1: List[List[VisualModelRpcClient]] = [] self.vit_attn_backend = init_vit_att_backend(index=0) - for dp_rank_id in range(self.vit_dp): - for tp_rank_id in range(self.vit_tp): - - rpc_model = await start_model_process() - self.model_rpcs[dp_rank_id].append(rpc_model[0]) - self.model_rpcs_1[dp_rank_id].append(rpc_model[1]) + for tp_rank_id in range(self.vit_tp): + rpc_model = await start_model_process() + self.model_rpcs.append(rpc_model[0]) + self.model_rpcs_1.append(rpc_model[1]) init_model_ret = [] - for dp_rank_id in range(self.vit_dp): # async init model process - for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] - kvargs = { - "weight_dir": self.model_weightdir, - "device_id": device_id, - "vit_tp": self.vit_tp, - "cache_port": self.args.cache_port, - "tp_rank_id": tp_rank_id, - "dp_rank_id": dp_rank_id, - "data_type": self.args.data_type, - "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], - "quant_type": self.args.vit_quant_type, - "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), - "vit_attn_backend": self.vit_attn_backend, - } - init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) + + for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[tp_rank_id] + kvargs = { + "weight_dir": self.model_weightdir, + "device_id": device_id, + "vit_tp": self.vit_tp, + "cache_port": self.args.cache_port, + "tp_rank_id": tp_rank_id, + "dp_rank_id": 0, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_ports[0], + "quant_type": self.args.vit_quant_type, + "quant_cfg": self.args.vit_quant_cfg, + "max_batch_size": min(self.args.visual_infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, + } + init_model_ret.append(self.model_rpcs[tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) return - async def infer_imgs(self, images: List[ImageItem], infer_uids: str): + async def infer_imgs(self, images: List[ImageItem], infer_uid: str): assert len(images) != 0 tasks = [] for vit_tp_rank in range(self.vit_tp): - task = asyncio.create_task(self.model_rpcs[0][vit_tp_rank].encode(images, infer_uids=infer_uids)) + task = asyncio.create_task(self.model_rpcs[vit_tp_rank].encode(images, infer_uid=infer_uid)) tasks.append(task) await asyncio.gather(*tasks) return - async def put_to_afs(self, infer_uids: str): - await self.model_rpcs_1[0][0].put_to_afs(infer_uids) + async def put_to_afs(self, infer_uid: str): + await self.model_rpcs_1[0].put_to_afs(infer_uid) return def _task_worker(self): From cd051048930f529920046868d4f5b1da09579975 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 05:20:58 +0000 Subject: [PATCH 33/54] fix --- .../model_infer/visual_only_model_rpc.py | 145 ++++++++++-------- 1 file changed, 79 insertions(+), 66 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py index 1b8350af69..853ebd4482 100644 --- a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py +++ b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py @@ -1,15 +1,13 @@ import os import queue import threading -import dataclasses -import rpyc -import socket import asyncio import inspect import uuid +import rpyc from lightllm.utils.retry_utils import retry from rpyc.utils.factory import unix_connect -from typing import List, Any +from typing import List, Any, Deque, Tuple from .model_rpc import VisualModelRpcServer, VisualModelRpcClient from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.afs_utils import SepEmbedHandler @@ -21,18 +19,6 @@ logger = init_logger(__name__) - -@dataclasses.dataclass -class _Task: - images: List["ImageItem"] - ret: Any - event: threading.Event - hasError: bool = False - - def wait(self, timeout: float = None): - self.event.wait(timeout=timeout) - - class VisualOnlyModelRpcServer(VisualModelRpcServer): """ 完善这个代码: @@ -44,10 +30,16 @@ class VisualOnlyModelRpcServer(VisualModelRpcServer): def __init__(self): super().__init__() + # 控制每次的最大推理图片数量,防止爆显存 + self.max_infer_batch_size = get_env_start_args().visual_infer_batch_size + # 异步队列, 用于接受任务 - self.task_queue = queue.Queue() - # 限制并发, 主要控制内存用量,防止过多照成爆炸。 - self.sempare = threading.Semaphore(3) + self.infer_queue = queue.Queue() + # 将计算得到的结果放入 afs 的queue + self.put_afs_queue = queue.Queue() + + # 限制并发, 主要控制内存用量,防止过多造成内存OOM + self.sempare = threading.Semaphore(self.max_infer_batch_size * 8) self.afs_handler = SepEmbedHandler( afs_embed_dir=get_env_start_args().afs_embed_dir, @@ -57,67 +49,87 @@ def __init__(self): ) # 启动任务处理线程 - self.worker_thread = threading.Thread(target=self._task_worker, daemon=True) - self.worker_thread.start() + self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) + self._infer_thread.start() + + self._put_afs_thread = threading.Thread(target=self._put_afs_worker, daemon=True) + self._put_afs_thread.start() + + def exposed_run_task(self, images: List["ImageItem"], ref_event: threading.Event): + try: + images = obtain(images) + images[-1].event = ref_event + + for image in images: + self.infer_queue.put(image) + + except BaseException as e: + logger.exception(str(e)) + raise e + return + + def _get_image_items_from_queue(self, max_num: int) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) + + # 尝试继续获取更多任务,直到达到 max_num + while len(tasks) < max_num: + try: + self.sempare.acquire() + task = self.infer_queue.get(block=False) + tasks.append(task) + except queue.Empty: + self.sempare.release() + break + + return tasks - def _task_worker(self): + def _infer_worker(self): """ 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 """ while True: try: # 从队列获取任务, 阻塞等待 - task: _Task = self.task_queue.get() + images = self._get_image_items_from_queue(max_num=self.max_infer_batch_size) # 执行任务: 调用父类的forward方法处理图像 - try: - all_img_embeds, uuids, valid_ids = self.forward(task.images) - all_img_embeds = all_img_embeds.detach().cpu() - - # 存储结果到task.ret - task.ret = {"embeds": all_img_embeds, "valid_ids": valid_ids} - except Exception as e: - task.hasError = True - logger.exception(str(e)) - raise e - finally: - # 标记任务完成, 唤醒等待的调用者 - task.event.set() - self.task_queue.task_done() + all_img_embeds, uuids, valid_ids = self.forward(images) + all_img_embeds = all_img_embeds.detach().cpu() + for image, valid_id in zip(images, valid_ids): + start, end = valid_id + self.put_afs_queue.put((image, all_img_embeds[start:end])) except Exception as e: logger.exception(str(e)) raise e - - def exposed_run_task(self, images: List["ImageItem"]): + + def _put_afs_worker(self): """ - 添加任务到队列 - - Args: - images: 要处理的图像列表 - - Returns: - _Task: 任务对象, 包含ret和event + 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 """ - images = obtain(images) - with self.sempare: - event = threading.Event() - task = _Task(images=images, ret=None, event=event) - self.task_queue.put(task) - task.event.wait(timeout=8888) - - all_img_embeds = task.ret["embeds"] - valid_ids = task.ret["valid_ids"] - - if self.tp_rank_id == 0: - for i in enumerate(len(images)): - start, end = valid_ids[i] - image = images[i] - self.afs_handler.insert(image.md5, all_img_embeds[start:end]) - return + while True: + try: + # 从队列获取任务, 阻塞等待 + image, embed = self.put_afs_queue.get(block=True) + # 只有 0 rank 执行真的写入操作。 + if self.tp_rank_id == 0: + self.afs_handler.insert(image.md5, embed) + if hasattr(image, "event"): + image.event.set() + self.sempare.release() + except Exception as e: + logger.exception(str(e)) + raise e -def _init_env(socket_path: str, device_id: int, success_event): +def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) @@ -129,7 +141,7 @@ def _init_env(socket_path: str, device_id: int, success_event): return -async def start_model_process(vit_tp, device_id): +async def start_model_process(): import multiprocessing socket_path = _generate_unix_socket_path() @@ -141,7 +153,6 @@ async def start_model_process(vit_tp, device_id): target=_init_env, args=( socket_path, - device_id, success_event, ), ) @@ -151,7 +162,9 @@ async def start_model_process(vit_tp, device_id): conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) assert proc.is_alive() - return VisualModelRpcClient(conn.root, vit_tp, rpc_server_process=proc) + # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 + conn._bg_thread = rpyc.BgServingThread(conn) + return VisualModelRpcClient(conn.root) def _generate_unix_socket_path() -> str: From 19f6cb6aa5de55fe03e7e0bc27a0cd5b699a62fb Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 05:42:40 +0000 Subject: [PATCH 34/54] fix --- .../visualserver/model_infer/__init__.py | 61 +++++++++ .../visualserver/model_infer/model_rpc.py | 117 ++---------------- .../model_infer/model_rpc_client.py | 44 +++++++ .../model_infer/visual_only_model_rpc.py | 2 +- 4 files changed, 113 insertions(+), 111 deletions(-) create mode 100644 lightllm/server/visualserver/model_infer/model_rpc_client.py diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index e69de29bb2..43d1d90df3 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -0,0 +1,61 @@ +import asyncio +import rpyc +import inspect +import uuid +import os +from lightllm.utils.retry_utils import retry +from rpyc.utils.classic import obtain, unix_connect +from rpyc.utils.server import ThreadedServer +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.envs_utils import get_env_start_args + +from .model_rpc import VisualModelRpcServer, VisualModelRpcClient +from .visual_only_model_rpc import VisualOnlyModelRpcServer + +def _init_env(socket_path: str, success_event): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + + import lightllm.utils.rpyc_fix_utils as _ + if get_env_start_args().run_mode == "visual_only": + t = ThreadedServer(VisualOnlyModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) + else: + t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) + success_event.set() + t.start() + return + + +async def start_model_process(): + import multiprocessing + + socket_path = _generate_unix_socket_path() + if os.path.exists(socket_path): + os.remove(socket_path) + + success_event = multiprocessing.Event() + proc = multiprocessing.Process( + target=_init_env, + args=( + socket_path, + success_event, + ), + ) + proc.start() + await asyncio.to_thread(success_event.wait, timeout=40) + assert proc.is_alive() + + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) + assert proc.is_alive() + + if get_env_start_args().run_mode == "visual_only": + # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 + conn._bg_thread = rpyc.BgServingThread(conn) + + return VisualModelRpcClient(conn) + + +def _generate_unix_socket_path() -> str: + """Generate a random Unix socket path""" + unique_id = uuid.uuid4().hex[:8] + return f"/tmp/lightllm_model_infer_{unique_id}.sock" \ No newline at end of file diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 5065806d7e..449f97cefa 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -126,7 +126,6 @@ def exposed_init_model(self, kvargs): redis_port=args.config_server_vit_redis_port, capacity=args.afs_embed_capacity, ) - self.async_ret_handle_dict: Dict[str, tuple] = {} except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -144,23 +143,18 @@ def forward(self, images: List[ImageItem]): return self.model.encode(images) # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images: List[ImageItem], infer_uid: Optional[str] = None): + def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) all_img_embeds = all_img_embeds.to(torch.device("cuda")) - if not self.is_visual_only_mode: - assert infer_uid is None - self._not_visual_only_mode_handle( - all_img_embeds=all_img_embeds, uuids=uuids, valid_ids=valid_ids, images=images - ) - else: - self._visual_only_mode_handle( - all_img_embeds=all_img_embeds, uuids=uuids, valid_ids=valid_ids, images=images, infer_uid=infer_uid - ) + assert not self.is_visual_only_mode + self._put_image_embed_to_cpu_cache( + all_img_embeds=all_img_embeds, uuids=uuids, valid_ids=valid_ids, images=images + ) return - def _not_visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images): + def _put_image_embed_to_cpu_cache(self, all_img_embeds, uuids, valid_ids, images): if self.tp_rank_id == 0: ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] @@ -176,101 +170,4 @@ def _not_visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images) ids_to_set.append(uid) if ids_to_set: self.cache_client.root.set_items_embed(ids_to_set) - torch.cuda.current_stream().synchronize() - - def _visual_only_mode_handle(self, all_img_embeds, uuids, valid_ids, images, infer_uid): - if self.tp_rank_id == 0: - all_img_embeds = all_img_embeds.detach().cpu() - self.async_ret_handle_dict[infer_uid] = (all_img_embeds, valid_ids, images) - - def exposed_put_to_afs(self, infer_uid: str): - assert self.tp_rank_id == 0 - ret = self.async_ret_handle_dict.pop(infer_uid, None) - if ret is not None: - all_img_embeds, valid_ids, images = ret - for i in enumerate(len(images)): - start, end = valid_ids[i] - image = images[i] - self.redis_afs_client.insert(image.md5, all_img_embeds[start:end]) - - -class VisualModelRpcClient: - def __init__(self, rpc_conn): - self.rpc_conn: VisualModelRpcServer = rpc_conn - - def async_wrap(f): - f = rpyc.async_(f) - - async def _func(*args, **kwargs): - ans = f(*args, **kwargs) - await asyncio.to_thread(ans.wait) - # raise if exception - return ans.value - - return _func - - self._init_model = async_wrap(self.rpc_conn.init_model) - self._encode = async_wrap(self.rpc_conn.encode) - self._put_to_afs = async_wrap(self.rpc_conn.put_to_afs) - - return - - async def init_model(self, kvargs): - ans: rpyc.AsyncResult = self._init_model(kvargs) - await ans - return - - async def encode(self, images: List[ImageItem]): - ans = self._encode(images) - return await ans - - async def put_to_afs(self): - ans = self._put_to_afs() - return await ans - - -def _init_env(scoket_path: str, success_event: "mp.Event"): - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) - - import lightllm.utils.rpyc_fix_utils as _ - - t = ThreadedServer(VisualModelRpcServer(), socket_path=scoket_path, protocol_config={"allow_pickle": True}) - success_event.set() - t.start() - return - - -async def start_model_process(): - socket_path = _generate_unix_socket_path() - if os.path.exists(socket_path): - os.remove(socket_path) - - success_event = mp.Event() - proc = mp.Process( - target=_init_env, - args=( - socket_path, - success_event, - ), - ) - proc.start() - await asyncio.to_thread(success_event.wait, timeout=40) - - if get_env_start_args().run_mode != "visual_only": - conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - - assert proc.is_alive() - return VisualModelRpcClient(conn.root) - else: - conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - conn1 = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - - assert proc.is_alive() - return VisualModelRpcClient(conn.root), VisualModelRpcClient(conn1.root) - - -def _generate_unix_socket_path() -> str: - """Generate a random Unix socket path""" - unique_id = uuid.uuid4().hex[:8] - return f"/tmp/lightllm_model_infer_{unique_id}.sock" + torch.cuda.current_stream().synchronize() \ No newline at end of file diff --git a/lightllm/server/visualserver/model_infer/model_rpc_client.py b/lightllm/server/visualserver/model_infer/model_rpc_client.py new file mode 100644 index 0000000000..e4d3457859 --- /dev/null +++ b/lightllm/server/visualserver/model_infer/model_rpc_client.py @@ -0,0 +1,44 @@ +import asyncio +import rpyc +import threading +from typing import Dict, List, Tuple, Deque, Optional, Union +from lightllm.server.multimodal_params import ImageItem +from .model_rpc import VisualModelRpcServer +from .visual_only_model_rpc import VisualOnlyModelRpcServer +from lightllm.utils.envs_utils import get_env_start_args + + +class VisualModelRpcClient: + def __init__(self, rpc_conn): + self.rpc_conn: Union[VisualModelRpcServer, VisualOnlyModelRpcServer] = rpc_conn + + def async_wrap(f): + f = rpyc.async_(f) + + async def _func(*args, **kwargs): + ans = f(*args, **kwargs) + await asyncio.to_thread(ans.wait) + # raise if exception + return ans.value + + return _func + + self._init_model = async_wrap(self.rpc_conn.root.init_model) + self._encode = async_wrap(self.rpc_conn.root.encode) + if get_env_start_args().run_mode == "visual_only": + self._run_task = async_wrap(self.rpc_conn.root.run_task) + + return + + async def init_model(self, kvargs): + ans: rpyc.AsyncResult = self._init_model(kvargs) + await ans + return + + async def encode(self, images: List[ImageItem]): + ans = self._encode(images) + return await ans + + async def run_task(self, images: List[ImageItem], ref_event: threading.Event): + ans = self._run_task(images, ref_event) + return await ans diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py index 853ebd4482..5d4f86cae3 100644 --- a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py +++ b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py @@ -164,7 +164,7 @@ async def start_model_process(): assert proc.is_alive() # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 conn._bg_thread = rpyc.BgServingThread(conn) - return VisualModelRpcClient(conn.root) + return VisualModelRpcClient(conn) def _generate_unix_socket_path() -> str: From d70d437b834d2d9bfe36c83de7290e333c14b919 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 05:43:42 +0000 Subject: [PATCH 35/54] fix --- .../model_infer/visual_only_model_rpc.py | 46 +------------------ 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py index 5d4f86cae3..2042382357 100644 --- a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py +++ b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py @@ -126,48 +126,4 @@ def _put_afs_worker(self): self.sempare.release() except Exception as e: logger.exception(str(e)) - raise e - - -def _init_env(socket_path: str, success_event): - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) - - import lightllm.utils.rpyc_fix_utils as _ - - t = ThreadedServer(VisualOnlyModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) - success_event.set() - t.start() - return - - -async def start_model_process(): - import multiprocessing - - socket_path = _generate_unix_socket_path() - if os.path.exists(socket_path): - os.remove(socket_path) - - success_event = multiprocessing.Event() - proc = multiprocessing.Process( - target=_init_env, - args=( - socket_path, - success_event, - ), - ) - proc.start() - await asyncio.to_thread(success_event.wait, timeout=40) - assert proc.is_alive() - - conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) - assert proc.is_alive() - # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 - conn._bg_thread = rpyc.BgServingThread(conn) - return VisualModelRpcClient(conn) - - -def _generate_unix_socket_path() -> str: - """Generate a random Unix socket path""" - unique_id = uuid.uuid4().hex[:8] - return f"/tmp/lightllm_model_infer_{unique_id}.sock" + raise e \ No newline at end of file From cbacd0a0a6b7cb14db883274050113958a43d8dc Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 08:50:23 +0000 Subject: [PATCH 36/54] fix --- .../visualserver/model_infer/__init__.py | 9 +- .../model_infer/model_rpc_client.py | 13 +- .../model_infer/visual_only_model_rpc.py | 147 +++++++++++++----- 3 files changed, 117 insertions(+), 52 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index 43d1d90df3..60c9de817e 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -8,8 +8,8 @@ from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args - -from .model_rpc import VisualModelRpcServer, VisualModelRpcClient +from .model_rpc_client import VisualModelRpcClient +from .model_rpc import VisualModelRpcServer from .visual_only_model_rpc import VisualOnlyModelRpcServer def _init_env(socket_path: str, success_event): @@ -48,9 +48,8 @@ async def start_model_process(): conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) assert proc.is_alive() - if get_env_start_args().run_mode == "visual_only": - # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 - conn._bg_thread = rpyc.BgServingThread(conn) + # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 + conn._bg_thread = rpyc.BgServingThread(conn) return VisualModelRpcClient(conn) diff --git a/lightllm/server/visualserver/model_infer/model_rpc_client.py b/lightllm/server/visualserver/model_infer/model_rpc_client.py index e4d3457859..c69cac2b61 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc_client.py +++ b/lightllm/server/visualserver/model_infer/model_rpc_client.py @@ -5,7 +5,6 @@ from lightllm.server.multimodal_params import ImageItem from .model_rpc import VisualModelRpcServer from .visual_only_model_rpc import VisualOnlyModelRpcServer -from lightllm.utils.envs_utils import get_env_start_args class VisualModelRpcClient: @@ -24,9 +23,7 @@ async def _func(*args, **kwargs): return _func self._init_model = async_wrap(self.rpc_conn.root.init_model) - self._encode = async_wrap(self.rpc_conn.root.encode) - if get_env_start_args().run_mode == "visual_only": - self._run_task = async_wrap(self.rpc_conn.root.run_task) + self._run_task = async_wrap(self.rpc_conn.root.run_task) return @@ -34,11 +31,7 @@ async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) await ans return - - async def encode(self, images: List[ImageItem]): - ans = self._encode(images) - return await ans - async def run_task(self, images: List[ImageItem], ref_event: threading.Event): - ans = self._run_task(images, ref_event) + async def run_task(self, images: List[ImageItem], ref_event_list: List[threading.Event]): + ans = self._run_task(images, ref_event_list) return await ans diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py index 2042382357..72b4520074 100644 --- a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py +++ b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py @@ -1,22 +1,17 @@ -import os import queue import threading -import asyncio -import inspect -import uuid -import rpyc -from lightllm.utils.retry_utils import retry -from rpyc.utils.factory import unix_connect +import torch.distributed as dist +import torch from typing import List, Any, Deque, Tuple -from .model_rpc import VisualModelRpcServer, VisualModelRpcClient +from .model_rpc import VisualModelRpcServer from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.afs_utils import SepEmbedHandler from rpyc.utils.server import ThreadedServer from lightllm.utils.envs_utils import get_env_start_args from rpyc.utils.classic import obtain -from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.log_utils import init_logger + logger = init_logger(__name__) class VisualOnlyModelRpcServer(VisualModelRpcServer): @@ -30,17 +25,37 @@ class VisualOnlyModelRpcServer(VisualModelRpcServer): def __init__(self): super().__init__() + + + + def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): + try: + images = obtain(images) + for i in range(len(images)): + images[i].event = ref_event_list[i] + self.infer_queue.put(images[i]) + + except BaseException as e: + logger.exception(str(e)) + raise e + return + + def init_taskes(self): # 控制每次的最大推理图片数量,防止爆显存 self.max_infer_batch_size = get_env_start_args().visual_infer_batch_size # 异步队列, 用于接受任务 self.infer_queue = queue.Queue() - # 将计算得到的结果放入 afs 的queue - self.put_afs_queue = queue.Queue() + self.infer_queue_lock = threading.Lock() + # 将计算得到的结果放入 afs 或者 embed cache 的 queue + self.store_queue = queue.Queue() # 限制并发, 主要控制内存用量,防止过多造成内存OOM self.sempare = threading.Semaphore(self.max_infer_batch_size * 8) + # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 + self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") + self.afs_handler = SepEmbedHandler( afs_embed_dir=get_env_start_args().afs_embed_dir, redis_host=get_env_start_args().config_server_host, @@ -52,23 +67,12 @@ def __init__(self): self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) self._infer_thread.start() - self._put_afs_thread = threading.Thread(target=self._put_afs_worker, daemon=True) - self._put_afs_thread.start() + self._store_thread = threading.Thread(target=self._store_worker, daemon=True) + self._store_thread.start() + pass - def exposed_run_task(self, images: List["ImageItem"], ref_event: threading.Event): - try: - images = obtain(images) - images[-1].event = ref_event - for image in images: - self.infer_queue.put(image) - - except BaseException as e: - logger.exception(str(e)) - raise e - return - - def _get_image_items_from_queue(self, max_num: int) -> List[ImageItem]: + def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: """ 从队列中批量获取任务,直到达到 max_num 或队列为空。 """ @@ -78,29 +82,72 @@ def _get_image_items_from_queue(self, max_num: int) -> List[ImageItem]: task = self.infer_queue.get(block=True) tasks.append(task) - # 尝试继续获取更多任务,直到达到 max_num + if not force_same: + # 尝试继续获取更多任务,直到达到 max_num + while len(tasks) < max_num: + try: + self.sempare.acquire() + task = self.infer_queue.get(block=False) + tasks.append(task) + except queue.Empty: + self.sempare.release() + break + else: + while len(tasks) < max_num: + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) + + return tasks + + def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + task = self.store_queue.get(block=True) + tasks.append(task) + while len(tasks) < max_num: try: - self.sempare.acquire() - task = self.infer_queue.get(block=False) + task = self.store_queue.get(block=False) tasks.append(task) except queue.Empty: - self.sempare.release() break return tasks + def _infer_worker(self): """ 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 """ + torch.cuda.set_device(self.device_id) while True: try: # 从队列获取任务, 阻塞等待 - images = self._get_image_items_from_queue(max_num=self.max_infer_batch_size) + if self.tp_rank_id == 0: + images = self._get_image_items_from_infer_queue(max_num=self.max_infer_batch_size) + dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) + else: + ans = [None] + dist.broadcast_object_list(ans, src=0, group=self.gloo_group) + images = self._get_image_items_from_infer_queue(max_num=ans[0], force_same=True) # 执行任务: 调用父类的forward方法处理图像 all_img_embeds, uuids, valid_ids = self.forward(images) + all_img_embeds = all_img_embeds.to(torch.device("cuda")) + + if self.is_visual_only_mode: + all_img_embeds = all_img_embeds.detach().cpu() + for image, valid_id in zip(images, valid_ids): + start, end = valid_id + gen_embed = all_img_embeds[start:end] + image.gen_embed = gen_embed + self.store_queue.put(image) + else: + self._store_to_cpu_cache(all_img_embeds, valid_ids, images) all_img_embeds = all_img_embeds.detach().cpu() for image, valid_id in zip(images, valid_ids): start, end = valid_id @@ -110,20 +157,46 @@ def _infer_worker(self): logger.exception(str(e)) raise e - def _put_afs_worker(self): + def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): + for i in range(len(images)): + start, end = valid_ids[i] + image = images[i] + if self.tp_rank_id == 0: + self.cpu_embed_cache_client.copy_vision_to_cache( + embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + ) + cuda_event = torch.cuda.Event() + cuda_event.record() + image.cuda_event = cuda_event + self.store_queue.put(image) + + def _store_worker(self): """ 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 """ while True: try: # 从队列获取任务, 阻塞等待 - image, embed = self.put_afs_queue.get(block=True) + images: List[ImageItem] = self._get_image_items_from_store_queue(max_num=self.max_infer_batch_size) # 只有 0 rank 执行真的写入操作。 if self.tp_rank_id == 0: - self.afs_handler.insert(image.md5, embed) - if hasattr(image, "event"): - image.event.set() - self.sempare.release() + if self.is_visual_only_mode: + for image in images: + self.afs_handler.insert(image.md5, image.gen_embed) + image.event.set() + else: + for image in images: + # 等待拷贝到cpu cache 完成。 + image.cuda_event.synchronize() + + uuids = [image.uuid for image in images] + self.cache_client.root.set_items_embed(uuids) + + for image in images: + image.event.set() + + for _ in images: + self.sempare.release() except Exception as e: logger.exception(str(e)) raise e \ No newline at end of file From bd3af5f697e0f68df0bfeda81c5dacca4ceaaee3 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 09:01:03 +0000 Subject: [PATCH 37/54] fix --- .../visualserver/model_infer/model_rpc.py | 30 --------- .../model_infer/visual_only_model_rpc.py | 64 ++++++++++--------- 2 files changed, 35 insertions(+), 59 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 449f97cefa..dd74a2f24d 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -141,33 +141,3 @@ def exposed_init_model(self, kvargs): @torch.no_grad() def forward(self, images: List[ImageItem]): return self.model.encode(images) - - # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images: List[ImageItem]): - images = obtain(images) - all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) - - assert not self.is_visual_only_mode - self._put_image_embed_to_cpu_cache( - all_img_embeds=all_img_embeds, uuids=uuids, valid_ids=valid_ids, images=images - ) - return - - def _put_image_embed_to_cpu_cache(self, all_img_embeds, uuids, valid_ids, images): - if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] - self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache - ) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) - torch.cuda.current_stream().synchronize() \ No newline at end of file diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py index 72b4520074..13fee51270 100644 --- a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py +++ b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py @@ -140,19 +140,10 @@ def _infer_worker(self): all_img_embeds = all_img_embeds.to(torch.device("cuda")) if self.is_visual_only_mode: - all_img_embeds = all_img_embeds.detach().cpu() - for image, valid_id in zip(images, valid_ids): - start, end = valid_id - gen_embed = all_img_embeds[start:end] - image.gen_embed = gen_embed - self.store_queue.put(image) + self._store_to_afs(all_img_embeds, valid_ids, images) else: self._store_to_cpu_cache(all_img_embeds, valid_ids, images) - all_img_embeds = all_img_embeds.detach().cpu() - for image, valid_id in zip(images, valid_ids): - start, end = valid_id - self.put_afs_queue.put((image, all_img_embeds[start:end])) - + except Exception as e: logger.exception(str(e)) raise e @@ -169,6 +160,14 @@ def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): cuda_event.record() image.cuda_event = cuda_event self.store_queue.put(image) + + def _store_to_afs(self, all_img_embeds, valid_ids, images): + all_img_embeds = all_img_embeds.detach().cpu() + for image, valid_id in zip(images, valid_ids): + start, end = valid_id + gen_embed = all_img_embeds[start:end] + image.gen_embed = gen_embed + self.store_queue.put(image) def _store_worker(self): """ @@ -178,25 +177,32 @@ def _store_worker(self): try: # 从队列获取任务, 阻塞等待 images: List[ImageItem] = self._get_image_items_from_store_queue(max_num=self.max_infer_batch_size) - # 只有 0 rank 执行真的写入操作。 - if self.tp_rank_id == 0: - if self.is_visual_only_mode: - for image in images: - self.afs_handler.insert(image.md5, image.gen_embed) - image.event.set() - else: - for image in images: - # 等待拷贝到cpu cache 完成。 - image.cuda_event.synchronize() - - uuids = [image.uuid for image in images] - self.cache_client.root.set_items_embed(uuids) - - for image in images: - image.event.set() - + + if self.is_visual_only_mode: + self._commit_to_afs(images=images) + else: + self._commit_to_cpu_cache(images=images) + for _ in images: self.sempare.release() except Exception as e: logger.exception(str(e)) - raise e \ No newline at end of file + raise e + + def _commit_to_afs(self, images): + if self.tp_rank_id == 0: + for image in images: + self.afs_handler.insert(image.md5, image.gen_embed) + image.event.set() + + def _commit_to_cpu_cache(self, images): + if self.tp_rank_id == 0: + for image in images: + # 等待拷贝到cpu cache 完成。 + image.cuda_event.synchronize() + + uuids = [image.uuid for image in images] + self.cache_client.root.set_items_embed(uuids) + + for image in images: + image.event.set() \ No newline at end of file From fce01360ece4348ed374650ee5efc2eb7fec743f Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 09:55:50 +0000 Subject: [PATCH 38/54] fix --- .../visualserver/model_infer/model_rpc.py | 213 ++++++++++++++++-- .../model_infer/visual_only_model_rpc.py | 208 ----------------- 2 files changed, 195 insertions(+), 226 deletions(-) delete mode 100644 lightllm/server/visualserver/model_infer/visual_only_model_rpc.py diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index dd74a2f24d..f947c5b1fb 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -1,19 +1,14 @@ -import asyncio -import numpy as np import rpyc import torch import socket -import inspect -import uuid -import os import torch.multiprocessing as mp -import collections -from datetime import timedelta +import queue +import threading +import torch.distributed as dist +import torch from typing import Dict, List, Tuple, Deque, Optional from transformers.configuration_utils import PretrainedConfig -from lightllm.utils.retry_utils import retry -from rpyc.utils.classic import obtain, unix_connect -from rpyc.utils.server import ThreadedServer +from rpyc.utils.classic import obtain from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel @@ -27,14 +22,21 @@ from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env -from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from lightllm.server.multimodal_params import ImageItem +from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) class VisualModelRpcServer(rpyc.Service): + def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) @@ -120,12 +122,8 @@ def exposed_init_model(self, kvargs): else: args = get_env_start_args() assert args.visual_dp == 1 - self.redis_afs_client = SepEmbedHandler( - afs_embed_dir=args.afs_embed_dir, - redis_host=args.config_server_host, - redis_port=args.config_server_vit_redis_port, - capacity=args.afs_embed_capacity, - ) + + self._init_taskes() except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -136,8 +134,187 @@ def exposed_init_model(self, kvargs): set_random_seed(2147483647) return + + def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): + try: + images = obtain(images) + for i in range(len(images)): + images[i].event = ref_event_list[i] + self.infer_queue.put(images[i]) + + except BaseException as e: + logger.exception(str(e)) + raise e + return + + def _init_taskes(self): + # 控制每次的最大推理图片数量,防止爆显存 + self.max_infer_batch_size = get_env_start_args().visual_infer_batch_size + + # 异步队列, 用于接受任务 + self.infer_queue = queue.Queue() + self.infer_queue_lock = threading.Lock() + # 将计算得到的结果放入 afs 或者 embed cache 的 queue + self.store_queue = queue.Queue() + + # 限制并发, 主要控制内存用量,防止过多造成内存OOM + self.sempare = threading.Semaphore(self.max_infer_batch_size * 8) + + # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 + self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") + + self.afs_handler = SepEmbedHandler( + afs_embed_dir=get_env_start_args().afs_embed_dir, + redis_host=get_env_start_args().config_server_host, + redis_port=get_env_start_args().config_server_vit_redis_port, + capacity=get_env_start_args().afs_embed_capacity, + ) + + # 启动任务处理线程 + self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) + self._infer_thread.start() + + self._store_thread = threading.Thread(target=self._store_worker, daemon=True) + self._store_thread.start() + pass # @calculate_time(show=True, min_cost_ms=150) @torch.no_grad() - def forward(self, images: List[ImageItem]): + def _forward(self, images: List[ImageItem]): return self.model.encode(images) + + + def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) + + if not force_same: + # 尝试继续获取更多任务,直到达到 max_num + while len(tasks) < max_num: + try: + self.sempare.acquire() + task = self.infer_queue.get(block=False) + tasks.append(task) + except queue.Empty: + self.sempare.release() + break + else: + while len(tasks) < max_num: + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) + + return tasks + + def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + task = self.store_queue.get(block=True) + tasks.append(task) + + while len(tasks) < max_num: + try: + task = self.store_queue.get(block=False) + tasks.append(task) + except queue.Empty: + break + + return tasks + + + def _infer_worker(self): + """ + 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 + """ + torch.cuda.set_device(self.device_id) + while True: + try: + # 从队列获取任务, 阻塞等待 + if self.tp_rank_id == 0: + images = self._get_image_items_from_infer_queue(max_num=self.max_infer_batch_size) + dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) + else: + ans = [None] + dist.broadcast_object_list(ans, src=0, group=self.gloo_group) + images = self._get_image_items_from_infer_queue(max_num=ans[0], force_same=True) + + # 执行任务: 调用父类的forward方法处理图像 + all_img_embeds, uuids, valid_ids = self._forward(images) + all_img_embeds = all_img_embeds.to(torch.device("cuda")) + + if self.is_visual_only_mode: + self._store_to_afs(all_img_embeds, valid_ids, images) + else: + self._store_to_cpu_cache(all_img_embeds, valid_ids, images) + + except Exception as e: + logger.exception(str(e)) + raise e + + def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): + for i in range(len(images)): + start, end = valid_ids[i] + image = images[i] + if self.tp_rank_id == 0: + self.cpu_embed_cache_client.copy_vision_to_cache( + embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + ) + cuda_event = torch.cuda.Event() + cuda_event.record() + image.cuda_event = cuda_event + self.store_queue.put(image) + + def _store_to_afs(self, all_img_embeds, valid_ids, images): + all_img_embeds = all_img_embeds.detach().cpu() + for image, valid_id in zip(images, valid_ids): + start, end = valid_id + gen_embed = all_img_embeds[start:end] + image.gen_embed = gen_embed + self.store_queue.put(image) + + def _store_worker(self): + """ + 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 + """ + while True: + try: + # 从队列获取任务, 阻塞等待 + images: List[ImageItem] = self._get_image_items_from_store_queue(max_num=self.max_infer_batch_size) + + if self.is_visual_only_mode: + self._commit_to_afs(images=images) + else: + self._commit_to_cpu_cache(images=images) + + for _ in images: + self.sempare.release() + except Exception as e: + logger.exception(str(e)) + raise e + + def _commit_to_afs(self, images): + if self.tp_rank_id == 0: + for image in images: + self.afs_handler.insert(image.md5, image.gen_embed) + image.event.set() + + def _commit_to_cpu_cache(self, images): + if self.tp_rank_id == 0: + for image in images: + # 等待拷贝到cpu cache 完成。 + image.cuda_event.synchronize() + + uuids = [image.uuid for image in images] + self.cache_client.root.set_items_embed(uuids) + + for image in images: + image.event.set() diff --git a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py b/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py deleted file mode 100644 index 13fee51270..0000000000 --- a/lightllm/server/visualserver/model_infer/visual_only_model_rpc.py +++ /dev/null @@ -1,208 +0,0 @@ -import queue -import threading -import torch.distributed as dist -import torch -from typing import List, Any, Deque, Tuple -from .model_rpc import VisualModelRpcServer -from lightllm.server.multimodal_params import ImageItem -from lightllm.server.embed_cache.afs_utils import SepEmbedHandler -from rpyc.utils.server import ThreadedServer -from lightllm.utils.envs_utils import get_env_start_args -from rpyc.utils.classic import obtain -from lightllm.utils.log_utils import init_logger - - -logger = init_logger(__name__) - -class VisualOnlyModelRpcServer(VisualModelRpcServer): - """ - 完善这个代码: - 1. 创建一个队列, 用于接受别人放入的task, - 2. 创建一个线程,从队列中取出任务,完成后,修改task中的event,让放入的人得到结果和通知。这是任务循环。 - 3. 能不能封装比较易读的流程。 - """ - - def __init__(self): - super().__init__() - - - - - def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): - try: - images = obtain(images) - for i in range(len(images)): - images[i].event = ref_event_list[i] - self.infer_queue.put(images[i]) - - except BaseException as e: - logger.exception(str(e)) - raise e - return - - def init_taskes(self): - # 控制每次的最大推理图片数量,防止爆显存 - self.max_infer_batch_size = get_env_start_args().visual_infer_batch_size - - # 异步队列, 用于接受任务 - self.infer_queue = queue.Queue() - self.infer_queue_lock = threading.Lock() - # 将计算得到的结果放入 afs 或者 embed cache 的 queue - self.store_queue = queue.Queue() - - # 限制并发, 主要控制内存用量,防止过多造成内存OOM - self.sempare = threading.Semaphore(self.max_infer_batch_size * 8) - - # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 - self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") - - self.afs_handler = SepEmbedHandler( - afs_embed_dir=get_env_start_args().afs_embed_dir, - redis_host=get_env_start_args().config_server_host, - redis_port=get_env_start_args().config_server_vit_redis_port, - capacity=get_env_start_args().afs_embed_capacity, - ) - - # 启动任务处理线程 - self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) - self._infer_thread.start() - - self._store_thread = threading.Thread(target=self._store_worker, daemon=True) - self._store_thread.start() - pass - - - def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: - """ - 从队列中批量获取任务,直到达到 max_num 或队列为空。 - """ - tasks = [] - # 至少获取一个任务,阻塞 - self.sempare.acquire() - task = self.infer_queue.get(block=True) - tasks.append(task) - - if not force_same: - # 尝试继续获取更多任务,直到达到 max_num - while len(tasks) < max_num: - try: - self.sempare.acquire() - task = self.infer_queue.get(block=False) - tasks.append(task) - except queue.Empty: - self.sempare.release() - break - else: - while len(tasks) < max_num: - self.sempare.acquire() - task = self.infer_queue.get(block=True) - tasks.append(task) - - return tasks - - def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: - """ - 从队列中批量获取任务,直到达到 max_num 或队列为空。 - """ - tasks = [] - # 至少获取一个任务,阻塞 - task = self.store_queue.get(block=True) - tasks.append(task) - - while len(tasks) < max_num: - try: - task = self.store_queue.get(block=False) - tasks.append(task) - except queue.Empty: - break - - return tasks - - - def _infer_worker(self): - """ - 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 - """ - torch.cuda.set_device(self.device_id) - while True: - try: - # 从队列获取任务, 阻塞等待 - if self.tp_rank_id == 0: - images = self._get_image_items_from_infer_queue(max_num=self.max_infer_batch_size) - dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) - else: - ans = [None] - dist.broadcast_object_list(ans, src=0, group=self.gloo_group) - images = self._get_image_items_from_infer_queue(max_num=ans[0], force_same=True) - - # 执行任务: 调用父类的forward方法处理图像 - all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) - - if self.is_visual_only_mode: - self._store_to_afs(all_img_embeds, valid_ids, images) - else: - self._store_to_cpu_cache(all_img_embeds, valid_ids, images) - - except Exception as e: - logger.exception(str(e)) - raise e - - def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): - for i in range(len(images)): - start, end = valid_ids[i] - image = images[i] - if self.tp_rank_id == 0: - self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache - ) - cuda_event = torch.cuda.Event() - cuda_event.record() - image.cuda_event = cuda_event - self.store_queue.put(image) - - def _store_to_afs(self, all_img_embeds, valid_ids, images): - all_img_embeds = all_img_embeds.detach().cpu() - for image, valid_id in zip(images, valid_ids): - start, end = valid_id - gen_embed = all_img_embeds[start:end] - image.gen_embed = gen_embed - self.store_queue.put(image) - - def _store_worker(self): - """ - 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 - """ - while True: - try: - # 从队列获取任务, 阻塞等待 - images: List[ImageItem] = self._get_image_items_from_store_queue(max_num=self.max_infer_batch_size) - - if self.is_visual_only_mode: - self._commit_to_afs(images=images) - else: - self._commit_to_cpu_cache(images=images) - - for _ in images: - self.sempare.release() - except Exception as e: - logger.exception(str(e)) - raise e - - def _commit_to_afs(self, images): - if self.tp_rank_id == 0: - for image in images: - self.afs_handler.insert(image.md5, image.gen_embed) - image.event.set() - - def _commit_to_cpu_cache(self, images): - if self.tp_rank_id == 0: - for image in images: - # 等待拷贝到cpu cache 完成。 - image.cuda_event.synchronize() - - uuids = [image.uuid for image in images] - self.cache_client.root.set_items_embed(uuids) - - for image in images: - image.event.set() \ No newline at end of file From 53b9127c214963680205ae55687b75646e45d7f6 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 09:58:33 +0000 Subject: [PATCH 39/54] fix --- lightllm/server/visualserver/model_infer/__init__.py | 12 +++++------- .../visualserver/model_infer/model_rpc_client.py | 5 ++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index 60c9de817e..aeb8382e74 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -10,17 +10,15 @@ from lightllm.utils.envs_utils import get_env_start_args from .model_rpc_client import VisualModelRpcClient from .model_rpc import VisualModelRpcServer -from .visual_only_model_rpc import VisualOnlyModelRpcServer + def _init_env(socket_path: str, success_event): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) import lightllm.utils.rpyc_fix_utils as _ - if get_env_start_args().run_mode == "visual_only": - t = ThreadedServer(VisualOnlyModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) - else: - t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) + + t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) success_event.set() t.start() return @@ -50,11 +48,11 @@ async def start_model_process(): # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 conn._bg_thread = rpyc.BgServingThread(conn) - + return VisualModelRpcClient(conn) def _generate_unix_socket_path() -> str: """Generate a random Unix socket path""" unique_id = uuid.uuid4().hex[:8] - return f"/tmp/lightllm_model_infer_{unique_id}.sock" \ No newline at end of file + return f"/tmp/lightllm_model_infer_{unique_id}.sock" diff --git a/lightllm/server/visualserver/model_infer/model_rpc_client.py b/lightllm/server/visualserver/model_infer/model_rpc_client.py index c69cac2b61..682d6affcc 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc_client.py +++ b/lightllm/server/visualserver/model_infer/model_rpc_client.py @@ -4,12 +4,11 @@ from typing import Dict, List, Tuple, Deque, Optional, Union from lightllm.server.multimodal_params import ImageItem from .model_rpc import VisualModelRpcServer -from .visual_only_model_rpc import VisualOnlyModelRpcServer class VisualModelRpcClient: def __init__(self, rpc_conn): - self.rpc_conn: Union[VisualModelRpcServer, VisualOnlyModelRpcServer] = rpc_conn + self.rpc_conn: VisualModelRpcServer = rpc_conn def async_wrap(f): f = rpyc.async_(f) @@ -31,7 +30,7 @@ async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) await ans return - + async def run_task(self, images: List[ImageItem], ref_event_list: List[threading.Event]): ans = self._run_task(images, ref_event_list) return await ans From 2df2ce36d4aef0d1e9b79c42b458cc78eca677ed Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 10:02:05 +0000 Subject: [PATCH 40/54] fix --- lightllm/server/visualserver/model_infer/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index aeb8382e74..bc7f7a76d5 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -4,7 +4,8 @@ import uuid import os from lightllm.utils.retry_utils import retry -from rpyc.utils.classic import obtain, unix_connect +from rpyc.utils.factory import unix_connect +from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args @@ -46,7 +47,7 @@ async def start_model_process(): conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) assert proc.is_alive() - # 服务端需要调用event所以,客户端需要一个后台线程进行相关的处理。 + # 服务端需要调用客户端传入的event所以,客户端需要一个后台线程进行相关的处理。 conn._bg_thread = rpyc.BgServingThread(conn) return VisualModelRpcClient(conn) From 3ad7daa9d8ac43a034eaf7c0139d0910b0922c97 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 10:21:47 +0000 Subject: [PATCH 41/54] fix --- .../visualserver/model_infer/model_rpc.py | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index f947c5b1fb..e0d4eb2c23 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -5,7 +5,6 @@ import queue import threading import torch.distributed as dist -import torch from typing import Dict, List, Tuple, Deque, Optional from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain @@ -26,9 +25,6 @@ from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend from lightllm.server.embed_cache.afs_utils import SepEmbedHandler -from lightllm.server.multimodal_params import ImageItem -from lightllm.server.embed_cache.afs_utils import SepEmbedHandler -from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger @@ -36,7 +32,6 @@ class VisualModelRpcServer(rpyc.Service): - def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) @@ -56,6 +51,7 @@ def exposed_init_model(self, kvargs): # } weight_dir = kvargs["weight_dir"] + self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] @@ -115,13 +111,19 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() if not self.is_visual_only_mode: - # 独立部署vit模式下,不需要连接 cache_client self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) else: + # 独立部署vit模式下,不需要连接 cache_client, 结果是写入 afs args = get_env_start_args() assert args.visual_dp == 1 + self.afs_handler = SepEmbedHandler( + afs_embed_dir=self.args.afs_embed_dir, + redis_host=self.args.config_server_host, + redis_port=self.args.config_server_vit_redis_port, + capacity=self.args.afs_embed_capacity, + ) self._init_taskes() except Exception as e: @@ -134,7 +136,7 @@ def exposed_init_model(self, kvargs): set_random_seed(2147483647) return - + def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): try: images = obtain(images) @@ -146,44 +148,36 @@ def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threa logger.exception(str(e)) raise e return - + def _init_taskes(self): + self.args = get_env_start_args() # 控制每次的最大推理图片数量,防止爆显存 - self.max_infer_batch_size = get_env_start_args().visual_infer_batch_size + self.max_infer_batch_size = self.args.visual_infer_batch_size # 异步队列, 用于接受任务 self.infer_queue = queue.Queue() - self.infer_queue_lock = threading.Lock() # 将计算得到的结果放入 afs 或者 embed cache 的 queue self.store_queue = queue.Queue() - # 限制并发, 主要控制内存用量,防止过多造成内存OOM + # 限制并发, 主要是为了控制内存用量,防止过多造成内存OOM self.sempare = threading.Semaphore(self.max_infer_batch_size * 8) # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") - self.afs_handler = SepEmbedHandler( - afs_embed_dir=get_env_start_args().afs_embed_dir, - redis_host=get_env_start_args().config_server_host, - redis_port=get_env_start_args().config_server_vit_redis_port, - capacity=get_env_start_args().afs_embed_capacity, - ) - # 启动任务处理线程 self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) self._infer_thread.start() self._store_thread = threading.Thread(target=self._store_worker, daemon=True) self._store_thread.start() - pass + return # @calculate_time(show=True, min_cost_ms=150) @torch.no_grad() def _forward(self, images: List[ImageItem]): return self.model.encode(images) - def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: """ 从队列中批量获取任务,直到达到 max_num 或队列为空。 @@ -192,8 +186,8 @@ def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = Fal # 至少获取一个任务,阻塞 self.sempare.acquire() task = self.infer_queue.get(block=True) - tasks.append(task) - + tasks.append(task) + if not force_same: # 尝试继续获取更多任务,直到达到 max_num while len(tasks) < max_num: @@ -211,7 +205,7 @@ def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = Fal tasks.append(task) return tasks - + def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: """ 从队列中批量获取任务,直到达到 max_num 或队列为空。 @@ -219,8 +213,8 @@ def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: tasks = [] # 至少获取一个任务,阻塞 task = self.store_queue.get(block=True) - tasks.append(task) - + tasks.append(task) + while len(tasks) < max_num: try: task = self.store_queue.get(block=False) @@ -229,7 +223,6 @@ def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: break return tasks - def _infer_worker(self): """ @@ -255,11 +248,11 @@ def _infer_worker(self): self._store_to_afs(all_img_embeds, valid_ids, images) else: self._store_to_cpu_cache(all_img_embeds, valid_ids, images) - + except Exception as e: logger.exception(str(e)) raise e - + def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): for i in range(len(images)): start, end = valid_ids[i] @@ -280,7 +273,7 @@ def _store_to_afs(self, all_img_embeds, valid_ids, images): gen_embed = all_img_embeds[start:end] image.gen_embed = gen_embed self.store_queue.put(image) - + def _store_worker(self): """ 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 @@ -297,16 +290,17 @@ def _store_worker(self): for _ in images: self.sempare.release() + except Exception as e: logger.exception(str(e)) raise e - + def _commit_to_afs(self, images): if self.tp_rank_id == 0: for image in images: self.afs_handler.insert(image.md5, image.gen_embed) image.event.set() - + def _commit_to_cpu_cache(self, images): if self.tp_rank_id == 0: for image in images: From c03c7d0acaff1c69a89ab9fbf95120b159560d57 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 10:26:40 +0000 Subject: [PATCH 42/54] fix --- lightllm/server/visualserver/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index e7a081f42f..b78365789c 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -13,7 +13,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry From 36b4bb7b4eab6fb983ef0cccba922ee65ee7c53e Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 11:08:51 +0000 Subject: [PATCH 43/54] fix --- lightllm/server/visualserver/manager.py | 141 +++++++++++------------- 1 file changed, 63 insertions(+), 78 deletions(-) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b78365789c..efe8cba783 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,6 +7,8 @@ import pickle import inspect import setproctitle +import threading +import collections from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -48,7 +50,6 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp @@ -56,6 +57,8 @@ def __init__( self.infer_batch_size = args.visual_infer_batch_size self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() + self.cur_dp_index = 0 + self.lock = threading.Lock() async def wait_to_model_ready(self): @@ -89,87 +92,70 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - async def infer_imgs(self, images: List[ImageItem]): - if len(images) == 0: + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) + is_aborted = shm_req.is_aborted + disable_prompt_cache = shm_req.sample_params.disable_prompt_cache + self.shm_req_manager.put_back_req_obj(shm_req) + # case 0 + if is_aborted: + # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 + # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 + # 需要一些一致的流程来保证不出现异步问题。 + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return - tasks = [] - for vit_dp_rank in range(self.vit_dp): - assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] - if assigned_images: - for vit_tp_rank in range(self.vit_tp): - task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) - tasks.append(task) + multimodal_params = group_req_indexes.multimodal_params + img_uuids = [img.uuid for img in multimodal_params.images] + # disable prompt cache通常用来测试,需要也去掉image cache的影响 + if disable_prompt_cache: + ready_image = [False] * len(img_uuids) + else: + if len(img_uuids) > 0: + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + else: + ready_image = [] - await asyncio.gather(*tasks) - return + images_need_infer = [] + for img, ready in zip(multimodal_params.images, ready_image): + if not ready: + images_need_infer.append(img) - async def loop_for_fwd(self): - while True: - if len(self.waiting_reqs) == 0: - await asyncio.sleep(0.01) # 10ms - else: - processing_group_reqs = [] - images_need_infer = [] - ready_to_send = [] - - def flush_ready(force: bool = False): - if not ready_to_send: - return - if not force and len(ready_to_send) < self.send_batch_size: - return - - for group_req_indexes in ready_to_send: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - ready_to_send.clear() - - while len(self.waiting_reqs) > 0: - group_req_indexes = self.waiting_reqs.pop(0) - shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) - is_aborted = shm_req.is_aborted - disable_prompt_cache = shm_req.sample_params.disable_prompt_cache - self.shm_req_manager.put_back_req_obj(shm_req) - if is_aborted: - # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 - # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 - # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - continue - - multimodal_params = group_req_indexes.multimodal_params - - img_uuids = [img.uuid for img in multimodal_params.images] - # disable prompt cache通常用来测试,需要也去掉image cache的影响 - if disable_prompt_cache: - ready_image = [False] * len(img_uuids) - else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) - - for img, ready in zip(multimodal_params.images, ready_image): - if not ready: - images_need_infer.append(img) - - if len(images_need_infer) == self.infer_batch_size: - await self.infer_imgs(images_need_infer) - images_need_infer = [] - ready_to_send.extend(processing_group_reqs) - processing_group_reqs = [] - flush_ready(force=False) - - if len(images_need_infer) == 0: - ready_to_send.append(group_req_indexes) - flush_ready(force=False) - else: - processing_group_reqs.append(group_req_indexes) + # case 1 + if len(images_need_infer) == 0: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + # case 2 + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.extend(self.run_task(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images])) - if len(images_need_infer) > 0: - await self.infer_imgs(images_need_infer) - images_need_infer = [] + with self.lock: + await asyncio.gather(*taskes) + + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) + + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return - # 这些处理完 image 的 group 也 ready 了 - ready_to_send.extend(processing_group_reqs) - processing_group_reqs = [] - flush_ready(force=True) + def run_task(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + return taskes async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): @@ -184,7 +170,7 @@ async def loop_for_netio_req(self): f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) - self.waiting_reqs.append(recv_req) + asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) @@ -218,6 +204,5 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) loop.run_until_complete(visualserver.loop_for_netio_req()) return From 2cf51dd376d57787cd316369d8f5853fbcce4604 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 11:17:40 +0000 Subject: [PATCH 44/54] fix --- lightllm/server/visualserver/model_infer/model_rpc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index e0d4eb2c23..c46443f373 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -51,6 +51,7 @@ def exposed_init_model(self, kvargs): # } weight_dir = kvargs["weight_dir"] + self.infer_max_batch_size = kvargs["max_batch_size"] self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -151,8 +152,6 @@ def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threa def _init_taskes(self): self.args = get_env_start_args() - # 控制每次的最大推理图片数量,防止爆显存 - self.max_infer_batch_size = self.args.visual_infer_batch_size # 异步队列, 用于接受任务 self.infer_queue = queue.Queue() @@ -160,7 +159,7 @@ def _init_taskes(self): self.store_queue = queue.Queue() # 限制并发, 主要是为了控制内存用量,防止过多造成内存OOM - self.sempare = threading.Semaphore(self.max_infer_batch_size * 8) + self.sempare = threading.Semaphore(self.infer_max_batch_size * 8) # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") From f866f7bd131612d3926e2b7505f399b072b0a3c6 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 11:18:32 +0000 Subject: [PATCH 45/54] fix --- lightllm/server/visualserver/model_infer/model_rpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index c46443f373..39ce7819d3 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -232,7 +232,7 @@ def _infer_worker(self): try: # 从队列获取任务, 阻塞等待 if self.tp_rank_id == 0: - images = self._get_image_items_from_infer_queue(max_num=self.max_infer_batch_size) + images = self._get_image_items_from_infer_queue(max_num=self.infer_max_batch_size) dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) else: ans = [None] From 7419e799ed3fd5afcb900d225bf8e25f56f35600 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 11:23:05 +0000 Subject: [PATCH 46/54] fix --- lightllm/server/visualserver/model_infer/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index bc7f7a76d5..c186856442 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -3,6 +3,7 @@ import inspect import uuid import os +import multiprocessing from lightllm.utils.retry_utils import retry from rpyc.utils.factory import unix_connect from rpyc.utils.classic import obtain @@ -26,7 +27,7 @@ def _init_env(socket_path: str, success_event): async def start_model_process(): - import multiprocessing + import lightllm.utils.rpyc_fix_utils as _ socket_path = _generate_unix_socket_path() if os.path.exists(socket_path): From 5ddb9159e831727c18f37a71839ac7b041fcdafb Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 11:25:29 +0000 Subject: [PATCH 47/54] fix --- .../server/visualserver/model_infer/model_rpc.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 39ce7819d3..95415e4b4e 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -119,12 +119,13 @@ def exposed_init_model(self, kvargs): # 独立部署vit模式下,不需要连接 cache_client, 结果是写入 afs args = get_env_start_args() assert args.visual_dp == 1 - self.afs_handler = SepEmbedHandler( - afs_embed_dir=self.args.afs_embed_dir, - redis_host=self.args.config_server_host, - redis_port=self.args.config_server_vit_redis_port, - capacity=self.args.afs_embed_capacity, - ) + if self.tp_rank_id == 0: + self.afs_handler = SepEmbedHandler( + afs_embed_dir=self.args.afs_embed_dir, + redis_host=self.args.config_server_host, + redis_port=self.args.config_server_vit_redis_port, + capacity=self.args.afs_embed_capacity, + ) self._init_taskes() except Exception as e: From bb363bbe837f0ab9fd95d9305ffcf2e0e7420727 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 11:28:45 +0000 Subject: [PATCH 48/54] fix --- lightllm/server/visualserver/manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index efe8cba783..6e90d91848 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -184,6 +184,8 @@ def clean_up(self): def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") From 110c2caee8a7723c2b86fd97db0a67d99d0b251d Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 12:06:56 +0000 Subject: [PATCH 49/54] fix --- .../visualserver/visual_only_manager.py | 192 ++++++++---------- 1 file changed, 82 insertions(+), 110 deletions(-) diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index e44aece533..1fd7d7f1ba 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -1,19 +1,16 @@ import asyncio import uvloop +import rpyc import inspect import setproctitle import threading -import queue -import dataclasses -import rpyc -import uuid -from typing import List, Any -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +import collections +from typing import List from lightllm.server.core.objs import StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry @@ -31,121 +28,105 @@ def __init__( args: StartArgs, ): self.args = args - self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir self.vit_dp = args.visual_dp - self.vit_tp = args.visual_tp assert self.vit_dp == 1 - - # 工作线程 - self.task_queue = queue.Queue() - # 限制并发, 主要控制内存用量,防止过多照成爆炸。 - self.sempare = threading.Semaphore(3) - # 启动任务处理线程 - self.worker_thread = threading.Thread(target=self._task_worker, daemon=True) - self.worker_thread.start() + self.vit_tp = args.visual_tp + # image 最大推理 batch size + self.infer_batch_size = args.visual_infer_batch_size + self.cur_dp_index = 0 + self.lock = threading.Lock() + + self.new_loop = asyncio.new_event_loop() + t = threading.Thread(target=self.event_loop, args=(self.new_loop,), daemon=True) + t.start() + + def event_loop(self, loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever() + return async def wait_to_model_ready(self): - self.model_rpcs: List[List[VisualModelRpcClient]] = [] - self.model_rpcs_1: List[List[VisualModelRpcClient]] = [] + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) - for tp_rank_id in range(self.vit_tp): - rpc_model = await start_model_process() - self.model_rpcs.append(rpc_model[0]) - self.model_rpcs_1.append(rpc_model[1]) + for dp_rank_id in range(self.vit_dp): + for tp_rank_id in range(self.vit_tp): - init_model_ret = [] + rpc_model = await start_model_process() + self.model_rpcs[dp_rank_id].append(rpc_model) - for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[tp_rank_id] - kvargs = { - "weight_dir": self.model_weightdir, - "device_id": device_id, - "vit_tp": self.vit_tp, - "cache_port": self.args.cache_port, - "tp_rank_id": tp_rank_id, - "dp_rank_id": 0, - "data_type": self.args.data_type, - "visual_nccl_port": self.args.visual_nccl_ports[0], - "quant_type": self.args.vit_quant_type, - "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.args.visual_infer_batch_size // self.vit_dp, 1), - "vit_attn_backend": self.vit_attn_backend, - } - init_model_ret.append(self.model_rpcs[tp_rank_id].init_model(kvargs)) + init_model_ret = [] + for dp_rank_id in range(self.vit_dp): # async init model process + for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] + kvargs = { + "weight_dir": self.model_weightdir, + "device_id": device_id, + "vit_tp": self.vit_tp, + "cache_port": self.args.cache_port, + "tp_rank_id": tp_rank_id, + "dp_rank_id": dp_rank_id, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + "quant_type": self.args.vit_quant_type, + "quant_cfg": self.args.vit_quant_cfg, + "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, + } + init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) return - async def infer_imgs(self, images: List[ImageItem], infer_uid: str): - assert len(images) != 0 - tasks = [] - for vit_tp_rank in range(self.vit_tp): - task = asyncio.create_task(self.model_rpcs[vit_tp_rank].encode(images, infer_uid=infer_uid)) - tasks.append(task) - - await asyncio.gather(*tasks) + async def handle_reqs(self, images_need_infer: List[ImageItem]): + # case 2 + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.extend(self.run_task(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images])) + + with self.lock: + await asyncio.gather(*taskes) + + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) return - async def put_to_afs(self, infer_uid: str): - await self.model_rpcs_1[0].put_to_afs(infer_uid) + def run_task(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + return taskes + + def clean_up(self): return - def _task_worker(self): - """ - 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 - """ - while True: - try: - # 从队列获取任务, 阻塞等待 - task: _Task = self.task_queue.get() - - # 执行任务: 调用父类的forward方法处理图像 - try: - asyncio.run(self.infer_imgs(task.images)) - except Exception as e: - task.hasError = True - logger.exception(str(e)) - raise e - finally: - # 标记任务完成, 唤醒等待的调用者 - task.event.set() - self.task_queue.task_done() - - except Exception as e: - logger.exception(str(e)) - raise e - - def exposed_run_task(self, images: List["ImageItem"]): - """ - 添加任务到队列 - - Args: - images: 要处理的图像列表 - - Returns: - _Task: 任务对象, 包含ret和event - """ + def exposed_infer_images(self, images: List[ImageItem], ref_event: threading.Event): try: images = obtain(images) - # 写入 shm, 然后 - - with self.sempare: - event = threading.Event() - task = _Task(images=images, infer_uid=uuid.uuid4().hex, vent=event) - self.task_queue.put(task) - task.event.wait(timeout=8888) - - asyncio.run(self.put_to_afs(infer_uids=task.infer_uid)) + # 将 images 的内容写入到 shm 中, - # 将 shm 进行删除 + handle = asyncio.run_coroutine_threadsafe(self.handle_reqs(images_need_infer=images), loop=self.new_loop) + handle.result() + ref_event.set() except BaseException as e: logger.exception(str(e)) raise e - return + finally: + # 将 shm 进行删除 + pass - def clean_up(self): return @@ -156,9 +137,11 @@ def start_visual_process(args, pipe_writer): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() + try: visualserver = VisualManager(args=args) - asyncio.run(visualserver.wait_to_model_ready()) + future = asyncio.run_coroutine_threadsafe(visualserver.wait_to_model_ready(), loop=visualserver.new_loop) + future.result() t = rpyc.ThreadedServer(visualserver, port=None, protocol_config={"allow_pickle": True}) except Exception as e: logger.exception(str(e)) @@ -169,14 +152,3 @@ def start_visual_process(args, pipe_writer): t.start() return - - -@dataclasses.dataclass -class _Task: - images: List["ImageItem"] - event: threading.Event - infer_uid: str - hasError: bool = False - - def wait(self, timeout: float = None): - self.event.wait(timeout=timeout) From 526500fabba879f952ead0ea3612e6d0633e0f88 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 13:03:48 +0000 Subject: [PATCH 50/54] fix --- lightllm/server/visualserver/visual_only_manager.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 1fd7d7f1ba..d722116e36 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -5,6 +5,7 @@ import setproctitle import threading import collections +import uuid from typing import List from lightllm.server.core.objs import StartArgs @@ -17,6 +18,7 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name from rpyc.utils.classic import obtain +from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data, free_shm logger = init_logger(__name__) @@ -114,7 +116,13 @@ def clean_up(self): def exposed_infer_images(self, images: List[ImageItem], ref_event: threading.Event): try: images = obtain(images) - # 将 images 的内容写入到 shm 中, + # 将 images 的内容写入到 shm 中,这里修改了原始的uuid,主要是在远端的vit + # 本身不具有 embed cache 的引用保证,则新的唯一标识来进行推理,最终写入的 + # 目标的 md5 一致即可,这样调用端一样可以拿到准确的数据。 + for image in images: + image.uuid = str(uuid.uuid4()) + create_shm(get_shm_name_data(image.uuid), image.data_bytes) + del image.data_bytes handle = asyncio.run_coroutine_threadsafe(self.handle_reqs(images_need_infer=images), loop=self.new_loop) handle.result() @@ -125,7 +133,8 @@ def exposed_infer_images(self, images: List[ImageItem], ref_event: threading.Eve raise e finally: # 将 shm 进行删除 - pass + for image in images: + free_shm(get_shm_name_data(image.uuid)) return From b48528ad02dd14d55276d9651765544fb31084ff Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 13:18:39 +0000 Subject: [PATCH 51/54] fix --- lightllm/server/api_cli.py | 15 ++++++++++++--- lightllm/server/core/objs/start_args_type.py | 7 +++++-- .../server/visualserver/visual_only_manager.py | 7 +++++-- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 7d1cdb136e..19aacdf849 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -15,7 +15,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "nixl_decode", "pd_master", "config_server", - "only_visual_infer", + "visual_only", ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, @@ -71,12 +71,12 @@ def make_argument_parser() -> argparse.ArgumentParser: help="The port number for the config server in config_server mode.", ) parser.add_argument( - "--config_server_vit_redis_port", + "--config_server_visual_redis_port", type=int, default=None, help="""when run_mode is config_server, set this params will start a redis server, when a llm infer node start to set this params, the visual infer module will start a - proxy module use config server to find remote vit infer nodes to infer img""" + proxy module use config server to find remote vit infer nodes to infer img""", ) parser.add_argument( "--nixl_pd_kv_page_num", @@ -457,6 +457,15 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) + parser.add_argument( + "--visual_rpyc_port", + type=int, + default=None, + help=""" + when run_mode is visual_only, set this port, make others to call local visual infer to + transfer image to embed. + """, + ) parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 6948bceb3b..4a3b8ee28e 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,7 +8,9 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -20,7 +22,7 @@ class StartArgs: pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) - config_server_vit_redis_port: int = field(default=None) + config_server_visual_redis_port: int = field(default=None) afs_image_embed_dir: str = field(default=None) pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) @@ -106,6 +108,7 @@ class StartArgs: visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) + visual_rpyc_port: Optional[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index d722116e36..954e697e41 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -139,7 +139,7 @@ def exposed_infer_images(self, images: List[ImageItem], ref_event: threading.Eve return -def start_visual_process(args, pipe_writer): +def start_visual_process(args: StartArgs, pipe_writer): import lightllm.utils.rpyc_fix_utils as _ # 注册graceful 退出的处理 @@ -151,7 +151,10 @@ def start_visual_process(args, pipe_writer): visualserver = VisualManager(args=args) future = asyncio.run_coroutine_threadsafe(visualserver.wait_to_model_ready(), loop=visualserver.new_loop) future.result() - t = rpyc.ThreadedServer(visualserver, port=None, protocol_config={"allow_pickle": True}) + from .register_loop import register_loop + + asyncio.run_coroutine_threadsafe(register_loop(args=args), loop=visualserver.new_loop) + t = rpyc.ThreadedServer(visualserver, port=args.visual_rpyc_port, protocol_config={"allow_pickle": True}) except Exception as e: logger.exception(str(e)) visualserver.clean_up() From bff7bb7cc6c97c214eb2a11b80ce4b11400ba7cd Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 13:24:21 +0000 Subject: [PATCH 52/54] fix --- lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/visualserver/register_loop.py | 42 ------------------- .../visualserver/visual_only_manager.py | 41 +++++++++++++++++- 3 files changed, 40 insertions(+), 44 deletions(-) delete mode 100644 lightllm/server/visualserver/register_loop.py diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 4a3b8ee28e..a15c66475c 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -102,6 +102,7 @@ class StartArgs: job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) push_interval: int = field(default=10) + visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py deleted file mode 100644 index 31d0f7b8ac..0000000000 --- a/lightllm/server/visualserver/register_loop.py +++ /dev/null @@ -1,42 +0,0 @@ -import asyncio -import pickle -import websockets -import socket -from lightllm.utils.net_utils import get_hostname_ip -from lightllm.utils.log_utils import init_logger -from .vit_connect import VIT_Obj - -logger = init_logger(__name__) - - -async def register_loop(args): - assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" - - if args.host in ["0.0.0.0"]: - host_ip = get_hostname_ip() - else: - host_ip = args.host - - while True: - - try: - uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" - async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: - - sock = websocket.transport.get_extra_info("socket") - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.remote_vit_port}") - - await websocket.send(pickle.dumps(vit_obj)) - logger.info(f"Sent registration vit_obj: {vit_obj}") - - while True: - await websocket.send("heartbeat") - await asyncio.sleep(40) - - except Exception as e: - logger.error("connetion to config_server has error") - logger.exception(str(e)) - await asyncio.sleep(10) - logger.info("reconnection to config_server") diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 954e697e41..d669dee633 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -6,6 +6,11 @@ import threading import collections import uuid +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from .vit_connect import VIT_Obj from typing import List from lightllm.server.core.objs import StartArgs @@ -48,6 +53,37 @@ def event_loop(self, loop: asyncio.AbstractEventLoop): loop.run_forever() return + async def register_to_config_server_loop(self, args: StartArgs): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.visual_rpyc_port}") + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") + async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] @@ -151,9 +187,10 @@ def start_visual_process(args: StartArgs, pipe_writer): visualserver = VisualManager(args=args) future = asyncio.run_coroutine_threadsafe(visualserver.wait_to_model_ready(), loop=visualserver.new_loop) future.result() - from .register_loop import register_loop - asyncio.run_coroutine_threadsafe(register_loop(args=args), loop=visualserver.new_loop) + asyncio.run_coroutine_threadsafe( + visualserver.register_to_config_server_loop(args=args), loop=visualserver.new_loop + ) t = rpyc.ThreadedServer(visualserver, port=args.visual_rpyc_port, protocol_config={"allow_pickle": True}) except Exception as e: logger.exception(str(e)) From 3fdc3a695dfbc984a7fb0bccdce1ee6f372bf240 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 14:47:41 +0000 Subject: [PATCH 53/54] fix --- lightllm/server/visualserver/manager.py | 12 +- lightllm/server/visualserver/proxy_manager.py | 125 ++++++++++++++++++ 2 files changed, 133 insertions(+), 4 deletions(-) create mode 100644 lightllm/server/visualserver/proxy_manager.py diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 6e90d91848..76acfcd49a 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -92,7 +92,7 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[ImageItem]: shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) is_aborted = shm_req.is_aborted disable_prompt_cache = shm_req.sample_params.disable_prompt_cache @@ -102,8 +102,7 @@ async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - return + return [] multimodal_params = group_req_indexes.multimodal_params img_uuids = [img.uuid for img in multimodal_params.images] @@ -121,7 +120,11 @@ async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): if not ready: images_need_infer.append(img) - # case 1 + return images_need_infer + + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + images_need_infer = self.get_need_infer_images(group_req_indexes) + if len(images_need_infer) == 0: self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return @@ -142,6 +145,7 @@ async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): with self.lock: await asyncio.gather(*taskes) + # 等待推理通知已经 ok for dp_index in range(self.vit_dp): _images = dp_to_handle_images[dp_index] if _images: diff --git a/lightllm/server/visualserver/proxy_manager.py b/lightllm/server/visualserver/proxy_manager.py new file mode 100644 index 0000000000..8230b69b86 --- /dev/null +++ b/lightllm/server/visualserver/proxy_manager.py @@ -0,0 +1,125 @@ +import zmq +import zmq.asyncio +import asyncio +import uvloop +import rpyc +import socket +import pickle +import inspect +import setproctitle +import threading +import collections +from typing import List +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.core.objs import ShmReqManager, StartArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain +from .manager import VisualManager + +logger = init_logger(__name__) + + +class ProxyVisualManager(VisualManager): + def __init__( + self, + args: StartArgs, + ): + super().__init__(args) + assert self.vit_dp == 1 and self.vit_tp == 1 + + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + images_need_infer = self.get_need_infer_images(group_req_indexes) + + # case 1 + if len(images_need_infer) == 0: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + # case 2 + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.extend(self.run_task(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images])) + + with self.lock: + await asyncio.gather(*taskes) + + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) + + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + def run_task(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + return taskes + + async def loop_for_netio_req(self): + if not hasattr(self, "visual_recv_max_count"): + self.visual_recv_max_count = 64 + + while True: + try: + for _ in range(self.visual_recv_max_count): + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GroupReqIndexes): + logger.info( + f"visual recv req id {recv_req.group_req_id} " + f"img count {len(recv_req.multimodal_params.images)}" + ) + asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + else: + assert False, f"Error Req Inf {recv_req}" + self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) + except zmq.ZMQError: + # 当队列已经开始清空的时候,将一次接受数量下调 + self.visual_recv_max_count = 64 + await asyncio.sleep(0.01) + + def clean_up(self): + return + + +def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + start_parent_check_thread() + try: + visualserver = VisualManager(args=args) + asyncio.run(visualserver.wait_to_model_ready()) + except Exception as e: + logger.exception(str(e)) + visualserver.clean_up() + raise e + + pipe_writer.send("init ok") + + def handle_exception(loop, context): + logger.exception(f"VisualServer Caught exception: {str(context)}") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(handle_exception) + asyncio.set_event_loop(loop) + loop.run_until_complete(visualserver.loop_for_netio_req()) + return From 35e28b156636f07124e147989e28b52e0e17db16 Mon Sep 17 00:00:00 2001 From: wzj Date: Sat, 28 Mar 2026 14:58:08 +0000 Subject: [PATCH 54/54] fix --- lightllm/server/visualserver/objs.py | 14 ++ lightllm/server/visualserver/proxy_manager.py | 18 ++ .../visualserver/visual_only_manager.py | 2 +- lightllm/server/visualserver/vit_connect.py | 234 ------------------ 4 files changed, 33 insertions(+), 235 deletions(-) create mode 100644 lightllm/server/visualserver/objs.py delete mode 100644 lightllm/server/visualserver/vit_connect.py diff --git a/lightllm/server/visualserver/objs.py b/lightllm/server/visualserver/objs.py new file mode 100644 index 0000000000..cc68e88874 --- /dev/null +++ b/lightllm/server/visualserver/objs.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip: str + port: int + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip}:{self.port}, node_id: {self.node_id}" diff --git a/lightllm/server/visualserver/proxy_manager.py b/lightllm/server/visualserver/proxy_manager.py index 8230b69b86..ea60648a55 100644 --- a/lightllm/server/visualserver/proxy_manager.py +++ b/lightllm/server/visualserver/proxy_manager.py @@ -9,6 +9,8 @@ import setproctitle import threading import collections +import base64 +import httpx from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -94,6 +96,22 @@ async def loop_for_netio_req(self): self.visual_recv_max_count = 64 await asyncio.sleep(0.01) + async def loop_to_connect_remote_visual_server(self): + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.exception(f"Error getting VIT instances: {e}") + return None + def clean_up(self): return diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index d669dee633..c9b381d453 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -69,7 +69,7 @@ async def register_to_config_server_loop(self, args: StartArgs): sock = websocket.transport.get_extra_info("socket") sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.visual_rpyc_port}") + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip=host_ip, port=args.visual_rpyc_port) await websocket.send(pickle.dumps(vit_obj)) logger.info(f"Sent registration vit_obj: {vit_obj}") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py deleted file mode 100644 index 9720eb6698..0000000000 --- a/lightllm/server/visualserver/vit_connect.py +++ /dev/null @@ -1,234 +0,0 @@ -import asyncio -import base64 -import pickle -import time -from dataclasses import dataclass -from typing import Dict, Optional - -import httpx -import rpyc -import zmq -import zmq.asyncio - -from lightllm.utils.log_utils import init_logger -from lightllm.server.core.objs.io_objs import GroupReqIndexes - -logger = init_logger(__name__) - - -@dataclass -class VIT_Obj: - node_id: int - host_ip_port: str - - def to_log_str(self): - return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" - - -class VITConnectionManager: - """VIT连接管理器""" - - def __init__(self, args, context, local_visual_port: int, cache_client: rpyc.Connection): - self.args = args - self.context = context - self.local_visual_port = local_visual_port - - self.send_to_visual = None - self.remote_vit_instances = {} - self.current_vit_index = 0 - self.remote_vit = args.enable_remote_vit - self.remote_vit_port = args.remote_vit_port - self.cache_client = cache_client - - self._setup_vit_connections() - - def _setup_vit_connections(self): - """ - 设置VIT连接,支持本地和远程VIT实例 - 支持多种连接模式: - 1. 本地VIT实例 (默认) - 2. 远程多个VIT实例 (负载均衡) - """ - if self.remote_vit: - # 远程VIT实例模式 - self._setup_remote_vit_connections() - else: - self._setup_local_vit_connection() - - def _setup_local_vit_connection(self): - self.send_to_visual = self.context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") - logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") - - def _setup_remote_vit_connections(self): - """ - 初始化远程VIT连接,同步获取初始实例 - """ - logger.info("Setting up remote VIT connections...") - - self._sync_init_vit_instances() - - retry_count = 0 - max_retries = 30 # 最多等待30秒 - while len(self.remote_vit_instances) == 0 and retry_count < max_retries: - logger.info(f"Waiting for VIT instances... (attempt {retry_count + 1}/{max_retries})") - time.sleep(1) - retry_count += 1 - self._sync_init_vit_instances() - - if len(self.remote_vit_instances) == 0: - logger.warning("No VIT instances available after initialization") - else: - logger.info(f"Successfully connected to {len(self.remote_vit_instances)} VIT instances") - - def _sync_init_vit_instances(self): - """ - 同步初始化VIT实例连接 - """ - try: - # 使用同步方式获取VIT实例 - vit_objs = self._sync_get_vit_objs() - if vit_objs: - self._update_vit_connections(vit_objs) - except Exception as e: - logger.error(f"Failed to initialize VIT instances: {e}") - - def _sync_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: - """ - 同步获取VIT实例信息 - """ - import requests - - uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" - try: - response = requests.get(uri, timeout=10) - if response.status_code == 200: - base64data = response.json()["data"] - id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) - return id_to_vit_obj - else: - logger.error(f"Failed to get VIT instances: {response.status_code}") - return None - except Exception as e: - logger.error(f"Error getting VIT instances: {e}") - return None - - def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): - """ - 更新VIT连接,添加新的连接,关闭失效的连接 - """ - # 关闭不再存在的连接 - closed_ids = [] - for id, remote_instance in self.remote_vit_instances.items(): - if id not in id_to_vit_obj: - try: - remote_instance.close() - except: - pass - closed_ids.append(id) - logger.info(f"Closed VIT connection {id}") - - for id in closed_ids: - self.remote_vit_instances.pop(id) - - # 建立新的连接 - for id, vit_obj in id_to_vit_obj.items(): - if id not in self.remote_vit_instances: - try: - socket = self.context.socket(zmq.PUSH) - # print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) - ip, port = vit_obj.host_ip_port.split(":") - socket.connect(f"tcp://{ip}:{port}") - self.remote_vit_instances[id] = socket - logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") - except Exception as e: - logger.error(f"Failed to connect to VIT instance {id}: {e}") - - def _get_vit_instance(self): - """ - 获取下一个可用的VIT实例 (轮询负载均衡) - """ - if not self.remote_vit: - return self.send_to_visual - - if len(self.remote_vit_instances) == 0: - raise Exception("No available VIT instances") - - # 简单的轮询负载均衡 - index = (self.current_vit_index + 1) % len(self.remote_vit_instances) - self.current_vit_index = index - return list(self.remote_vit_instances.values())[index] - - async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL): - """ - 发送数据到VIT实例,支持本地和远程模式 - """ - instance = self._get_vit_instance() - - try: - instance.send_pyobj(req, protocol=protocol) - except Exception as e: - logger.error(f"Failed to send to VIT instance: {e}") - raise Exception(f"Failed to send to VIT instance: {e}") - - if self.remote_vit: - await self._wait_visual_embed_ready(req) - - req.multimodal_params.free_images() - - async def vit_handle_loop(self): - """ - 异步VIT连接管理循环,由外部启动 - """ - if not self.remote_vit: - return - logger.info("Starting VIT connection management loop") - while True: - try: - id_to_vit_obj = await self._async_get_vit_objs() - if id_to_vit_obj: - self._update_vit_connections(id_to_vit_obj) - await asyncio.sleep(30) - except Exception as e: - logger.exception(f"Error in VIT handle loop: {e}") - await asyncio.sleep(10) - - async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: - """ - 异步获取VIT实例信息 - """ - uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" - try: - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get(uri) - if response.status_code == 200: - base64data = response.json()["data"] - id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) - return id_to_vit_obj - else: - logger.error(f"Failed to get VIT instances: {response.status_code}") - return None - except Exception as e: - logger.exception(f"Error getting VIT instances: {e}") - return None - - async def _wait_visual_embed_ready( - self, - req: GroupReqIndexes, - timeout_seconds: int = 1000, - ): - # 本地模式不需要等待 - if not self.remote_vit: - return - uuids = [image.uuid for image in req.multimodal_params.images] - - async def wait_for_embeds(): - while not all(self.cache_client.root.get_items_embed(uuids, True)): - await asyncio.sleep(0.01) - - try: - await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) - except asyncio.TimeoutError: - raise TimeoutError( - f"Req {req.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" - )