Skip to content
Open

Vit sep #1234

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ecf34ae
feat: vit seperation
Mar 11, 2026
9f43400
refine
Mar 20, 2026
a9db45d
fix
Mar 23, 2026
9ec3d88
fix
hiworldwzj Mar 24, 2026
155e8ee
fix
hiworldwzj Mar 24, 2026
d2e3919
fix
hiworldwzj Mar 24, 2026
0369dec
fix
hiworldwzj Mar 24, 2026
958cfc1
fix
hiworldwzj Mar 24, 2026
d95cba6
fix
hiworldwzj Mar 24, 2026
778c483
fix
hiworldwzj Mar 24, 2026
d9baab3
fix
hiworldwzj Mar 24, 2026
05d5094
fix
hiworldwzj Mar 24, 2026
b27633d
fix
hiworldwzj Mar 24, 2026
b62231a
fix
hiworldwzj Mar 26, 2026
05389e3
fix
hiworldwzj Mar 26, 2026
0125978
fix
hiworldwzj Mar 26, 2026
6af47b6
fix
hiworldwzj Mar 26, 2026
cee7837
fix
hiworldwzj Mar 26, 2026
b620e0e
fix
hiworldwzj Mar 26, 2026
46133a7
fix
hiworldwzj Mar 26, 2026
07a46c7
fix
hiworldwzj Mar 26, 2026
4d63f47
fix
hiworldwzj Mar 26, 2026
22f7a1c
fix
hiworldwzj Mar 27, 2026
dbb0ef3
fix
hiworldwzj Mar 27, 2026
4ba14b5
fix
hiworldwzj Mar 27, 2026
48be126
fix
hiworldwzj Mar 27, 2026
c95b7f8
fix
hiworldwzj Mar 27, 2026
abfaa6f
fix
hiworldwzj Mar 27, 2026
b42f4d5
fix
hiworldwzj Mar 27, 2026
9ce0edc
fix
hiworldwzj Mar 27, 2026
83cf458
fix
hiworldwzj Mar 27, 2026
ea6530f
fix
hiworldwzj Mar 27, 2026
cd05104
fix
hiworldwzj Mar 28, 2026
19f6cb6
fix
hiworldwzj Mar 28, 2026
d70d437
fix
hiworldwzj Mar 28, 2026
cbacd0a
fix
hiworldwzj Mar 28, 2026
bd3af5f
fix
hiworldwzj Mar 28, 2026
fce0136
fix
hiworldwzj Mar 28, 2026
53b9127
fix
hiworldwzj Mar 28, 2026
2df2ce3
fix
hiworldwzj Mar 28, 2026
3ad7daa
fix
hiworldwzj Mar 28, 2026
c03c7d0
fix
hiworldwzj Mar 28, 2026
36b4bb7
fix
hiworldwzj Mar 28, 2026
2cf51dd
fix
hiworldwzj Mar 28, 2026
f866f7b
fix
hiworldwzj Mar 28, 2026
7419e79
fix
hiworldwzj Mar 28, 2026
5ddb915
fix
hiworldwzj Mar 28, 2026
bb363bb
fix
hiworldwzj Mar 28, 2026
110c2ca
fix
hiworldwzj Mar 28, 2026
526500f
fix
hiworldwzj Mar 28, 2026
b48528a
fix
hiworldwzj Mar 28, 2026
bff7bb7
fix
hiworldwzj Mar 28, 2026
3fdc3a6
fix
hiworldwzj Mar 28, 2026
35e28b1
fix
hiworldwzj Mar 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lightllm/models/vit/triton_kernel/flashattention_nopad.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,16 @@ def flash_attention_v3_fwd(
False,
window_size[0],
window_size[1],
0.0,
attention_chunk=0,
softcap=0.0,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0,
sinks=None,
)

return
return o

