diff --git a/dev/run_qwen3_5_localbackend_yes_no_maybe.py b/dev/run_qwen3_5_localbackend_yes_no_maybe.py index ce2e3db5..ad79d213 100644 --- a/dev/run_qwen3_5_localbackend_yes_no_maybe.py +++ b/dev/run_qwen3_5_localbackend_yes_no_maybe.py @@ -51,6 +51,11 @@ def _format_int_list(values: list[int]) -> str: parser.add_argument( "--enable-thinking", action=argparse.BooleanOptionalAction, default=False ) +parser.add_argument( + "--rollout-weights-mode", + choices=("lora", "merged"), + default=None, +) parser.add_argument("--trainer-gpu-ids", type=int, nargs="+") parser.add_argument("--inference-gpu-ids", type=int, nargs="+") args = parser.parse_args() @@ -98,6 +103,8 @@ def _format_int_list(values: list[int]) -> str: f"INFERENCE_GPU_IDS={_format_int_list(args.inference_gpu_ids)}", ] ) +if args.rollout_weights_mode is not None: + env.append(f"ROLLOUT_WEIGHTS_MODE={args.rollout_weights_mode}") env_block = " \\\n ".join(env) run_script = textwrap.dedent( @@ -143,6 +150,7 @@ def _format_int_list(values: list[int]) -> str: print(f" load_in_4bit: {args.load_in_4bit}") print(f" load_in_16bit: {args.load_in_16bit}") print(f" enable_thinking: {args.enable_thinking}") +print(f" rollout_weights_mode: {args.rollout_weights_mode}") print(f" trainer_gpu_ids: {args.trainer_gpu_ids}") print(f" inference_gpu_ids: {args.inference_gpu_ids}") diff --git a/dev/yes-no-maybe-metrics.py b/dev/yes-no-maybe-metrics.py index 1a418feb..5ada7b44 100644 --- a/dev/yes-no-maybe-metrics.py +++ b/dev/yes-no-maybe-metrics.py @@ -223,6 +223,12 @@ def build_internal_config() -> art.dev.InternalModelConfig: result["trainer_gpu_ids"] = trainer_gpu_ids result["inference_gpu_ids"] = inference_gpu_ids + rollout_weights_mode = os.environ.get("ROLLOUT_WEIGHTS_MODE") + if rollout_weights_mode is not None: + if rollout_weights_mode not in {"lora", "merged"}: + raise ValueError("ROLLOUT_WEIGHTS_MODE must be either 'lora' or 'merged'") + result["rollout_weights_mode"] = rollout_weights_mode + return result diff --git a/pyproject.toml b/pyproject.toml index d7e16eb9..c64dd997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ backend = [ "pytest>=8.4.1", "nbmake>=1.5.5", "gql<4", + "nvidia-cudnn-frontend<1.21 ; sys_platform == 'linux'", "vllm @ https://github.com/vivekkalyan/vllm/releases/download/v0.17.0-art1/vllm-0.17.0%2Bart1-cp38-abi3-manylinux_2_31_x86_64.whl ; sys_platform == 'linux'", ] megatron = [ diff --git a/src/art/dev/engine.py b/src/art/dev/engine.py index d0cef938..d79384f7 100644 --- a/src/art/dev/engine.py +++ b/src/art/dev/engine.py @@ -3,6 +3,10 @@ from typing_extensions import TypedDict +class WeightTransferConfig(TypedDict): + backend: Literal["nccl"] + + class EngineArgs(TypedDict, total=False): model: str served_model_name: str | list[str] | None @@ -124,6 +128,7 @@ class EngineArgs(TypedDict, total=False): calculate_kv_scales: bool | None additional_config: dict[str, Any] | None + weight_transfer_config: WeightTransferConfig | None disable_log_requests: ( bool # Deprecated in vLLM 0.13+, use enable_log_requests instead diff --git a/src/art/dev/get_model_config.py b/src/art/dev/get_model_config.py index d56a20eb..6835dfb9 100644 --- a/src/art/dev/get_model_config.py +++ b/src/art/dev/get_model_config.py @@ -14,6 +14,7 @@ def get_model_config( config = InternalModelConfig() dedicated = is_dedicated_mode(config) + rollout_weights_mode = config.get("rollout_weights_mode", "lora") if dedicated: enable_sleep_mode = False @@ -78,6 +79,7 @@ def get_model_config( init_args=init_args, engine_args=engine_args, peft_args=peft_args, + rollout_weights_mode=rollout_weights_mode, tinker_args=config.get("tinker_args"), trainer_args=trainer_args, ) diff --git a/src/art/dev/model.py b/src/art/dev/model.py index 53d88711..e55b35d1 100644 --- a/src/art/dev/model.py +++ b/src/art/dev/model.py @@ -1,9 +1,12 @@ from enum import Enum +from typing import Literal from typing_extensions import Required, TypedDict from .engine import EngineArgs +RolloutWeightsMode = Literal["lora", "merged"] + # Vendored from transformers.training_args.OptimizerNames class OptimizerNames(str, Enum): @@ -120,6 +123,10 @@ class InternalModelConfig(TypedDict, total=False): inference run on separate GPUs. inference_gpu_ids: GPU IDs for vLLM inference (e.g., [1]). When set with trainer_gpu_ids, enables dedicated mode. + rollout_weights_mode: How inference weights are applied in vLLM. + - "lora": load LoRA adapters into vLLM directly + - "merged": keep training LoRA adapters, but push merged weights + into vLLM for inference """ init_args: "InitArgs" @@ -130,6 +137,7 @@ class InternalModelConfig(TypedDict, total=False): trainer_args: "TrainerArgs" trainer_gpu_ids: list[int] inference_gpu_ids: list[int] + rollout_weights_mode: "RolloutWeightsMode" class TinkerArgs(TypedDict, total=False): diff --git a/src/art/dev/validate.py b/src/art/dev/validate.py index 031464e0..7ab8c6a1 100644 --- a/src/art/dev/validate.py +++ b/src/art/dev/validate.py @@ -1,6 +1,11 @@ """Validation functions for model configuration.""" -from .model import InternalModelConfig +from .model import InternalModelConfig, RolloutWeightsMode + +QWEN3_5_MOE_MODELS = { + "Qwen/Qwen3.5-35B-A3B", + "Qwen/Qwen3.5-397B-A17B", +} def is_dedicated_mode(config: InternalModelConfig) -> bool: @@ -8,6 +13,18 @@ def is_dedicated_mode(config: InternalModelConfig) -> bool: return "trainer_gpu_ids" in config and "inference_gpu_ids" in config +def _rollout_weights_mode(config: InternalModelConfig) -> RolloutWeightsMode: + mode = config.get("rollout_weights_mode", "lora") + if mode in {"lora", "merged"}: + return mode + raise ValueError("rollout_weights_mode must be either 'lora' or 'merged'") + + +def _is_qwen3_5_moe_model(config: InternalModelConfig) -> bool: + model_name = config.get("engine_args", {}).get("model") + return model_name in QWEN3_5_MOE_MODELS + + def validate_dedicated_config(config: InternalModelConfig) -> None: """Validate dedicated mode GPU configuration. @@ -16,12 +33,19 @@ def validate_dedicated_config(config: InternalModelConfig) -> None: """ has_trainer = "trainer_gpu_ids" in config has_inference = "inference_gpu_ids" in config + rollout_weights_mode = _rollout_weights_mode(config) if has_trainer != has_inference: raise ValueError( "trainer_gpu_ids and inference_gpu_ids must both be set or both unset" ) + if rollout_weights_mode == "merged" and not has_trainer: + raise ValueError( + "rollout_weights_mode='merged' requires dedicated mode " + "(set both trainer_gpu_ids and inference_gpu_ids)" + ) + if not has_trainer: return @@ -65,3 +89,9 @@ def validate_dedicated_config(config: InternalModelConfig) -> None: "enable_sleep_mode is incompatible with dedicated mode " "(dedicated mode runs vLLM on a separate GPU, sleep/wake is not needed)" ) + + if _is_qwen3_5_moe_model(config) and rollout_weights_mode == "lora": + raise ValueError( + "Qwen3.5-MoE models require rollout_weights_mode='merged' with the " + "current vLLM version because direct LoRA inference is currently broken" + ) diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 5b6a563c..0a1e17af 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -6,6 +6,7 @@ import json import logging import os +import socket import subprocess import sys from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast @@ -207,6 +208,21 @@ def _get_trainer_optimizer(trainer: GRPOTrainer) -> Optimizer: return optimizer +def _find_free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return cast(int, sock.getsockname()[1]) + + +def _normalize_merged_checkpoint_name(name: str) -> str: + # PEFT wraps adapted modules under `.base_layer`, but vLLM expects the + # original checkpoint parameter names during update_weights(). + normalized = name.removeprefix("base_model.model.") + while ".base_layer." in normalized: + normalized = normalized.replace(".base_layer.", ".") + return normalized + + # ============================================================================ # Model Classes # ============================================================================ @@ -325,12 +341,23 @@ class UnslothService: _vllm_log_file: Any = field(default=None, repr=False) _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 + _weight_transfer_group: Any = field(default=None, init=False, repr=False) _train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False) @property def is_dedicated(self) -> bool: return is_dedicated_mode(self.config) + @property + def rollout_weights_mode(self) -> Literal["lora", "merged"]: + mode = self.config["rollout_weights_mode"] + assert mode in {"lora", "merged"} + return mode + + @property + def _vllm_base_url(self) -> str: + return f"http://{self._vllm_host}:{self._vllm_port}" + def _next_lora_id(self) -> int: """Return a new unique LoRA ID to avoid collisions in vLLM.""" self._lora_id_counter += 1 @@ -387,8 +414,13 @@ async def _start_vllm_subprocess( if config and "engine_args" in config: engine_args.update(dict(config["engine_args"])) engine_args.setdefault("generation_config", "vllm") - engine_args["enable_lora"] = True - engine_args.setdefault("max_loras", 2) + if self.rollout_weights_mode == "merged": + engine_args["weight_transfer_config"] = {"backend": "nccl"} + engine_args.pop("enable_lora", None) + engine_args.pop("max_loras", None) + else: + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) for key in ("model", "served_model_name", "enable_sleep_mode"): engine_args.pop(key, None) @@ -402,6 +434,7 @@ async def _start_vllm_subprocess( f"--cuda-visible-devices={cuda_devices}", f"--lora-path={lora_path}", f"--served-model-name={self.model_name}@{self._latest_step}", + f"--rollout-weights-mode={self.rollout_weights_mode}", f"--engine-args-json={json.dumps(engine_args)}", f"--server-args-json={json.dumps(server_args)}", ] @@ -451,6 +484,194 @@ async def _start_vllm_subprocess( logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices) return self._vllm_host, self._vllm_port + async def _set_served_model_name(self, step: int) -> None: + import httpx + + served_model_name = f"{self.model_name}@{step}" + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/art/set_served_model_name", + json={"name": served_model_name}, + timeout=30.0, + ) + response.raise_for_status() + logger.info( + "[DEDICATED] Updated merged rollout alias to %s", + served_model_name, + ) + + async def _init_merged_weight_transfer(self) -> None: + import httpx + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + ) + + if self._weight_transfer_group is not None: + return + + async with httpx.AsyncClient() as client: + world_size_response = await client.get( + f"{self._vllm_base_url}/get_world_size", + timeout=30.0, + ) + try: + world_size_response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise RuntimeError( + "Merged rollout weights require a vLLM build with the " + "/get_world_size endpoint" + ) from exc + inference_world_size = int(world_size_response.json()["world_size"]) + + master_port = _find_free_tcp_port() + init_info = { + "master_address": "127.0.0.1", + "master_port": master_port, + "rank_offset": 1, + "world_size": inference_world_size + 1, + } + + remote_init_task = asyncio.create_task( + client.post( + f"{self._vllm_base_url}/init_weight_transfer_engine", + json={"init_info": init_info}, + timeout=300.0, + ) + ) + # TODO: replace this with a real readiness handshake if this ever flakes. + await asyncio.sleep(1.0) + self._weight_transfer_group = await asyncio.to_thread( + NCCLWeightTransferEngine.trainer_init, + { + "master_address": init_info["master_address"], + "master_port": init_info["master_port"], + "world_size": init_info["world_size"], + }, + ) + remote_init_response = await remote_init_task + try: + remote_init_response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise RuntimeError( + "Merged rollout weights require a vLLM build with the " + "/init_weight_transfer_engine endpoint" + ) from exc + + logger.info( + "[DEDICATED] Initialized merged weight transfer: inference_world_size=%d", + inference_world_size, + ) + + def _merged_checkpoint_weights(self) -> list[tuple[str, torch.Tensor]]: + model = self._state.peft_model.base_model.model + device = next(model.parameters()).device + assert device.type == "cuda" + + weights: list[tuple[str, torch.Tensor]] = [] + normalized_names: set[str] = set() + for name, tensor in model.state_dict().items(): + if "lora_" in name: + continue + normalized_name = _normalize_merged_checkpoint_name(name) + assert normalized_name not in normalized_names + normalized_names.add(normalized_name) + detached = tensor.detach() + if detached.device != device: + detached = detached.to(device=device, non_blocking=True) + weights.append((normalized_name, detached)) + + assert weights + return weights + + async def _sync_merged_weights( + self, + step: int, + pause_generation: bool, + ) -> None: + import httpx + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, + ) + + assert self._weight_transfer_group is not None + + peft_model = self._state.peft_model + merged = False + error: Exception | None = None + logger.info("[DEDICATED] Syncing merged rollout weights for step %d", step) + + async with httpx.AsyncClient() as client: + try: + if pause_generation: + response = await client.post( + f"{self._vllm_base_url}/pause", + params={"mode": "wait"}, + timeout=300.0, + ) + response.raise_for_status() + + peft_model.merge_adapter() + merged = True + torch.cuda.synchronize() + + weights = self._merged_checkpoint_weights() + update_info = { + "names": [name for name, _ in weights], + "dtype_names": [ + str(tensor.dtype).removeprefix("torch.") + for _, tensor in weights + ], + "shapes": [list(tensor.shape) for _, tensor in weights], + "is_checkpoint_format": True, + } + + _, update_response = await asyncio.gather( + asyncio.to_thread( + NCCLWeightTransferEngine.trainer_send_weights, + iter(weights), + {"group": self._weight_transfer_group}, + ), + client.post( + f"{self._vllm_base_url}/update_weights", + json={"update_info": update_info}, + timeout=600.0, + ), + ) + try: + update_response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise RuntimeError( + "Merged rollout weights require a vLLM build with the " + "/update_weights endpoint" + ) from exc + self._latest_step = step + await self._set_served_model_name(step) + except Exception as exc: + error = exc + raise + finally: + if merged: + peft_model.unmerge_adapter() + torch.cuda.synchronize() + if pause_generation: + try: + response = await client.post( + f"{self._vllm_base_url}/resume", + timeout=30.0, + ) + response.raise_for_status() + except Exception: + if error is None: + raise + logger.exception( + "Failed to resume generation after merged weight sync error" + ) + + logger.info( + "[DEDICATED] Merged rollout sync complete for step %d", + step, + ) + async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: """Reload LoRA adapter in vLLM subprocess via HTTP.""" import httpx @@ -478,6 +699,7 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: def close(self) -> None: """Terminate vLLM subprocess if running.""" + self._weight_transfer_group = None if self._vllm_process is None: return self._vllm_process.terminate() @@ -510,7 +732,16 @@ async def start_openai_server( if self.is_dedicated: port = (config or {}).get("server_args", {}).get("port", 8000) - return await self._start_vllm_subprocess(lora_path, port, config=config) + vllm_location = await self._start_vllm_subprocess( + lora_path, + port, + config=config, + ) + if self.rollout_weights_mode == "merged": + _ = self._state + await self._init_merged_weight_transfer() + await self._sync_merged_weights(self._latest_step, False) + return vllm_location # Shared mode: in-process vLLM self._state.offload_to_cpu() @@ -544,7 +775,10 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: f"checkpoint_dir={checkpoint_dir} is_dedicated={self.is_dedicated}" ) if self.is_dedicated: - await self._reload_adapter(checkpoint_dir, step) + if self.rollout_weights_mode == "merged": + await self._set_served_model_name(step) + else: + await self._reload_adapter(checkpoint_dir, step) self._latest_step = step return @@ -655,14 +889,21 @@ async def _train_dedicated( ) new_step = int(os.path.basename(checkpoint_dir)) - logger.info( - f"[DEDICATED] _train_dedicated: saved checkpoint step={new_step}, " - f"reloading adapter..." - ) - await self._reload_adapter(checkpoint_dir, new_step) + if self.rollout_weights_mode == "merged": + logger.info( + "[DEDICATED] _train_dedicated: saved checkpoint step=%s, syncing merged weights...", + new_step, + ) + await self._sync_merged_weights(new_step, True) + else: + logger.info( + "[DEDICATED] _train_dedicated: saved checkpoint step=%s, reloading adapter...", + new_step, + ) + await self._reload_adapter(checkpoint_dir, new_step) self._latest_step = new_step logger.info( - f"[DEDICATED] _train_dedicated: adapter reloaded for step {new_step}" + f"[DEDICATED] _train_dedicated: inference weights updated for step {new_step}" ) async def _train_shared( diff --git a/src/art/vllm/dedicated_server.py b/src/art/vllm/dedicated_server.py index 72e60cae..47921be6 100644 --- a/src/art/vllm/dedicated_server.py +++ b/src/art/vllm/dedicated_server.py @@ -19,8 +19,14 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument("--port", type=int, required=True) parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--cuda-visible-devices", required=True) - parser.add_argument("--lora-path", required=True, help="Initial LoRA adapter path") + parser.add_argument("--lora-path", required=True, help="Initial checkpoint path") parser.add_argument("--served-model-name", required=True) + parser.add_argument( + "--rollout-weights-mode", + choices=("lora", "merged"), + default="lora", + help="Whether the dedicated server serves LoRA adapters or merged weights", + ) parser.add_argument( "--engine-args-json", default="{}", help="Additional engine args as JSON" ) @@ -32,12 +38,75 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: return parser.parse_args(argv) +def _patch_art_dedicated_routes() -> None: + from fastapi import APIRouter, FastAPI, Request + from fastapi.responses import JSONResponse + from vllm.entrypoints.openai import api_server + from vllm.tasks import SupportedTask + + if getattr(api_server, "_art_dedicated_routes_patched", False): + return + + original_build_app = api_server.build_app + + def art_build_app( + args: argparse.Namespace, + supported_tasks: tuple[SupportedTask, ...] | None = None, + ) -> FastAPI: + app = original_build_app(args, supported_tasks) + router = APIRouter() + + @router.post("/art/set_served_model_name") + async def set_served_model_name(raw_request: Request) -> JSONResponse: + body = await raw_request.json() + name = body["name"] + assert isinstance(name, str) and name + models = raw_request.app.state.openai_serving_models + assert models.base_model_paths + models.base_model_paths[0].name = name + return JSONResponse(content={"name": name}) + + app.include_router(router) + return app + + setattr(api_server, "build_app", art_build_app) + setattr(api_server, "_art_dedicated_routes_patched", True) + + +def _append_cli_arg(vllm_args: list[str], key: str, value: object) -> None: + cli_key = f"--{key.replace('_', '-')}" + match value: + case True: + vllm_args.append(cli_key) + case False | None: + return + case str() | int() | float(): + vllm_args.append(f"{cli_key}={value}") + case dict(): + vllm_args.append(f"{cli_key}={json.dumps(value)}") + case list(): + for item in value: + match item: + case str() | int() | float(): + vllm_args.append(f"{cli_key}={item}") + case dict(): + vllm_args.append(f"{cli_key}={json.dumps(item)}") + case _: + assert False, ( + f"Unsupported CLI list item for {key}: {type(item)}" + ) + case _: + assert False, f"Unsupported CLI arg for {key}: {type(value)}" + + def main(argv: list[str] | None = None) -> None: args = parse_args(argv) # Must set CUDA_VISIBLE_DEVICES before any torch/CUDA import os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" + if args.rollout_weights_mode == "merged": + os.environ["VLLM_SERVER_DEV_MODE"] = "1" # Patches must be applied before vLLM's api_server is imported from .patches import ( @@ -60,27 +129,25 @@ def main(argv: list[str] | None = None) -> None: engine_args = json.loads(args.engine_args_json) server_args = json.loads(args.server_args_json) + if args.rollout_weights_mode == "merged": + _patch_art_dedicated_routes() + vllm_args = [ f"--model={args.model}", f"--port={args.port}", f"--host={args.host}", f"--served-model-name={args.served_model_name}", - "--enable-lora", - f"--lora-modules={args.served_model_name}={args.lora_path}", ] + if args.rollout_weights_mode == "lora": + vllm_args.extend( + [ + "--enable-lora", + f"--lora-modules={args.served_model_name}={args.lora_path}", + ] + ) for extra_args in (engine_args, server_args): for key, value in extra_args.items(): - if value is None: - continue - cli_key = f"--{key.replace('_', '-')}" - if isinstance(value, bool): - if value: - vllm_args.append(cli_key) - elif isinstance(value, list): - for item in value: - vllm_args.append(f"{cli_key}={item}") - else: - vllm_args.append(f"{cli_key}={value}") + _append_cli_arg(vllm_args, key, value) vllm_parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server." diff --git a/tests/unit/test_dedicated_config.py b/tests/unit/test_dedicated_config.py index 3de780ef..d1e2f63d 100644 --- a/tests/unit/test_dedicated_config.py +++ b/tests/unit/test_dedicated_config.py @@ -143,6 +143,7 @@ def test_get_model_config_shared_mode(): assert "inference_gpu_ids" not in result assert result["engine_args"]["enable_sleep_mode"] is True assert result["init_args"].get("fast_inference") is False + assert result["rollout_weights_mode"] == "lora" def test_get_model_config_dedicated_mode(): @@ -158,6 +159,7 @@ def test_get_model_config_dedicated_mode(): assert result["inference_gpu_ids"] == [1] assert result["engine_args"]["enable_sleep_mode"] is False assert "fast_inference" not in result["init_args"] + assert result["rollout_weights_mode"] == "lora" def test_get_model_config_dedicated_preserves_user_engine_args(): @@ -173,3 +175,71 @@ def test_get_model_config_dedicated_preserves_user_engine_args(): assert result["engine_args"]["max_model_len"] == 4096 # Sleep mode should still be disabled even if user didn't set it assert result["engine_args"]["enable_sleep_mode"] is False + + +def test_get_model_config_preserves_rollout_weights_mode(): + from art.dev.get_model_config import get_model_config + + with tempfile.TemporaryDirectory() as tmpdir: + config = InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ) + result = get_model_config("test-model", tmpdir, config) + assert result["rollout_weights_mode"] == "merged" + + +def test_invalid_rollout_weights_mode(): + with pytest.raises( + ValueError, match="rollout_weights_mode must be either 'lora' or 'merged'" + ): + validate_dedicated_config( + InternalModelConfig(rollout_weights_mode="bad-mode") # type: ignore[typeddict-item] + ) + + +def test_merged_rollout_weights_requires_dedicated_mode(): + with pytest.raises( + ValueError, match="rollout_weights_mode='merged' requires dedicated mode" + ): + validate_dedicated_config(InternalModelConfig(rollout_weights_mode="merged")) + + +def test_qwen3_5_moe_requires_merged_rollout_weights(): + with pytest.raises( + ValueError, + match="Qwen3.5-MoE models require rollout_weights_mode='merged'", + ): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + engine_args={"model": "Qwen/Qwen3.5-35B-A3B"}, # type: ignore[typeddict-item] + ) + ) + + +def test_qwen3_5_moe_allows_merged_rollout_weights(): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + engine_args={"model": "Qwen/Qwen3.5-35B-A3B"}, # type: ignore[typeddict-item] + ) + ) + + +def test_other_qwen3_5_moe_requires_merged_rollout_weights(): + with pytest.raises( + ValueError, + match="Qwen3.5-MoE models require rollout_weights_mode='merged'", + ): + validate_dedicated_config( + InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + engine_args={"model": "Qwen/Qwen3.5-397B-A17B"}, # type: ignore[typeddict-item] + ) + ) diff --git a/tests/unit/test_dedicated_server.py b/tests/unit/test_dedicated_server.py index 0acf7baa..11209cef 100644 --- a/tests/unit/test_dedicated_server.py +++ b/tests/unit/test_dedicated_server.py @@ -5,7 +5,7 @@ pytest.importorskip("cloudpickle") pytest.importorskip("vllm") -from art.vllm.dedicated_server import parse_args +from art.vllm.dedicated_server import _append_cli_arg, parse_args def test_parse_args_required(): @@ -29,6 +29,7 @@ def test_parse_args_required(): assert args.lora_path == "/tmp/checkpoints/0000" assert args.served_model_name == "my-model@0" assert args.host == "127.0.0.1" + assert args.rollout_weights_mode == "lora" assert args.engine_args_json == "{}" assert args.server_args_json == "{}" @@ -95,3 +96,47 @@ def test_parse_args_with_server_args(): server_args = json.loads(args.server_args_json) assert server_args["enable_auto_tool_choice"] is True assert server_args["tool_call_parser"] == "hermes" + + +def test_parse_args_merged_mode(): + args = parse_args( + [ + "--model", + "test-model", + "--port", + "8000", + "--cuda-visible-devices", + "1", + "--lora-path", + "/tmp/lora", + "--served-model-name", + "test@0", + "--rollout-weights-mode", + "merged", + ] + ) + + assert args.rollout_weights_mode == "merged" + assert args.lora_path == "/tmp/lora" + + +def test_parse_args_requires_lora_path(): + with pytest.raises(SystemExit): + parse_args( + [ + "--model", + "test-model", + "--port", + "8000", + "--cuda-visible-devices", + "1", + "--served-model-name", + "test@0", + ] + ) + + +def test_append_cli_arg_serializes_dict_values(): + args: list[str] = [] + _append_cli_arg(args, "weight_transfer_config", {"backend": "nccl"}) + assert args == ['--weight-transfer-config={"backend": "nccl"}'] diff --git a/tests/unit/test_merged_weight_names.py b/tests/unit/test_merged_weight_names.py new file mode 100644 index 00000000..bc9b4890 --- /dev/null +++ b/tests/unit/test_merged_weight_names.py @@ -0,0 +1,55 @@ +import pytest + +pytest.importorskip("trl") +pytest.importorskip("vllm") + +from art.unsloth.service import _normalize_merged_checkpoint_name + + +def test_normalize_merged_checkpoint_name_strips_peft_wrapper_segments(): + assert ( + _normalize_merged_checkpoint_name( + "model.language_model.layers.3.self_attn.q_proj.base_layer.weight" + ) + == "model.language_model.layers.3.self_attn.q_proj.weight" + ) + assert ( + _normalize_merged_checkpoint_name( + "model.language_model.layers.3.mlp.shared_expert.gate_proj.base_layer.weight" + ) + == "model.language_model.layers.3.mlp.shared_expert.gate_proj.weight" + ) + assert ( + _normalize_merged_checkpoint_name( + "model.language_model.layers.3.mlp.experts.base_layer.base_layer.gate_up_proj" + ) + == "model.language_model.layers.3.mlp.experts.gate_up_proj" + ) + assert ( + _normalize_merged_checkpoint_name( + "model.language_model.layers.3.mlp.experts.base_layer.base_layer.down_proj" + ) + == "model.language_model.layers.3.mlp.experts.down_proj" + ) + + +def test_normalize_merged_checkpoint_name_strips_peft_prefix(): + assert ( + _normalize_merged_checkpoint_name( + "base_model.model.model.language_model.layers.7.self_attn.o_proj.base_layer.weight" + ) + == "model.language_model.layers.7.self_attn.o_proj.weight" + ) + assert ( + _normalize_merged_checkpoint_name("base_model.model.lm_head.weight") + == "lm_head.weight" + ) + + +def test_normalize_merged_checkpoint_name_leaves_regular_names_unchanged(): + assert ( + _normalize_merged_checkpoint_name( + "model.language_model.layers.3.self_attn.q_norm.weight" + ) + == "model.language_model.layers.3.self_attn.q_norm.weight" + ) diff --git a/uv.lock b/uv.lock index 792f6ead..8d5497ee 100644 --- a/uv.lock +++ b/uv.lock @@ -5360,6 +5360,7 @@ backend = [ { name = "hf-xet" }, { name = "nbclient" }, { name = "nbmake" }, + { name = "nvidia-cudnn-frontend", marker = "sys_platform == 'linux'" }, { name = "peft" }, { name = "pyarrow" }, { name = "pytest" }, @@ -5452,6 +5453,7 @@ requires-dist = [ { name = "nbmake", marker = "extra == 'backend'", specifier = ">=1.5.5" }, { name = "nest-asyncio", specifier = ">=1.6.0" }, { name = "numpy", marker = "extra == 'tinker'" }, + { name = "nvidia-cudnn-frontend", marker = "sys_platform == 'linux' and extra == 'backend'", specifier = "<1.21" }, { name = "nvidia-ml-py", marker = "extra == 'megatron'", specifier = "==13.580.82" }, { name = "openai", specifier = ">=2.14.0" }, { name = "peft", marker = "extra == 'backend'", specifier = ">=0.14.0" },