diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..768ebd9139 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -195,7 +195,8 @@ def flash_attention_v3_fwd( False, window_size[0], window_size[1], - 0.0, + attention_chunk=0, + softcap=0.0, is_rotary_interleaved=False, scheduler_metadata=None, num_splits=1, @@ -203,8 +204,7 @@ def flash_attention_v3_fwd( sm_margin=0, sinks=None, ) - - return + return o except ImportError: print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index c8a82d3239..19aacdf849 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,16 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], + choices=[ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual_only", + ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -61,6 +70,14 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--config_server_visual_redis_port", + type=int, + default=None, + help="""when run_mode is config_server, set this params will start a redis server, + when a llm infer node start to set this params, the visual infer module will start a + proxy module use config server to find remote vit infer nodes to infer img""", + ) parser.add_argument( "--nixl_pd_kv_page_num", type=int, @@ -440,6 +457,15 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) + parser.add_argument( + "--visual_rpyc_port", + type=int, + default=None, + help=""" + when run_mode is visual_only, set this port, make others to call local visual infer to + transfer image to embed. + """, + ) parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) @@ -605,6 +631,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--afs_image_embed_dir", + type=str, + default=None, + help="path for vit embed, when use vit remote infer mode", + ) parser.add_argument( "--enable_cpu_cache", action="store_true", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..c803de7db3 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -43,7 +43,7 @@ from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .api_lightllm import lightllm_get_score +from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger from lightllm.utils.error_utils import ServerBusyError @@ -92,6 +92,8 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManagerForPDMaster( args=args, ) + elif args.run_mode in ["visual", "visual_only"]: + self.metric_client = MetricClient(args.metric_port) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -136,7 +138,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode == "pd_master": + if g_objs.args.run_mode in ["pd_master", "visual", "visual_only"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/get_image_embedding") +async def get_image_embed(request: Request) -> Response: + try: + return await lightllm_get_image_embedding(request, g_objs.httpserver_manager) + except ServerBusyError as e: + logger.error("%s", str(e), exc_info=True) + return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: @@ -359,6 +373,8 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + if g_objs.httpserver_manager is None: + return loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..3ad7b49daf 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -1,10 +1,11 @@ import collections from typing import AsyncGenerator from fastapi import BackgroundTasks, Request -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from lightllm.utils.envs_utils import get_env_start_args import ujson as json @@ -150,3 +151,21 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: + request_dict = await request.json() + args = get_env_start_args() + assert not args.disable_vision + assert args.enable_remote_vit + sample_params_dict = request_dict["parameters"] + sampling_params = SamplingParams() + sampling_params.init(tokenizer=None, **sample_params_dict) + sampling_params.verify() + multimodal_params_dict = request_dict.get("multimodal_params", {}) + assert not multimodal_params_dict.get("audios") + multimodal_params = MultimodalParams(**multimodal_params_dict) + + await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) + + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..7542f7be6c 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode in ["visual", "visual_only"]: + visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 77355f0d06..a0a74d81bf 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -15,12 +15,37 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import has_audio_module, has_vision_module logger = init_logger(__name__) +def _ensure_remote_vit_embed_dir(image_embed_dir: str) -> None: + if os.path.exists(image_embed_dir): + if not os.path.isdir(image_embed_dir): + raise ValueError(f"image_embed_dir is not a directory: {image_embed_dir}") + return + + os.makedirs(image_embed_dir, mode=0o777, exist_ok=True) + os.chmod(image_embed_dir, 0o777) + + +def _prepare_remote_vit_embed_dir(args): + remote_vit_mode = args.enable_remote_vit or args.run_mode in ["visual", "visual_only"] + if not remote_vit_mode: + return + + if not args.image_embed_dir: + raise ValueError("remote vit mode requires --image_embed_dir to be set") + + args.image_embed_dir = os.path.abspath(args.image_embed_dir) + _ensure_remote_vit_embed_dir(args.image_embed_dir) + + logger.info(f"using image_embed_dir: {args.image_embed_dir}") + + def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: @@ -57,11 +82,12 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): +def normal_or_p_d_start(args, only_prepare=False): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args @@ -73,7 +99,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -168,6 +194,7 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + _prepare_remote_vit_embed_dir(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -229,11 +256,16 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + if only_prepare: + return + already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) if args.pd_decode_rpyc_port is not None: already_uesd_ports.append(args.pd_decode_rpyc_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -241,8 +273,10 @@ def normal_or_p_d_start(args): ports_locker.lock_port() node_world_size = args.tp // args.nnodes + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -265,8 +299,12 @@ def normal_or_p_d_start(args): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] visual_model_tp_ports.append(tp_ports_for_dp) can_use_ports = can_use_ports[args.visual_tp :] - visual_nccl_ports.append(can_use_ports[0]) - can_use_ports = can_use_ports[1:] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] # 将申请好的端口放入args参数中 if args.nccl_port is None: @@ -316,7 +354,7 @@ def normal_or_p_d_start(args): start_args=[(args,)], ) - if not args.disable_vision: + if not args.disable_vision and not args.enable_remote_vit: from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( @@ -463,6 +501,79 @@ def pd_master_start(args): http_server_process.wait() +def visual_start(args): + normal_or_p_d_start(args, only_prepare=True) + + already_uesd_ports = [args.remote_vit_port] + if args.nccl_port is not None: + already_uesd_ports.append(args.nccl_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) + + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_ports=already_uesd_ports, + ) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + visual_port, + audio_port, + cache_port, + metric_port, + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] + + visual_nccl_ports = [] + for _ in range(args.visual_dp): + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + else: + args.visual_nccl_ports = visual_nccl_ports + + args.router_port = router_port + args.visual_port = visual_port + args.audio_port = audio_port + args.cache_port = cache_port + args.metric_port = metric_port + args.visual_node_id = uuid.uuid4().int + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(args,)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args,), + ], + ) + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": @@ -470,6 +581,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.start_redis: + start_redis_service(args) + set_env_start_args(args) command = [ diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..c55b743480 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.vit_connect import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name @@ -19,7 +20,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -80,6 +107,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_objs) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d3dc849664..a15c66475c 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -8,7 +8,9 @@ class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode", "visual_only"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) @@ -20,6 +22,8 @@ class StartArgs: pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) + config_server_visual_redis_port: int = field(default=None) + afs_image_embed_dir: str = field(default=None) pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") @@ -98,12 +102,14 @@ class StartArgs: job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) push_interval: int = field(default=10) + visual_node_id: int = field(default=None) visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default=None) + visual_rpyc_port: Optional[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) diff --git a/lightllm/server/embed_cache/afs_utils.py b/lightllm/server/embed_cache/afs_utils.py new file mode 100644 index 0000000000..9281559921 --- /dev/null +++ b/lightllm/server/embed_cache/afs_utils.py @@ -0,0 +1,157 @@ +import os +import time +import torch +import uuid +import itertools +from typing import List, Tuple, Optional +from pathlib import Path +from .redis_utils import RedisMetadataLib + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class AfsUtils: + def __init__(self, base_dir: str, dir_depth: int = 2): + self.base_dir = base_dir + # 判断 base_dir 是否存在,不存在则创建并赋予777权限,让其他人也可以写入 + if not os.path.exists(base_dir): + os.makedirs(base_dir, mode=0o777, exist_ok=True) + + # build sub dirs + parent_dir = Path(base_dir) + subdirs = ["".join(p) for p in itertools.product("0123456789abcdef", repeat=dir_depth)] + for sub in subdirs: + sub_dir_path = parent_dir / sub + os.makedirs(sub_dir_path, mode=0o777, exist_ok=True) + return + + def save_tensor_afs(self, name: str, tensor: torch.Tensor) -> bool: + try: + target_path = self._get_afs_path(name) + if target_path.exists(): + return True + tmp_path = self._get_afs_path(name=name, uuid_tail_str=str(uuid.uuid4())) + with open(tmp_path, "wb") as f: + tensor = tensor.detach().cpu() + dest = torch.empty_like(tensor) + dest.copy_(tensor) + torch.save(dest, f, _use_new_zipfile_serialization=False, pickle_protocol=4) + os.rename(tmp_path, target_path) + os.chmod(target_path, 0o777) + return True + except Exception as e: + logger.warning(f"failed to save embed tensor file: {target_path} tmp_path: {tmp_path} excetion {str(e)}") + return False + finally: + try: + tmp_path.unlink(missing_ok=True) + except: + pass + + def load_tensor_afs(self, name: str) -> Optional[torch.Tensor]: + try: + path = self._get_afs_path(name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) + except Exception as e: + logger.warning(f"fail to load afs file {name} error: {str(e)}") + return None + + def free_afs(self, name: str) -> bool: + try: + path = self._get_afs_path(name) + if not path.exists(): + return True + path.unlink(missing_ok=True) + return True + except Exception as e: + logger.warning(f"free_afs name: {name} error: {str(e)}") + return False + return + + def exist_afs(self, name: str) -> bool: + try: + path = self._get_afs_path(name) + return path.exists() + except Exception as e: + logger.warning(f"exist_afs name: {name} error: {str(e)}") + return False + + def _get_afs_path(self, name: str, uuid_tail_str: Optional[str] = None) -> Path: + if uuid_tail_str is None: + return Path(self.base_dir) / name[0:2] / name + else: + return Path(self.base_dir) / name[0:2] / f"{name}.{uuid_tail_str}" + + +class SepEmbedHandler: + def __init__( + self, + afs_embed_dir: str, + redis_host: str, + redis_port: int, + capacity: int = 250000, + evict_fraction: float = 0.1, + ) -> None: + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >= 1") + + redis_url = f"redis://{redis_host}:{redis_port}/0" + self.redis_client = RedisMetadataLib(redis_url=redis_url) + self.capacity = capacity + self.remove_count = min(int(self.capacity * evict_fraction), 1000) # full的时候,每次清理的数量 + self.afs_embed_dir = afs_embed_dir + self.afs_utils = AfsUtils(self.afs_embed_dir) + + def full_to_clean(self): + remove_objs: List[str] = self.redis_client.get_eviction_candidates( + remove_size=self.remove_count, capcity=self.capacity + ) + for obj in remove_objs: + try: + if self.afs_utils.free_afs(obj): + self.redis_client.remove([obj]) + except BaseException as e: + logger.warning(f"full_to_clean md5 {obj} error {str(e)}") + + def insert(self, md5: str, tensor: torch.Tensor) -> bool: + self.full_to_clean() + try: + # 保证一定会有清理的可能性 + self.redis_client.update(md5) + self.afs_utils.save_tensor_afs(md5, tensor) + self.redis_client.update(md5) + except: + return False + + def load(self, md5: str) -> Optional[torch.Tensor]: + try: + ans = self.afs_utils.load_tensor_afs(md5) + if ans: + self.redis_client.update(md5) + return ans + else: + return None + except Exception as e: + logger.warning(f"load md5 {md5} error {str(e)}") + return None + + def check_ready(self, md5_list: List[str]) -> List[bool]: + try: + tmp1 = self.redis_client.check_and_update(md5_list) + assert len(tmp1) == len(md5_list) + start = time.time() + tmp2 = [exists and self.afs_utils.exist_afs(md5) for md5, exists in zip(md5_list, tmp1)] + cost_time = time.time() - start + if cost_time > 0.05: + logger.warning(f"slow afs check exist {cost_time} seconds, md5_list size: {len(md5_list)}") + assert len(tmp1) == len(tmp2) + ans = [a and b for a, b in zip(tmp1, tmp2)] + return ans + except Exception as e: + logger.warning(f"check_ready error {str(e)}") + return [False] * len(md5_list) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index fbce108762..9251b87149 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -125,18 +125,26 @@ def _free_to_alloc(self, free_min_count: int, new_md5_dict: Dict[str, int]) -> D def _add_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] - self._sorted_records.remove(rec) - rec.ref += 1 - self._sorted_records.add(rec) + self._update_record_ref(rec, 1) return def _del_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] + self._update_record_ref(rec, -1) + return + + def _update_record_ref(self, rec: Record, delta: int): self._sorted_records.remove(rec) - rec.ref -= 1 + rec.ref += delta + rec.visittime = time.time() self._sorted_records.add(rec) return + def _update_record_ref_by_id(self, id_: int, delta: int): + rec: Record = self._id_to_records[id_] + self._update_record_ref(rec, delta) + return + def _judge_enough_token_cache(self, md5sum_list: list[str], token_num_list: list[int]) -> bool: tmp_dict = {} for md5, token_num in zip(md5sum_list, token_num_list): @@ -164,6 +172,10 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l alloc_md5_dict = self._free_to_alloc( free_min_count=new_needed - (self.capacity - self.occupied), new_md5_dict=new_md5_dict ) + for md5 in add_ref_m_list: + # 解锁 + self._del_ref(md5) + if len(alloc_md5_dict) == len(new_md5_dict): for md5sum, mem_block in alloc_md5_dict.items(): token_num = new_md5_dict[md5sum] @@ -187,10 +199,6 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l self._sorted_records.add(rec) self.occupied += 1 - for md5 in add_ref_m_list: - # 解锁 - self._del_ref(md5) - # 遍历加 ref results = [] for md5 in md5sum_list: @@ -212,10 +220,7 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: - rec: Record = self._id_to_records[id_] - self._sorted_records.remove(rec) - rec.ref -= 1 - self._sorted_records.add(rec) + self._update_record_ref_by_id(id_, -1) def set_items_data(self, ids: list[int]) -> None: for id_ in ids: diff --git a/lightllm/server/embed_cache/redis_utils.py b/lightllm/server/embed_cache/redis_utils.py new file mode 100644 index 0000000000..2dd266cacb --- /dev/null +++ b/lightllm/server/embed_cache/redis_utils.py @@ -0,0 +1,151 @@ +import redis +from typing import List, Tuple, Union, Optional + +class RedisMetadataLib: + """ + # 代码任务 + 创建一个基于redis 管理的元数据操作库代码。 + 要求: + 2. 提供一个包装的 redis 操作client 库,提供以下功能: + (1) 提供一个时间排序队列,向队列中插入md5,并更新时间错(单位为s即可). + (2) 输入为(md5_list,), 向队列中插入所有的md5, 并更新其对应时间错。 + (3) 输入为(md5_list,), 将队列中的md5进行删除。 + (4) 输入为(md5_list,), 返回 md5_list 中所有md5 每个是否在链表中存在,返回一个bool list来标识,同时对所有存在的md5,更新时间错到最新。 + (5) 输入为(remove_size, capcity), 当时间排序队列中的元素数量大于等于capcity, 返回时间排序队列中排在前面的 remove_size 个元素,其内容为 md5。 + (6) 所有操作都使用lua 脚本,以实现原子化操作,同时返回的错误要能区分具体错误的原因,注意lua脚本的可读性,和相关函数的输入输出测试。时间错为server端s级别的参数。 + """ + def __init__(self, redis_url: str = "redis://localhost:6379/0", prefix: str = "meta"): + # decode_responses=True 确保返回的是字符串而非字节 + self.r = redis.Redis.from_url(redis_url, decode_responses=True) + self.lru_key = f"{prefix}:queue:lru" + self._register_scripts() + + def _register_scripts(self): + """注册 Lua 脚本实现原子化操作""" + + # (1) & (2) 更新/插入:支持传入单个或多个 MD5 + # 逻辑:获取服务器时间,循环执行 ZADD + self._lua_update = self.r.register_script(""" + local lru_key = KEYS[1] + local now = redis.call('TIME')[1] + local count = 0 + for i, md5 in ipairs(ARGV) do + redis.call('ZADD', lru_key, now, md5) + count = count + 1 + end + return count + """) + + # (3) 删除:从队列中移除指定的 MD5 + self._lua_remove = self.r.register_script(""" + local lru_key = KEYS[1] + local count = 0 + for i, md5 in ipairs(ARGV) do + count = count + redis.call('ZREM', lru_key, md5) + end + return count + """) + + # (4) 检查并更新:判断是否存在,存在则刷新时间,返回 bool 状态列表 + self._lua_check_update = self.r.register_script(""" + local lru_key = KEYS[1] + local now = redis.call('TIME')[1] + local results = {} + for i, md5 in ipairs(ARGV) do + if redis.call('ZSCORE', lru_key, md5) then + redis.call('ZADD', lru_key, now, md5) + table.insert(results, 1) + else + table.insert(results, 0) + end + end + return results + """) + + # (5) 容量清理:检查容量并获取候选列表 + self._lua_evict = self.r.register_script(""" + local lru_key = KEYS[1] + local remove_size = tonumber(ARGV[1]) + local capacity = tonumber(ARGV[2]) + + local current_size = redis.call('ZCARD', lru_key) + if current_size >= capacity then + -- 按照分数(时间戳)从小到大排列,获取最旧的 N 个 + return redis.call('ZRANGE', lru_key, 0, remove_size - 1) + else + return {} + end + """) + + def _to_list(self, data: Union[str, List[str]]) -> List[str]: + """内部工具:将输入统一转为列表形式""" + if isinstance(data, str): + return [data] + return data + + def update(self, md5_list: Union[str, List[str]]) -> int: + """ + 功能 (1) & (2):插入或更新 md5 的时间戳。 + 支持传入单个字符串或字符串列表。 + """ + items = self._to_list(md5_list) + if not items: return 0 + return self._lua_update(keys=[self.lru_key], args=items) + + def remove(self, md5_list: Union[str, List[str]]) -> int: + """ + 功能 (3):将队列中的 md5 进行删除。 + 支持传入单个字符串或字符串列表。 + """ + items = self._to_list(md5_list) + if not items: return 0 + return self._lua_remove(keys=[self.lru_key], args=items) + + def check_and_update(self, md5_list: List[str]) -> List[bool]: + """ + 功能 (4):返回 md5_list 中每个 md5 是否在队列中存在。 + 对存在的 md5 会同时更新时间戳到最新。 + """ + if not md5_list: return [] + raw_res = self._lua_check_update(keys=[self.lru_key], args=md5_list) + return [res == 1 for res in raw_res] + + def get_eviction_candidates(self, remove_size: int, capacity: int) -> List[str]: + """ + 功能 (5):当队列数量 >= capacity 时,返回排在前面的 remove_size 个 md5。 + """ + return self._lua_evict(keys=[self.lru_key], args=[remove_size, capacity]) + +# ---------------- 功能测试 ---------------- + +def test_meta_lib(): + lib = RedisMetadataLib(prefix="test_service") + # 清理历史数据 + lib.r.delete(lib.lru_key) + + print("1. 测试更新 (update)") + lib.update("file_0") # 单个 + lib.update(["file_1", "file_2", "file_3"]) # 批量 + print(f"当前队列大小: {lib.r.zcard(lib.lru_key)}") + + print("\n2. 测试检查并更新 (check_and_update)") + # file_1 存在,file_none 不存在,file_3 存在 + check_list = ["file_1", "file_none", "file_3"] + exists_results = lib.check_and_update(check_list) + for m, exists in zip(check_list, exists_results): + print(f"MD5: {m}, 存在状态: {exists}") + + print("\n3. 测试容量逐出 (get_eviction_candidates)") + # 当前有 4 个元素,设容量为 3,要求返回最旧的 2 个 + candidates = lib.get_eviction_candidates(remove_size=2, capacity=3) + print(f"容量达到3时,建议删除的最旧2个元素: {candidates}") + + print("\n4. 测试删除 (remove)") + removed_count = lib.remove(["file_0", "file_1"]) + print(f"成功移除数量: {removed_count}") + + final_check = lib.check_and_update(["file_1", "file_2"]) + print(f"最终检查 [file_1, file_2]: {final_check}") + +if __name__ == "__main__": + test_meta_lib() \ No newline at end of file diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..76acfcd49a 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,13 +7,15 @@ import pickle import inspect import setproctitle +import threading +import collections from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry @@ -29,8 +31,8 @@ class VisualManager: def __init__( self, args: StartArgs, - visual_model_rpc_ports, ): + self.args = args context = zmq.Context(2) enable_audio = not args.disable_audio if enable_audio: @@ -48,47 +50,39 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.cache_port = args.cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir - self.tp_world_size = args.tp self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp + # image 最大推理 batch size self.infer_batch_size = args.visual_infer_batch_size - self.trust_remote_code = args.trust_remote_code - self.args = args - self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() + self.cur_dp_index = 0 + self.lock = threading.Lock() async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): - tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] - rpc_model = await start_model_process( - port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id - ) + + rpc_model = await start_model_process() self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] for dp_rank_id in range(self.vit_dp): # async init model process for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "vit_dp": self.vit_dp, + "device_id": device_id, "vit_tp": self.vit_tp, - "cache_port": self.cache_port, + "cache_port": self.args.cache_port, "tp_rank_id": tp_rank_id, "dp_rank_id": dp_rank_id, - "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id, "data_type": self.args.data_type, "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], - "visual_gpu_ids": self.args.visual_gpu_ids, "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), @@ -98,87 +92,74 @@ async def wait_to_model_ready(self): await asyncio.gather(*init_model_ret) return - async def infer_imgs(self, images: List[ImageItem]): - if len(images) == 0: - return + def get_need_infer_images(self, group_req_indexes: GroupReqIndexes) -> List[ImageItem]: + shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) + is_aborted = shm_req.is_aborted + disable_prompt_cache = shm_req.sample_params.disable_prompt_cache + self.shm_req_manager.put_back_req_obj(shm_req) + # case 0 + if is_aborted: + # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 + # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 + # 需要一些一致的流程来保证不出现异步问题。 + return [] + + multimodal_params = group_req_indexes.multimodal_params + img_uuids = [img.uuid for img in multimodal_params.images] + # disable prompt cache通常用来测试,需要也去掉image cache的影响 + if disable_prompt_cache: + ready_image = [False] * len(img_uuids) + else: + if len(img_uuids) > 0: + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + else: + ready_image = [] - tasks = [] - for vit_dp_rank in range(self.vit_dp): - assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] - if assigned_images: - for vit_tp_rank in range(self.vit_tp): - task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) - tasks.append(task) + images_need_infer = [] + for img, ready in zip(multimodal_params.images, ready_image): + if not ready: + images_need_infer.append(img) - await asyncio.gather(*tasks) - return + return images_need_infer - async def loop_for_fwd(self): - while True: - if len(self.waiting_reqs) == 0: - await asyncio.sleep(0.01) # 10ms - else: - processing_group_reqs = [] - images_need_infer = [] - ready_to_send = [] - - def flush_ready(force: bool = False): - if not ready_to_send: - return - if not force and len(ready_to_send) < self.send_batch_size: - return - - for group_req_indexes in ready_to_send: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - ready_to_send.clear() - - while len(self.waiting_reqs) > 0: - group_req_indexes = self.waiting_reqs.pop(0) - shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) - is_aborted = shm_req.is_aborted - disable_prompt_cache = shm_req.sample_params.disable_prompt_cache - self.shm_req_manager.put_back_req_obj(shm_req) - if is_aborted: - # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 - # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 - # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - continue - - multimodal_params = group_req_indexes.multimodal_params - - img_uuids = [img.uuid for img in multimodal_params.images] - # disable prompt cache通常用来测试,需要也去掉image cache的影响 - if disable_prompt_cache: - ready_image = [False] * len(img_uuids) - else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) - - for img, ready in zip(multimodal_params.images, ready_image): - if not ready: - images_need_infer.append(img) - - if len(images_need_infer) == self.infer_batch_size: - await self.infer_imgs(images_need_infer) - images_need_infer = [] - ready_to_send.extend(processing_group_reqs) - processing_group_reqs = [] - flush_ready(force=False) - - if len(images_need_infer) == 0: - ready_to_send.append(group_req_indexes) - flush_ready(force=False) - else: - processing_group_reqs.append(group_req_indexes) + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + images_need_infer = self.get_need_infer_images(group_req_indexes) - if len(images_need_infer) > 0: - await self.infer_imgs(images_need_infer) - images_need_infer = [] + if len(images_need_infer) == 0: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + # case 2 + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.extend(self.run_task(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images])) + + with self.lock: + await asyncio.gather(*taskes) + + # 等待推理通知已经 ok + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) + + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return - # 这些处理完 image 的 group 也 ready 了 - ready_to_send.extend(processing_group_reqs) - processing_group_reqs = [] - flush_ready(force=True) + def run_task(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + return taskes async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): @@ -193,7 +174,7 @@ async def loop_for_netio_req(self): f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" ) - self.waiting_reqs.append(recv_req) + asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) @@ -203,20 +184,18 @@ async def loop_for_netio_req(self): await asyncio.sleep(0.01) def clean_up(self): - for model_rpc in self.model_rpcs: - model_rpc.rpc_server_process.kill() - for model_rpc in self.model_rpcs: - model_rpc.rpc_server_process.join() return -def start_visual_process(args, model_rpc_ports, pipe_writer): +def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() try: - visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) + visualserver = VisualManager(args=args) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) @@ -231,6 +210,5 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/__init__.py b/lightllm/server/visualserver/model_infer/__init__.py index e69de29bb2..c186856442 100644 --- a/lightllm/server/visualserver/model_infer/__init__.py +++ b/lightllm/server/visualserver/model_infer/__init__.py @@ -0,0 +1,60 @@ +import asyncio +import rpyc +import inspect +import uuid +import os +import multiprocessing +from lightllm.utils.retry_utils import retry +from rpyc.utils.factory import unix_connect +from rpyc.utils.classic import obtain +from rpyc.utils.server import ThreadedServer +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.envs_utils import get_env_start_args +from .model_rpc_client import VisualModelRpcClient +from .model_rpc import VisualModelRpcServer + + +def _init_env(socket_path: str, success_event): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + + import lightllm.utils.rpyc_fix_utils as _ + + t = ThreadedServer(VisualModelRpcServer(), socket_path=socket_path, protocol_config={"allow_pickle": True}) + success_event.set() + t.start() + return + + +async def start_model_process(): + import lightllm.utils.rpyc_fix_utils as _ + + socket_path = _generate_unix_socket_path() + if os.path.exists(socket_path): + os.remove(socket_path) + + success_event = multiprocessing.Event() + proc = multiprocessing.Process( + target=_init_env, + args=( + socket_path, + success_event, + ), + ) + proc.start() + await asyncio.to_thread(success_event.wait, timeout=40) + assert proc.is_alive() + + conn = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True}) + assert proc.is_alive() + + # 服务端需要调用客户端传入的event所以,客户端需要一个后台线程进行相关的处理。 + conn._bg_thread = rpyc.BgServingThread(conn) + + return VisualModelRpcClient(conn) + + +def _generate_unix_socket_path() -> str: + """Generate a random Unix socket path""" + unique_id = uuid.uuid4().hex[:8] + return f"/tmp/lightllm_model_infer_{unique_id}.sock" diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 6355ac2dbf..95415e4b4e 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -1,14 +1,13 @@ -import asyncio -import numpy as np import rpyc import torch import socket -import inspect -from datetime import timedelta -from typing import Dict, List, Tuple +import torch.multiprocessing as mp +import queue +import threading +import torch.distributed as dist +from typing import Dict, List, Tuple, Deque, Optional from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain -from rpyc.utils.server import ThreadedServer from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel @@ -22,27 +21,44 @@ from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env -from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient from lightllm.server.visualserver import set_vit_att_backend +from lightllm.server.embed_cache.afs_utils import SepEmbedHandler +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) class VisualModelRpcServer(rpyc.Service): def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) - import torch - import torch.distributed as dist - self.vit_dp = kvargs["vit_dp"] + # kvargs = { + # "weight_dir": self.model_weightdir, + # "device_id": device_id, + # "vit_tp": self.vit_tp, + # "cache_port": self.args.cache_port, + # "tp_rank_id": tp_rank_id, + # "dp_rank_id": dp_rank_id, + # "data_type": self.args.data_type, + # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + # "quant_type": self.args.vit_quant_type, + # "quant_cfg": self.args.vit_quant_cfg, + # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "vit_attn_backend": self.vit_attn_backend, + # } + + weight_dir = kvargs["weight_dir"] + self.infer_max_batch_size = kvargs["max_batch_size"] + self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] self.cache_port = kvargs["cache_port"] - weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) - self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.is_visual_only_mode = get_env_start_args().run_mode == "visual_only" self.data_type = kvargs["data_type"] self.vit_attn_backend = kvargs["vit_attn_backend"] set_vit_att_backend(self.vit_attn_backend) @@ -95,7 +111,23 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + if not self.is_visual_only_mode: + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + else: + # 独立部署vit模式下,不需要连接 cache_client, 结果是写入 afs + args = get_env_start_args() + assert args.visual_dp == 1 + if self.tp_rank_id == 0: + self.afs_handler = SepEmbedHandler( + afs_embed_dir=self.args.afs_embed_dir, + redis_host=self.args.config_server_host, + redis_port=self.args.config_server_vit_redis_port, + capacity=self.args.afs_embed_capacity, + ) + + self._init_taskes() except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -107,112 +139,176 @@ def exposed_init_model(self, kvargs): set_random_seed(2147483647) return + def exposed_run_task(self, images: List["ImageItem"], ref_event_list: List[threading.Event]): + try: + images = obtain(images) + for i in range(len(images)): + images[i].event = ref_event_list[i] + self.infer_queue.put(images[i]) + + except BaseException as e: + logger.exception(str(e)) + raise e + return + + def _init_taskes(self): + self.args = get_env_start_args() + + # 异步队列, 用于接受任务 + self.infer_queue = queue.Queue() + # 将计算得到的结果放入 afs 或者 embed cache 的 queue + self.store_queue = queue.Queue() + + # 限制并发, 主要是为了控制内存用量,防止过多造成内存OOM + self.sempare = threading.Semaphore(self.infer_max_batch_size * 8) + + # 用于同步各个推理tp每次拿到一样的image数量建立的gloo通信组 + self.gloo_group = dist.new_group(ranks=list(range(self.vit_tp)), backend="gloo") + + # 启动任务处理线程 + self._infer_thread = threading.Thread(target=self._infer_worker, daemon=True) + self._infer_thread.start() + + self._store_thread = threading.Thread(target=self._store_worker, daemon=True) + self._store_thread.start() + return + # @calculate_time(show=True, min_cost_ms=150) @torch.no_grad() - def forward(self, images: List[ImageItem]): + def _forward(self, images: List[ImageItem]): return self.model.encode(images) - # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images: List[ImageItem]): - images = obtain(images) - all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) + def _get_image_items_from_infer_queue(self, max_num: int, force_same: bool = False) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) - if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] - self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache - ) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) - torch.cuda.current_stream().synchronize() - return + if not force_same: + # 尝试继续获取更多任务,直到达到 max_num + while len(tasks) < max_num: + try: + self.sempare.acquire() + task = self.infer_queue.get(block=False) + tasks.append(task) + except queue.Empty: + self.sempare.release() + break + else: + while len(tasks) < max_num: + self.sempare.acquire() + task = self.infer_queue.get(block=True) + tasks.append(task) + return tasks -class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc - self.vit_tp = vit_tp - self.rpc_server_process = rpc_server_process - self.use_rpc = True - if self.use_rpc: + def _get_image_items_from_store_queue(self, max_num: int) -> List[ImageItem]: + """ + 从队列中批量获取任务,直到达到 max_num 或队列为空。 + """ + tasks = [] + # 至少获取一个任务,阻塞 + task = self.store_queue.get(block=True) + tasks.append(task) - def async_wrap(f): - f = rpyc.async_(f) + while len(tasks) < max_num: + try: + task = self.store_queue.get(block=False) + tasks.append(task) + except queue.Empty: + break - async def _func(*args, **kwargs): - ans = f(*args, **kwargs) - await asyncio.to_thread(ans.wait) - # raise if exception - return ans.value + return tasks - return _func + def _infer_worker(self): + """ + 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 + """ + torch.cuda.set_device(self.device_id) + while True: + try: + # 从队列获取任务, 阻塞等待 + if self.tp_rank_id == 0: + images = self._get_image_items_from_infer_queue(max_num=self.infer_max_batch_size) + dist.broadcast_object_list([len(images)], src=0, group=self.gloo_group) + else: + ans = [None] + dist.broadcast_object_list(ans, src=0, group=self.gloo_group) + images = self._get_image_items_from_infer_queue(max_num=ans[0], force_same=True) - self._init_model = async_wrap(self.model.init_model) - self._encode = async_wrap(self.model.encode) - else: - self._init_model = self.model.exposed_init_model - self._encode = self.model.exposed_encode - return + # 执行任务: 调用父类的forward方法处理图像 + all_img_embeds, uuids, valid_ids = self._forward(images) + all_img_embeds = all_img_embeds.to(torch.device("cuda")) - async def init_model(self, kvargs): - ans: rpyc.AsyncResult = self._init_model(kvargs) - if self.use_rpc: - await ans - return - else: - return + if self.is_visual_only_mode: + self._store_to_afs(all_img_embeds, valid_ids, images) + else: + self._store_to_cpu_cache(all_img_embeds, valid_ids, images) - async def encode(self, images: List[ImageItem]): - ans = self._encode(images) - if self.use_rpc: - return await ans - else: - return ans + except Exception as e: + logger.exception(str(e)) + raise e + def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): + for i in range(len(images)): + start, end = valid_ids[i] + image = images[i] + if self.tp_rank_id == 0: + self.cpu_embed_cache_client.copy_vision_to_cache( + embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + ) + cuda_event = torch.cuda.Event() + cuda_event.record() + image.cuda_event = cuda_event + self.store_queue.put(image) -def _init_env(port, device_id): - # 注册graceful 退出的处理 - graceful_registry(inspect.currentframe().f_code.co_name) + def _store_to_afs(self, all_img_embeds, valid_ids, images): + all_img_embeds = all_img_embeds.detach().cpu() + for image, valid_id in zip(images, valid_ids): + start, end = valid_id + gen_embed = all_img_embeds[start:end] + image.gen_embed = gen_embed + self.store_queue.put(image) - import lightllm.utils.rpyc_fix_utils as _ + def _store_worker(self): + """ + 任务处理循环: 从队列中取出ImageItem和embed 放入 afs中, 执行完成后通知调用者 + """ + while True: + try: + # 从队列获取任务, 阻塞等待 + images: List[ImageItem] = self._get_image_items_from_store_queue(max_num=self.max_infer_batch_size) - t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) - t.start() - return + if self.is_visual_only_mode: + self._commit_to_afs(images=images) + else: + self._commit_to_cpu_cache(images=images) + for _ in images: + self.sempare.release() -async def start_model_process(port, vit_tp, device_id): - import multiprocessing + except Exception as e: + logger.exception(str(e)) + raise e - proc = multiprocessing.Process( - target=_init_env, - args=( - port, - device_id, - ), - ) - proc.start() - await asyncio.sleep(2) - repeat_count = 0 - while repeat_count < 20: - try: - con = rpyc.connect("localhost", port, config={"allow_pickle": True}) - con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - break - except BaseException: - await asyncio.sleep(1) - repeat_count += 1 - if repeat_count == 20: - raise Exception("init rpc env error!") - - assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + def _commit_to_afs(self, images): + if self.tp_rank_id == 0: + for image in images: + self.afs_handler.insert(image.md5, image.gen_embed) + image.event.set() + + def _commit_to_cpu_cache(self, images): + if self.tp_rank_id == 0: + for image in images: + # 等待拷贝到cpu cache 完成。 + image.cuda_event.synchronize() + + uuids = [image.uuid for image in images] + self.cache_client.root.set_items_embed(uuids) + + for image in images: + image.event.set() diff --git a/lightllm/server/visualserver/model_infer/model_rpc_client.py b/lightllm/server/visualserver/model_infer/model_rpc_client.py new file mode 100644 index 0000000000..682d6affcc --- /dev/null +++ b/lightllm/server/visualserver/model_infer/model_rpc_client.py @@ -0,0 +1,36 @@ +import asyncio +import rpyc +import threading +from typing import Dict, List, Tuple, Deque, Optional, Union +from lightllm.server.multimodal_params import ImageItem +from .model_rpc import VisualModelRpcServer + + +class VisualModelRpcClient: + def __init__(self, rpc_conn): + self.rpc_conn: VisualModelRpcServer = rpc_conn + + def async_wrap(f): + f = rpyc.async_(f) + + async def _func(*args, **kwargs): + ans = f(*args, **kwargs) + await asyncio.to_thread(ans.wait) + # raise if exception + return ans.value + + return _func + + self._init_model = async_wrap(self.rpc_conn.root.init_model) + self._run_task = async_wrap(self.rpc_conn.root.run_task) + + return + + async def init_model(self, kvargs): + ans: rpyc.AsyncResult = self._init_model(kvargs) + await ans + return + + async def run_task(self, images: List[ImageItem], ref_event_list: List[threading.Event]): + ans = self._run_task(images, ref_event_list) + return await ans diff --git a/lightllm/server/visualserver/objs.py b/lightllm/server/visualserver/objs.py new file mode 100644 index 0000000000..cc68e88874 --- /dev/null +++ b/lightllm/server/visualserver/objs.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip: str + port: int + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip}:{self.port}, node_id: {self.node_id}" diff --git a/lightllm/server/visualserver/proxy_manager.py b/lightllm/server/visualserver/proxy_manager.py new file mode 100644 index 0000000000..ea60648a55 --- /dev/null +++ b/lightllm/server/visualserver/proxy_manager.py @@ -0,0 +1,143 @@ +import zmq +import zmq.asyncio +import asyncio +import uvloop +import rpyc +import socket +import pickle +import inspect +import setproctitle +import threading +import collections +import base64 +import httpx +from typing import List +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.core.objs import ShmReqManager, StartArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain +from .manager import VisualManager + +logger = init_logger(__name__) + + +class ProxyVisualManager(VisualManager): + def __init__( + self, + args: StartArgs, + ): + super().__init__(args) + assert self.vit_dp == 1 and self.vit_tp == 1 + + async def handle_group_indexes(self, group_req_indexes: GroupReqIndexes): + images_need_infer = self.get_need_infer_images(group_req_indexes) + + # case 1 + if len(images_need_infer) == 0: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + # case 2 + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.extend(self.run_task(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images])) + + with self.lock: + await asyncio.gather(*taskes) + + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) + + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + def run_task(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + return taskes + + async def loop_for_netio_req(self): + if not hasattr(self, "visual_recv_max_count"): + self.visual_recv_max_count = 64 + + while True: + try: + for _ in range(self.visual_recv_max_count): + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if isinstance(recv_req, GroupReqIndexes): + logger.info( + f"visual recv req id {recv_req.group_req_id} " + f"img count {len(recv_req.multimodal_params.images)}" + ) + asyncio.create_task(self.handle_group_indexes(group_req_indexes=recv_req)) + else: + assert False, f"Error Req Inf {recv_req}" + self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) + except zmq.ZMQError: + # 当队列已经开始清空的时候,将一次接受数量下调 + self.visual_recv_max_count = 64 + await asyncio.sleep(0.01) + + async def loop_to_connect_remote_visual_server(self): + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.exception(f"Error getting VIT instances: {e}") + return None + + def clean_up(self): + return + + +def start_visual_process(args, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + start_parent_check_thread() + try: + visualserver = VisualManager(args=args) + asyncio.run(visualserver.wait_to_model_ready()) + except Exception as e: + logger.exception(str(e)) + visualserver.clean_up() + raise e + + pipe_writer.send("init ok") + + def handle_exception(loop, context): + logger.exception(f"VisualServer Caught exception: {str(context)}") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(handle_exception) + asyncio.set_event_loop(loop) + loop.run_until_complete(visualserver.loop_for_netio_req()) + return diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py new file mode 100644 index 0000000000..c9b381d453 --- /dev/null +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -0,0 +1,203 @@ +import asyncio +import uvloop +import rpyc +import inspect +import setproctitle +import threading +import collections +import uuid +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from .vit_connect import VIT_Obj +from typing import List +from lightllm.server.core.objs import StartArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from .model_infer import start_model_process, VisualModelRpcClient +from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain +from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data, free_shm + + +logger = init_logger(__name__) + + +class VisualManager(rpyc.Service): + def __init__( + self, + args: StartArgs, + ): + self.args = args + self.model_weightdir = args.model_dir + self.vit_dp = args.visual_dp + assert self.vit_dp == 1 + self.vit_tp = args.visual_tp + # image 最大推理 batch size + self.infer_batch_size = args.visual_infer_batch_size + self.cur_dp_index = 0 + self.lock = threading.Lock() + + self.new_loop = asyncio.new_event_loop() + t = threading.Thread(target=self.event_loop, args=(self.new_loop,), daemon=True) + t.start() + + def event_loop(self, loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever() + return + + async def register_to_config_server_loop(self, args: StartArgs): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip=host_ip, port=args.visual_rpyc_port) + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") + + async def wait_to_model_ready(self): + + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + self.vit_attn_backend = init_vit_att_backend(index=0) + for dp_rank_id in range(self.vit_dp): + for tp_rank_id in range(self.vit_tp): + + rpc_model = await start_model_process() + self.model_rpcs[dp_rank_id].append(rpc_model) + + init_model_ret = [] + for dp_rank_id in range(self.vit_dp): # async init model process + for tp_rank_id in range(self.vit_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] + kvargs = { + "weight_dir": self.model_weightdir, + "device_id": device_id, + "vit_tp": self.vit_tp, + "cache_port": self.args.cache_port, + "tp_rank_id": tp_rank_id, + "dp_rank_id": dp_rank_id, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + "quant_type": self.args.vit_quant_type, + "quant_cfg": self.args.vit_quant_cfg, + "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, + } + init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) + await asyncio.gather(*init_model_ret) + return + + async def handle_reqs(self, images_need_infer: List[ImageItem]): + # case 2 + dp_to_handle_images = collections.defaultdict(list) + for image in images_need_infer: + self.cur_dp_index += 1 + select_dp = self.cur_dp_index % self.vit_dp + dp_to_handle_images[select_dp].append((image, threading.Event())) + + taskes = [] + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + taskes.extend(self.run_task(dp_index, images=[e[0] for e in _images], events=[e[1] for e in _images])) + + with self.lock: + await asyncio.gather(*taskes) + + for dp_index in range(self.vit_dp): + _images = dp_to_handle_images[dp_index] + if _images: + await asyncio.to_thread(_images[-1][1].wait) + return + + def run_task(self, dp_index: int, images, events): + taskes = [] + for vit_tp_rank in range(self.vit_tp): + task = self.model_rpcs[dp_index][vit_tp_rank].run_task(images, events) + taskes.append(task) + return taskes + + def clean_up(self): + return + + def exposed_infer_images(self, images: List[ImageItem], ref_event: threading.Event): + try: + images = obtain(images) + # 将 images 的内容写入到 shm 中,这里修改了原始的uuid,主要是在远端的vit + # 本身不具有 embed cache 的引用保证,则新的唯一标识来进行推理,最终写入的 + # 目标的 md5 一致即可,这样调用端一样可以拿到准确的数据。 + for image in images: + image.uuid = str(uuid.uuid4()) + create_shm(get_shm_name_data(image.uuid), image.data_bytes) + del image.data_bytes + + handle = asyncio.run_coroutine_threadsafe(self.handle_reqs(images_need_infer=images), loop=self.new_loop) + handle.result() + + ref_event.set() + except BaseException as e: + logger.exception(str(e)) + raise e + finally: + # 将 shm 进行删除 + for image in images: + free_shm(get_shm_name_data(image.uuid)) + + return + + +def start_visual_process(args: StartArgs, pipe_writer): + import lightllm.utils.rpyc_fix_utils as _ + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") + start_parent_check_thread() + + try: + visualserver = VisualManager(args=args) + future = asyncio.run_coroutine_threadsafe(visualserver.wait_to_model_ready(), loop=visualserver.new_loop) + future.result() + + asyncio.run_coroutine_threadsafe( + visualserver.register_to_config_server_loop(args=args), loop=visualserver.new_loop + ) + t = rpyc.ThreadedServer(visualserver, port=args.visual_rpyc_port, protocol_config={"allow_pickle": True}) + except Exception as e: + logger.exception(str(e)) + visualserver.clean_up() + raise e + + pipe_writer.send("init ok") + + t.start() + return diff --git a/lightllm/server/visualserver_proxy/__init__.py b/lightllm/server/visualserver_proxy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4c..bec02cd4cc 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,14 +55,29 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): + """ + # kvargs = { + # "weight_dir": self.model_weightdir, + # "device_id": device_id, + # "vit_tp": self.vit_tp, + # "cache_port": self.args.cache_port, + # "tp_rank_id": tp_rank_id, + # "dp_rank_id": dp_rank_id, + # "data_type": self.args.data_type, + # "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + # "quant_type": self.args.vit_quant_type, + # "quant_cfg": self.args.vit_quant_cfg, + # "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + # "vit_attn_backend": self.vit_attn_backend, + # } + """ tp_world_size = kvargs["vit_tp"] dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = kvargs["visual_gpu_ids"] - device_id = visual_gpu_ids[kvargs["vit_rank_id"]] + device_id = kvargs["device_id"] set_current_device_id(device_id) torch.cuda.set_device(device_id) dist.init_process_group( diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 0000000000..acc4deb589 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,74 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + if not hasattr(args, "start_redis") or not args.start_redis: + return None + + config_server_host = args.config_server_host + redis_port = args.redis_port + try: + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + + try: + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{config_server_host}", + "--daemonize", + "no", + "--logfile", + "/dev/stdout", + "--loglevel", + "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception: + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None diff --git a/requirements.txt b/requirements.txt index 5b0b201ae3..5331227586 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,3 +94,5 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +xformers==0.0.33.post2 +redis==7.3.0