except ImportError:
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
Expand Down
34 changes: 33 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"],
choices=[
"normal",
"prefill",
"decode",
"nixl_prefill",
"nixl_decode",
"pd_master",
"config_server",
"visual_only",
],
default="normal",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
Expand Down Expand Up @@ -61,6 +70,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="The port number for the config server in config_server mode.",
)
parser.add_argument(
"--config_server_visual_redis_port",
type=int,
default=None,
help="""when run_mode is config_server, set this params will start a redis server,
when a llm infer node start to set this params, the visual infer module will start a
proxy module use config server to find remote vit infer nodes to infer img""",
)
parser.add_argument(
"--nixl_pd_kv_page_num",
type=int,
Expand Down Expand Up @@ -440,6 +457,15 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502",
)
parser.add_argument(
"--visual_rpyc_port",
type=int,
default=None,
help="""
when run_mode is visual_only, set this port, make others to call local visual infer to
transfer image to embed.
""",
)
parser.add_argument(
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)
Expand Down Expand Up @@ -605,6 +631,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=0.03,
help="""The interval of the schedule time, default is 30ms.""",
)
parser.add_argument(
"--afs_image_embed_dir",
type=str,
default=None,
help="path for vit embed, when use vit remote infer mode",
)
parser.add_argument(
"--enable_cpu_cache",
action="store_true",
Expand Down
20 changes: 18 additions & 2 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
from .api_lightllm import lightllm_get_score
from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.error_utils import ServerBusyError
Expand Down Expand Up @@ -92,6 +92,8 @@ def set_args(self, args: StartArgs):
self.httpserver_manager = HttpServerManagerForPDMaster(
args=args,
)
elif args.run_mode in ["visual", "visual_only"]:
self.metric_client = MetricClient(args.metric_port)
else:
init_tokenizer(args) # for openai api
SamplingParams.load_generation_cfg(args.model_dir)
Expand Down Expand Up @@ -136,7 +138,7 @@ def get_model_name():
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
async def healthcheck(request: Request):
if g_objs.args.run_mode == "pd_master":
if g_objs.args.run_mode in ["pd_master", "visual", "visual_only"]:
return JSONResponse({"message": "Ok"}, status_code=200)

if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
Expand Down Expand Up @@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/get_image_embedding")
async def get_image_embed(request: Request) -> Response:
try:
return await lightllm_get_image_embedding(request, g_objs.httpserver_manager)
except ServerBusyError as e:
logger.error("%s", str(e), exc_info=True)
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/")
async def compat_generate(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
Expand Down Expand Up @@ -359,6 +373,8 @@ async def startup_event():
logger.info("server start up")
loop = asyncio.get_event_loop()
g_objs.set_args(get_env_start_args())
if g_objs.httpserver_manager is None:
return
loop.create_task(g_objs.httpserver_manager.handle_loop())
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
return
21 changes: 20 additions & 1 deletion lightllm/server/api_lightllm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import collections
from typing import AsyncGenerator
from fastapi import BackgroundTasks, Request
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import Response, StreamingResponse, JSONResponse
from lightllm.server.core.objs.sampling_params import SamplingParams
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from lightllm.utils.envs_utils import get_env_start_args
import ujson as json


Expand Down Expand Up @@ -150,3 +151,21 @@ async def stream_results() -> AsyncGenerator[bytes, None]:

background_tasks = BackgroundTasks()
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)


async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response:
request_dict = await request.json()
args = get_env_start_args()
assert not args.disable_vision
assert args.enable_remote_vit
sample_params_dict = request_dict["parameters"]
sampling_params = SamplingParams()
sampling_params.init(tokenizer=None, **sample_params_dict)
sampling_params.verify()
multimodal_params_dict = request_dict.get("multimodal_params", {})
assert not multimodal_params_dict.get("audios")
multimodal_params = MultimodalParams(**multimodal_params_dict)

await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request)

return JSONResponse({"message": "OK"}, status_code=200)
4 changes: 3 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
parser = make_argument_parser()
args = parser.parse_args()
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start

if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
elif args.run_mode in ["visual", "visual_only"]:
visual_start(args)
else:
normal_or_p_d_start(args)
128 changes: 121 additions & 7 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,37 @@
from .router.manager import start_router_process
from lightllm.utils.process_check import is_process_active
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.redis_utils import start_redis_service
from lightllm.utils.shm_size_check import check_recommended_shm_size
from lightllm.utils.config_utils import has_audio_module, has_vision_module

logger = init_logger(__name__)


def _ensure_remote_vit_embed_dir(image_embed_dir: str) -> None:
if os.path.exists(image_embed_dir):
if not os.path.isdir(image_embed_dir):
raise ValueError(f"image_embed_dir is not a directory: {image_embed_dir}")
return

os.makedirs(image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(image_embed_dir, 0o777)
Comment on lines +31 to +32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using 0o777 permissions is highly permissive and can be a security risk, as it gives read, write, and execute permissions to everyone on the system. Consider using more restrictive permissions, such as 0o755 for directories, depending on the access requirements.

Suggested change
os.makedirs(image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(image_embed_dir, 0o777)
os.makedirs(image_embed_dir, mode=0o755, exist_ok=True)
os.chmod(image_embed_dir, 0o755)



def _prepare_remote_vit_embed_dir(args):
remote_vit_mode = args.enable_remote_vit or args.run_mode in ["visual", "visual_only"]
if not remote_vit_mode:
return

if not args.image_embed_dir:
raise ValueError("remote vit mode requires --image_embed_dir to be set")

args.image_embed_dir = os.path.abspath(args.image_embed_dir)
_ensure_remote_vit_embed_dir(args.image_embed_dir)

logger.info(f"using image_embed_dir: {args.image_embed_dir}")


def setup_signal_handlers(http_server_process, process_manager):
def signal_handler(sig, frame):
if sig == signal.SIGINT:
Expand Down Expand Up @@ -57,11 +82,12 @@ def signal_handler(sig, frame):
signal.signal(signal.SIGINT, signal_handler)

logger.info(f"start process pid {os.getpid()}")
logger.info(f"http server pid {http_server_process.pid}")
if http_server_process:
logger.info(f"http server pid {http_server_process.pid}")
return


def normal_or_p_d_start(args):
def normal_or_p_d_start(args, only_prepare=False):
from lightllm.server.core.objs.start_args_type import StartArgs

args: StartArgs = args
Expand All @@ -73,7 +99,7 @@ def normal_or_p_d_start(args):

enable_mps()

if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]:
if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual", "visual_only"]:
return

# 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块
Expand Down Expand Up @@ -168,6 +194,7 @@ def normal_or_p_d_start(args):
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0

_prepare_remote_vit_embed_dir(args)
# 检查GPU数量是否足够
if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
Expand Down Expand Up @@ -229,20 +256,27 @@ def normal_or_p_d_start(args):
args.data_type = get_dtype(args.model_dir)
assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]

if only_prepare:
return

already_uesd_ports = [args.port]
if args.nccl_port is not None:
already_uesd_ports.append(args.nccl_port)
if args.pd_decode_rpyc_port is not None:
already_uesd_ports.append(args.pd_decode_rpyc_port)
if args.visual_nccl_ports is not None:
already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp])

# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
# 捕获到端口设置冲突的问题
ports_locker = PortLocker(already_uesd_ports)
ports_locker.lock_port()

node_world_size = args.tp // args.nnodes
need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp
can_use_ports = alloc_can_use_network_port(
num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
num=10 + node_world_size + args.visual_dp * args.visual_tp + need_visual_nccl_ports,
used_ports=already_uesd_ports,
)
logger.info(f"alloced ports: {can_use_ports}")
(
Expand All @@ -265,8 +299,12 @@ def normal_or_p_d_start(args):
tp_ports_for_dp = can_use_ports[0 : args.visual_tp]
visual_model_tp_ports.append(tp_ports_for_dp)
can_use_ports = can_use_ports[args.visual_tp :]
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]
if args.visual_nccl_ports is None:
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]

if args.visual_nccl_ports is not None:
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]

# 将申请好的端口放入args参数中
if args.nccl_port is None:
Expand Down Expand Up @@ -316,7 +354,7 @@ def normal_or_p_d_start(args):
start_args=[(args,)],
)

if not args.disable_vision:
if not args.disable_vision and not args.enable_remote_vit:
from .visualserver.manager import start_visual_process

process_manager.start_submodule_processes(
Expand Down Expand Up @@ -463,13 +501,89 @@ def pd_master_start(args):
http_server_process.wait()


def visual_start(args):
normal_or_p_d_start(args, only_prepare=True)

already_uesd_ports = [args.remote_vit_port]
if args.nccl_port is not None:
already_uesd_ports.append(args.nccl_port)
if args.visual_nccl_ports is not None:
already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp])

need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp
can_use_ports = alloc_can_use_network_port(
num=5 + args.visual_dp * args.visual_tp + need_visual_nccl_ports,
used_ports=already_uesd_ports,
)
logger.info(f"alloced ports: {can_use_ports}")
(
router_port,
visual_port,
audio_port,
cache_port,
metric_port,
) = can_use_ports[0:5]
can_use_ports = can_use_ports[5:]

visual_nccl_ports = []
for _ in range(args.visual_dp):
if args.visual_nccl_ports is None:
visual_nccl_ports.append(can_use_ports[0])
can_use_ports = can_use_ports[1:]

if args.visual_nccl_ports is not None:
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
else:
args.visual_nccl_ports = visual_nccl_ports

args.router_port = router_port
args.visual_port = visual_port
args.audio_port = audio_port
args.cache_port = cache_port
args.metric_port = metric_port
args.visual_node_id = uuid.uuid4().int

logger.info(f"all start args:{args}")

set_env_start_args(args)

from .visualserver.manager import start_visual_process

process_manager.start_submodule_processes(
start_funcs=[
start_cache_manager,
],
start_args=[(args,)],
)
process_manager.start_submodule_processes(
start_funcs=[
start_visual_process,
],
start_args=[
(args,),
],
)
setup_signal_handlers(None, process_manager)
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down...")
process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
sys.exit(0)


def config_server_start(args):
set_unique_server_name(args)
if args.run_mode != "config_server":
return

logger.info(f"all start args:{args}")

if args.start_redis:
start_redis_service(args)

set_env_start_args(args)

command = [
Expand Down
Loading
Loading