From 1e0cc4c864bfe96cc2e1e87d78ac6465e7fd2075 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Sat, 21 Mar 2026 02:03:08 +0000 Subject: [PATCH 1/5] Refactor shared Megatron and Unsloth training code --- src/art/_backend_training.py | 109 ++++++ src/art/local/backend.py | 181 ++++------ src/art/megatron/jobs.py | 31 ++ src/art/megatron/runtime_env.py | 15 + src/art/megatron/service.py | 31 +- src/art/megatron/shared.py | 570 ++++++++++++++++++++++++++++++++ src/art/megatron/train.py | 332 ++----------------- src/art/serverless/backend.py | 58 ++-- src/art/unsloth/service.py | 531 ++--------------------------- src/art/unsloth/shared.py | 454 +++++++++++++++++++++++++ 10 files changed, 1332 insertions(+), 980 deletions(-) create mode 100755 src/art/_backend_training.py mode change 100644 => 100755 src/art/local/backend.py create mode 100755 src/art/megatron/jobs.py create mode 100755 src/art/megatron/runtime_env.py mode change 100644 => 100755 src/art/megatron/service.py create mode 100755 src/art/megatron/shared.py mode change 100644 => 100755 src/art/megatron/train.py mode change 100644 => 100755 src/art/serverless/backend.py mode change 100644 => 100755 src/art/unsloth/service.py create mode 100755 src/art/unsloth/shared.py diff --git a/src/art/_backend_training.py b/src/art/_backend_training.py new file mode 100755 index 00000000..b8e82377 --- /dev/null +++ b/src/art/_backend_training.py @@ -0,0 +1,109 @@ +from collections.abc import Iterable +import time +from typing import Literal + +from . import dev +from .metrics_taxonomy import ( + average_metric_samples, + build_training_summary_metrics, + summarize_trajectory_groups, +) +from .trajectories import TrajectoryGroup +from .types import TrainConfig + + +def build_rl_train_configs( + *, + learning_rate: float, + advantage_balance: float = 0.0, + scale_rewards: bool = True, + importance_sampling_level: Literal[ + "token", "sequence", "average", "geometric_average" + ] = "token", + mask_prob_ratio: bool = False, + ppo: bool = False, + precalculate_logprobs: bool = False, + epsilon: float | None = None, + epsilon_high: float | None = None, + max_negative_advantage_importance_sampling_weight: float | None = None, + kimi_k2_tau: float | None = None, + kl_penalty_coef: float = 0.0, + allow_training_without_logprobs: bool | None = None, + plot_tensors: bool | None = None, + truncated_importance_sampling: float | None = None, + scale_learning_rate_by_reward_std_dev: bool | None = None, + logprob_calculation_chunk_size: int | None = None, + num_trajectories_learning_rate_multiplier_power: float | None = None, + kl_ref_adapter_path: str | None = None, +) -> tuple[TrainConfig, dev.TrainConfig]: + config = TrainConfig( + learning_rate=learning_rate, + kl_penalty_coef=kl_penalty_coef, + ) + dev_config: dev.TrainConfig = { + "advantage_balance": advantage_balance, + "importance_sampling_level": importance_sampling_level, + "kl_penalty_coef": kl_penalty_coef, + "mask_prob_ratio": mask_prob_ratio, + "ppo": ppo, + "precalculate_logprobs": precalculate_logprobs, + "scale_rewards": scale_rewards, + } + + if allow_training_without_logprobs is not None: + dev_config["allow_training_without_logprobs"] = ( + allow_training_without_logprobs + ) + if plot_tensors is not None: + dev_config["plot_tensors"] = plot_tensors + if truncated_importance_sampling is not None: + dev_config["truncated_importance_sampling"] = truncated_importance_sampling + if scale_learning_rate_by_reward_std_dev is not None: + dev_config["scale_learning_rate_by_reward_std_dev"] = ( + scale_learning_rate_by_reward_std_dev + ) + if logprob_calculation_chunk_size is not None: + dev_config["logprob_calculation_chunk_size"] = ( + logprob_calculation_chunk_size + ) + if num_trajectories_learning_rate_multiplier_power is not None: + dev_config["num_trajectories_learning_rate_multiplier_power"] = ( + num_trajectories_learning_rate_multiplier_power + ) + if epsilon is not None: + dev_config["epsilon"] = epsilon + if epsilon_high is not None: + dev_config["epsilon_high"] = epsilon_high + if max_negative_advantage_importance_sampling_weight is not None: + dev_config["max_negative_advantage_importance_sampling_weight"] = ( + max_negative_advantage_importance_sampling_weight + ) + if kimi_k2_tau is not None: + dev_config["kimi_k2_tau"] = kimi_k2_tau + if kl_ref_adapter_path is not None: + dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path + + return config, dev_config + + +def aggregate_rl_training_metrics( + *, + training_metrics: list[dict[str, float]], + trajectory_groups: Iterable[TrajectoryGroup], + trainer_started: float, +) -> dict[str, float]: + groups_list = list(trajectory_groups) + avg_metrics = average_metric_samples(training_metrics) + summary = summarize_trajectory_groups(groups_list) + avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started) + avg_metrics.update( + { + key: value + for key, value in build_training_summary_metrics( + summary, + include_trainable_groups=True, + ).items() + if key not in avg_metrics + } + ) + return avg_metrics diff --git a/src/art/local/backend.py b/src/art/local/backend.py old mode 100644 new mode 100755 index ad743757..abb6395d --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -43,11 +43,13 @@ from mp_actors import close_proxy, move_to_child_process from .. import dev +from .._backend_training import ( + aggregate_rl_training_metrics, + build_rl_train_configs, +) from ..backend import AnyTrainableModel, Backend -from ..costs import build_cost_calculator, get_model_pricing from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, - average_metric_samples, build_training_summary_metrics, summarize_trajectory_groups, ) @@ -160,9 +162,6 @@ def _allocated_gpu_count(self, model: Model) -> int: def __enter__(self) -> Self: return self - async def __aenter__(self) -> Self: - return self - def __exit__( self, exc_type: type[BaseException] | None, @@ -171,30 +170,14 @@ def __exit__( ) -> None: self._close() - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc: BaseException | None, - tb: TracebackType | None, - ) -> None: - await self.close() - async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - for service in self._services.values(): - aclose = getattr(service, "aclose", None) - if aclose is None: - close = getattr(service, "close", None) - if close is not None: - close() - else: - await aclose() - close_proxy(service) + self._close() def _close(self) -> None: - for service in self._services.values(): + for _, service in self._services.items(): close = getattr(service, "close", None) if close is not None: close() @@ -226,11 +209,6 @@ async def register( # (wandb initialization is now handled by the model's _get_wandb_run method) if model.trainable and "WANDB_API_KEY" in os.environ: _ = model._get_wandb_run() - if model.trainable: - trainable_model = cast(TrainableModel, model) - pricing = get_model_pricing(trainable_model.base_model) - if pricing is not None: - trainable_model.set_cost_calculator(build_cost_calculator(pricing)) def _model_inference_name(self, model: Model, step: int | None = None) -> str: """Return the inference name for a model checkpoint. @@ -244,27 +222,25 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str: If None, returns name for latest checkpoint (step 0 initially). """ - requested_step = step - - if step is None and isinstance(model, TrainableModel): - from ..dev.validate import is_dedicated_mode - - service = self._services.get(model.name) - if service is not None and is_dedicated_mode( - model._internal_config or dev.InternalModelConfig() - ): - loaded_step = getattr(service, "_latest_step", None) - if isinstance(loaded_step, int): - step = loaded_step - - if step is None: - # The checkpoint directory is written before dedicated-mode - # vLLM finishes reloading the new adapter. - step = self.__get_step(model) - name = f"{model.name}@{step}" + # For LocalBackend, vLLM always serves LoRA adapters with @step suffix + # Default to step 0 when not specified (the initial checkpoint created at registration) + if step is not None: + actual_step = step + elif model.name in self._services and self._in_process: + # In dedicated mode the service tracks which adapter vLLM has + # actually loaded. Reading the filesystem would race: the + # checkpoint directory appears before the HTTP reload completes. + svc = self._services[model.name] + loaded_step = getattr(svc, "_latest_step", None) + actual_step = ( + loaded_step if loaded_step is not None else self.__get_step(model) + ) + else: + actual_step = self.__get_step(model) + name = f"{model.name}@{actual_step}" logger.debug( - f"[BACKEND] _model_inference_name: step_arg={requested_step} " - f"actual_step={step} -> {name}" + f"[BACKEND] _model_inference_name: step_arg={step} " + f"actual_step={actual_step} -> {name}" ) return name @@ -529,14 +505,12 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, - loss_fn: Literal["cispo", "ppo"] = "cispo", - loss_fn_config: dict | None = None, - normalize_advantages: bool = True, - adam_params: object | None = None, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, + # RL algorithm settings + ppo: bool = False, epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -573,14 +547,6 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. - loss_fn: RL loss function. LocalBackend currently supports - "cispo" and "ppo". - loss_fn_config: Additional loss-function config. Not supported by - LocalBackend. - normalize_advantages: Whether to normalize advantages. LocalBackend - currently requires True. - adam_params: Custom optimizer params. Not supported by - LocalBackend. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. Tokens diverging more from the reference get reduced advantages. Defaults to 0.0 (disabled). @@ -590,7 +556,8 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. - epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. + ppo: Whether to use PPO clipping. Defaults to False. + epsilon: Clip epsilon for importance sampling. Defaults based on ppo. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages in range [-1.0, 1.0]. Defaults to 0.0 (balanced). @@ -633,54 +600,37 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) - if loss_fn not in {"cispo", "ppo"}: - raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.") - if loss_fn_config is not None: - raise ValueError("LocalBackend requires loss_fn_config=None.") - if not normalize_advantages: - raise ValueError("LocalBackend requires normalize_advantages=True.") - if adam_params is not None: - raise ValueError("LocalBackend requires adam_params=None.") - - # Build config objects from explicit kwargs - config = TrainConfig( - learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef - ) - dev_config: dev.TrainConfig = { - "advantage_balance": advantage_balance, - "allow_training_without_logprobs": allow_training_without_logprobs, - "importance_sampling_level": importance_sampling_level, - "kl_penalty_coef": kl_penalty_coef, - "mask_prob_ratio": mask_prob_ratio, - "plot_tensors": plot_tensors, - "ppo": loss_fn == "ppo", - "precalculate_logprobs": precalculate_logprobs, - "scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev, - "scale_rewards": scale_rewards, - "logprob_calculation_chunk_size": logprob_calculation_chunk_size, - "num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power, - } - # Only include optional fields if they're set - if epsilon is not None: - dev_config["epsilon"] = epsilon - if epsilon_high is not None: - dev_config["epsilon_high"] = epsilon_high - if max_negative_advantage_importance_sampling_weight is not None: - dev_config["max_negative_advantage_importance_sampling_weight"] = ( - max_negative_advantage_importance_sampling_weight - ) - if kimi_k2_tau is not None: - dev_config["kimi_k2_tau"] = kimi_k2_tau - if truncated_importance_sampling is not None: - dev_config["truncated_importance_sampling"] = truncated_importance_sampling - if kl_ref_adapter_path is not None: - dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path - elif kl_penalty_reference_step is not None: - ref_checkpoint_dir = get_step_checkpoint_dir( + + resolved_kl_ref_adapter_path = kl_ref_adapter_path + if ( + resolved_kl_ref_adapter_path is None + and kl_penalty_reference_step is not None + ): + resolved_kl_ref_adapter_path = get_step_checkpoint_dir( get_model_dir(model=model, art_path=self._path), kl_penalty_reference_step, ) - dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir + config, dev_config = build_rl_train_configs( + learning_rate=learning_rate, + advantage_balance=advantage_balance, + scale_rewards=scale_rewards, + importance_sampling_level=importance_sampling_level, + mask_prob_ratio=mask_prob_ratio, + ppo=ppo, + precalculate_logprobs=precalculate_logprobs, + epsilon=epsilon, + epsilon_high=epsilon_high, + max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, + kimi_k2_tau=kimi_k2_tau, + kl_penalty_coef=kl_penalty_coef, + allow_training_without_logprobs=allow_training_without_logprobs, + plot_tensors=plot_tensors, + truncated_importance_sampling=truncated_importance_sampling, + scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev, + logprob_calculation_chunk_size=logprob_calculation_chunk_size, + num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power, + kl_ref_adapter_path=resolved_kl_ref_adapter_path, + ) # Collect metrics from training training_metrics: list[dict[str, float]] = [] @@ -690,21 +640,10 @@ async def train( # type: ignore[override] ): training_metrics.append(metrics) - # Aggregate metrics - avg_metrics = average_metric_samples(training_metrics) - summary = summarize_trajectory_groups(groups_list) - avg_metrics.setdefault( - "time/step_trainer_s", time.monotonic() - trainer_started - ) - avg_metrics.update( - { - key: value - for key, value in build_training_summary_metrics( - summary, - include_trainable_groups=True, - ).items() - if key not in avg_metrics - } + avg_metrics = aggregate_rl_training_metrics( + training_metrics=training_metrics, + trajectory_groups=groups_list, + trainer_started=trainer_started, ) # Get step and checkpoint path diff --git a/src/art/megatron/jobs.py b/src/art/megatron/jobs.py new file mode 100755 index 00000000..b4c39e00 --- /dev/null +++ b/src/art/megatron/jobs.py @@ -0,0 +1,31 @@ +from typing import Literal + +from pydantic import BaseModel + +from .. import dev, types +from ..preprocessing.pack import DiskPackedTensors + +DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl" +DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs" +DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking" + + +class MegatronTrainingJob(BaseModel): + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dev.TrainConfig + log_path: str = DEFAULT_TRAINING_LOG_PATH + + +class MegatronSFTTrainingJob(BaseModel): + job_type: Literal["sft"] = "sft" + lora_path: str + optimizer_state_path: str + sft_data_dir: str + num_batches: int + learning_rates: list[float] + weight_decay: float = 0.0 + max_grad_norm: float = 1.0 + log_path: str = DEFAULT_TRAINING_LOG_PATH diff --git a/src/art/megatron/runtime_env.py b/src/art/megatron/runtime_env.py new file mode 100755 index 00000000..c74a4b66 --- /dev/null +++ b/src/art/megatron/runtime_env.py @@ -0,0 +1,15 @@ +import os + + +def _set_cache_dir(env_var: str, default_path: str) -> None: + if not os.environ.get(env_var): + os.environ[env_var] = os.path.expanduser(default_path) + os.makedirs(os.environ[env_var], exist_ok=True) + + +def configure_megatron_runtime_env() -> None: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" + _set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor") + _set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache") diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py old mode 100644 new mode 100755 index 8ed6b82c..686af707 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -10,7 +10,6 @@ from typing import Any, AsyncIterator from peft.tuners.lora.config import LoraConfig -from pydantic import BaseModel from safetensors import safe_open from safetensors.torch import load_file, save_file import torch @@ -26,16 +25,11 @@ from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, openai_server_task, run_on_workers - - -class MegatronTrainingJob(BaseModel): - """Job format for communication with train.py""" - - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig +from .jobs import ( + DEFAULT_JOBS_DIR, + DEFAULT_TRAINING_LOG_PATH, + MegatronTrainingJob, +) @dataclass @@ -236,11 +230,10 @@ async def train( self._optimizer_state_path = self._get_optimizer_state_path() - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - for job_name in os.listdir(jobs_dir): + os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True) + for job_name in os.listdir(DEFAULT_JOBS_DIR): if job_name.endswith(".json"): - os.remove(os.path.join(jobs_dir, job_name)) + os.remove(os.path.join(DEFAULT_JOBS_DIR, job_name)) job = MegatronTrainingJob( lora_path=lora_path, optimizer_state_path=self._optimizer_state_path, @@ -248,7 +241,9 @@ async def train( config=config, experimental_config=_config, ) - job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") + job_path = os.path.join( + DEFAULT_JOBS_DIR, f"{datetime.datetime.now().isoformat()}.json" + ) with open(job_path, "w") as f: f.write(job.model_dump_json()) @@ -256,14 +251,14 @@ async def train( while True: await asyncio.sleep(0.1) try: - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: + with open(DEFAULT_TRAINING_LOG_PATH, "a+") as log_file: log_file.seek(0) lines = log_file.readlines()[num_lines:] for line in lines: if line := line.strip(): if line == "all done": self._merge_lora_adapter(lora_path) - os.remove("/tmp/megatron_training_log.jsonl") + os.remove(DEFAULT_TRAINING_LOG_PATH) break num_lines += 1 yield json.loads(line) diff --git a/src/art/megatron/shared.py b/src/art/megatron/shared.py new file mode 100755 index 00000000..b84b873d --- /dev/null +++ b/src/art/megatron/shared.py @@ -0,0 +1,570 @@ +import gc +import json +import math +import os +import shutil +import time +from dataclasses import dataclass +from typing import Any + +from megatron.core import parallel_state as ps +from safetensors.torch import load_file, save_file +import torch + +from ..loss import loss_fn, shift_tensor +from ..preprocessing.pack import PackedTensors, packed_tensors_from_dir +from .flex_attention import create_shared_prefix_attention_state +from .jobs import MegatronSFTTrainingJob, MegatronTrainingJob + + +@dataclass +class MegatronTrainContext: + model: list[Any] + optimizer: Any + rank: int + world_size: int + + +def create_megatron_train_context(model_identifier: str) -> MegatronTrainContext: + from megatron.core.distributed import DistributedDataParallelConfig + from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer + from torch._inductor.runtime.cache_dir_utils import ( + cache_dir as inductor_cache_dir, + ) + + from .lora import apply_lora_adapters + from .provider import get_provider + + provider = get_provider(model_identifier) + provider.register_pre_wrap_hook( + lambda model_chunks: _freeze_model(model_chunks) or model_chunks + ) + + model = provider.provide_distributed_model( + ddp_config=DistributedDataParallelConfig(), + data_parallel_random_init=False, + ) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + if rank == 0: + print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) + print("Resolved inductor cache_dir():", inductor_cache_dir()) + print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) + + _install_gpt_preprocess_hooks(model) + apply_lora_adapters(model, provider) + + optimizer = get_megatron_optimizer( + config=OptimizerConfig( + bf16=True, + lr=5e-6, + adam_beta1=0.9, + adam_beta2=0.99, + clip_grad=0.1, + weight_decay=0.1, + ), + model_chunks=model, # type: ignore[arg-type] + ) + _print_optimizer_parameter_stats( + rank=rank, + optimizer=optimizer, + model_chunks=model, + ) + + return MegatronTrainContext( + model=model, + optimizer=optimizer, + rank=rank, + world_size=world_size, + ) + + +def run_megatron_rl_job( + ctx: MegatronTrainContext, + job: MegatronTrainingJob, + *, + job_path: str | None = None, +) -> None: + packed_tensors = None + adapter_model = None + + try: + adapter_model = _load_lora_and_optimizer( + ctx, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + + ctx.optimizer.config.clip_grad = 0.1 + for param_group in ctx.optimizer.param_groups: + param_group["weight_decay"] = 0.1 + + _print0(ctx.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"]) + packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + indices = _data_parallel_indices( + job.disk_packed_tensors["num_sequences"], + dp_rank=ps.get_data_parallel_rank(), + dp_world_size=ps.get_data_parallel_world_size(), + ) + + for index in indices: + inputs = _packed_tensors_for_index(packed_tensors, index) + device = next(ctx.model[0].parameters()).device + inputs = _move_packed_tensors_to_device(inputs, device) + + new_logprobs: torch.Tensor = -ctx.model[0]( + input_ids=inputs["tokens"], + position_ids=inputs["input_pos"], + attention_mask=_placeholder_attention_mask(device), + labels=shift_tensor(inputs["tokens"], 0), + extra_block_kwargs={ + "attention_bias": _shared_prefix_attention_state(inputs), + }, + ) + loss_output = loss_fn( + inputs, # type: ignore[arg-type] + new_logprobs, + None, + None, + job.experimental_config, + ) + probs_corr = loss_output.probs_corr.item() + _print0( + ctx.rank, + "Correlation between old and new probabilities:", + probs_corr, + ) + + loss = loss_output.mean_policy_loss + loss.backward() + + start = time.perf_counter() + num_grads = _reduce_lora_grads( + ctx.model, + op=torch.distributed.ReduceOp.AVG, + ) + _print0( + ctx.rank, + f"Reduced {num_grads} LoRA grads in " + f"{(time.perf_counter() - start) * 1e3:.1f} ms", + ) + + for param_group in ctx.optimizer.param_groups: + param_group["lr"] = job.config.learning_rate + update_successful, grad_norm, num_zeros_in_grad = ctx.optimizer.step() + ctx.optimizer.zero_grad() + + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + + if ctx.rank == 0: + with open(job.log_path, "a+") as log_file: + log_msg = json.dumps( + { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "probs_corr": probs_corr, + **( + {"kl_policy_ref": loss_output.kl_policy_ref.item()} + if loss_output.kl_policy_ref is not None + else {} + ), + } + ) + print("Logging", log_msg) + log_file.write(log_msg + "\n") + + _save_lora_and_optimizer( + ctx, + adapter_model=adapter_model, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + _complete_job( + ctx, + job_path=job_path, + log_path=job.log_path, + cleanup_path=job.disk_packed_tensors["dir"], + ) + finally: + if packed_tensors is not None: + del packed_tensors + if adapter_model is not None: + del adapter_model + if "inputs" in locals(): + del inputs + gc.collect() + torch.cuda.empty_cache() + + +def run_megatron_sft_job( + ctx: MegatronTrainContext, + job: MegatronSFTTrainingJob, + *, + job_path: str | None = None, +) -> None: + adapter_model = None + + try: + adapter_model = _load_lora_and_optimizer( + ctx, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + + ctx.optimizer.config.clip_grad = job.max_grad_norm + for param_group in ctx.optimizer.param_groups: + param_group["weight_decay"] = job.weight_decay + + device = next(ctx.model[0].parameters()).device + dp_rank = ps.get_data_parallel_rank() + dp_world_size = ps.get_data_parallel_world_size() + + for batch_idx in range(job.num_batches): + batch_start_time = time.perf_counter() + batch_dir = os.path.join(job.sft_data_dir, f"batch_{batch_idx:06d}") + batch_metadata, trajectory_tensors = _load_sft_batch_from_disk(batch_dir) + + for param_group in ctx.optimizer.param_groups: + param_group["lr"] = job.learning_rates[batch_idx] + + num_trainable_tokens = batch_metadata["num_trainable_tokens"] + assert num_trainable_tokens > 0, ( + f"Batch {batch_idx} has no trainable tokens" + ) + + batch_loss = torch.tensor(0.0, device=device) + local_trajectory_tensors = trajectory_tensors[dp_rank::dp_world_size] + for traj_tensors in local_trajectory_tensors: + attention_mask_1d = traj_tensors["attention_mask"] + actual_len = int(attention_mask_1d.sum().item()) + input_ids = ( + traj_tensors["input_ids"][:actual_len].unsqueeze(0).to(device) + ) + labels = traj_tensors["labels"][:actual_len].unsqueeze(0).to(device) + + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + shifted_labels = shift_tensor(labels, -100) + attention_state = _causal_attention_state(seq_len, device) + + per_token_loss: torch.Tensor = ctx.model[0]( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=_placeholder_attention_mask(device), + labels=shifted_labels, + extra_block_kwargs={"attention_bias": attention_state}, + ) + masked_loss = per_token_loss[shifted_labels != -100].sum() + loss = masked_loss / num_trainable_tokens + loss.backward() + batch_loss += masked_loss.detach() + + start = time.perf_counter() + num_grads = _reduce_lora_grads( + ctx.model, + op=torch.distributed.ReduceOp.SUM, + ) + _print0( + ctx.rank, + f"SFT batch {batch_idx}: reduced {num_grads} LoRA grads in " + f"{(time.perf_counter() - start) * 1e3:.1f} ms", + ) + + update_successful, grad_norm, num_zeros_in_grad = ctx.optimizer.step() + ctx.optimizer.zero_grad() + + torch.distributed.reduce(batch_loss, dst=0, op=torch.distributed.ReduceOp.SUM) + avg_loss = batch_loss / num_trainable_tokens + + batch_time = time.perf_counter() - batch_start_time + tokens_per_second = ( + num_trainable_tokens / batch_time if batch_time > 0 else 0.0 + ) + + if ctx.rank == 0: + with open(job.log_path, "a+") as log_file: + log_msg = json.dumps( + { + "loss": avg_loss.item(), + "learning_rate": job.learning_rates[batch_idx], + "grad_norm": float(grad_norm), + "num_trajectories": float(batch_metadata["num_trajectories"]), + "num_trainable_tokens": float(num_trainable_tokens), + "tokens_per_second": tokens_per_second, + } + ) + print("Logging SFT", log_msg) + log_file.write(log_msg + "\n") + + _save_lora_and_optimizer( + ctx, + adapter_model=adapter_model, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + _complete_job( + ctx, + job_path=job_path, + log_path=job.log_path, + cleanup_path=job.sft_data_dir, + ) + finally: + if adapter_model is not None: + del adapter_model + gc.collect() + torch.cuda.empty_cache() + + +def _freeze_model(model_chunks: list[Any]) -> list[Any]: + for module in model_chunks: + for param in module.parameters(): + param.requires_grad = False + return model_chunks + + +def _install_gpt_preprocess_hooks(model_chunks: list[Any]) -> None: + from megatron.core.models.gpt.gpt_model import GPTModel + + for module in model_chunks: + while not isinstance(module, GPTModel) and hasattr(module, "module"): + module = module.module + if not isinstance(module, GPTModel): + continue + + _preprocess = module._preprocess + + def _preprocess_hook(*args: Any, _preprocess: Any = _preprocess, **kwargs: Any): + preproc_output = list(_preprocess(*args, **kwargs)) + preproc_output[0].requires_grad = True # type: ignore[index] + table = preproc_output[1] # [S,B,1,D] type: ignore[index] + d_model = table.size(-1) # type: ignore[union-attr] + table_flat = table.view(table.size(0), d_model) # type: ignore[union-attr] + position_ids = kwargs["position_ids"] + batch_size, seq_len = position_ids.shape + gathered = table_flat.index_select(0, position_ids.reshape(-1)) + gathered = gathered.view(batch_size, seq_len, d_model).permute(1, 0, 2) + preproc_output[1] = gathered.contiguous().unsqueeze(2) + return tuple(preproc_output) + + module._preprocess = _preprocess_hook # type: ignore[attr-defined] + + +def _print_optimizer_parameter_stats( + *, + rank: int, + optimizer: Any, + model_chunks: list[Any], +) -> None: + if rank != 0: + return + + num_params = sum( + param.numel() + for group in optimizer.param_groups + if not group["is_decoupled_lr"] + for param in group["params"] + ) + print(f"Number of parameters in optimizer: {num_params:,}") + total_params = sum( + param.numel() for chunk in model_chunks for param in chunk.parameters() + ) + percent = (num_params / total_params) * 100 if total_params > 0 else 0.0 + print(f"Optimizer parameters as percent of total: {percent:0.2f}%") + + +def _print0(rank: int, *values: Any) -> None: + if rank == 0: + print(*values) + + +def _placeholder_attention_mask(device: torch.device) -> torch.Tensor: + return torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) + + +def _causal_attention_state(seq_len: int, device: torch.device) -> Any: + group_ids = torch.zeros((1, seq_len), dtype=torch.int64, device=device) + parent_ids = torch.zeros_like(group_ids) + return create_shared_prefix_attention_state(group_ids, parent_ids) + + +def _shared_prefix_attention_state(packed_tensors: PackedTensors) -> Any: + return create_shared_prefix_attention_state( + group_ids=packed_tensors["group_ids"], + parent_ids=packed_tensors["parent_ids"], + ) + + +def _load_sft_batch_from_disk( + batch_dir: str, +) -> tuple[dict[str, Any], list[dict[str, torch.Tensor]]]: + with open(os.path.join(batch_dir, "metadata.json")) as f: + metadata = json.load(f) + + trajectory_tensors = [] + for i in range(metadata["num_trajectory_tensors"]): + tensors = load_file(os.path.join(batch_dir, f"trajectory_{i}.safetensors")) + trajectory_tensors.append(tensors) + return metadata, trajectory_tensors + + +def _data_parallel_indices( + num_sequences: int, + *, + dp_rank: int, + dp_world_size: int, +) -> list[int]: + if num_sequences <= 0: + raise ValueError("num_sequences must be positive") + + num_indices = math.ceil(num_sequences / dp_world_size) + indices = list(range(dp_rank, num_sequences, dp_world_size)) + if not indices: + indices = [dp_rank % num_sequences] + repeat = math.ceil(num_indices / len(indices)) + return (indices * repeat)[:num_indices] + + +def _packed_tensors_for_index( + packed_tensors: PackedTensors, + index: int, +) -> PackedTensors: + values: dict[str, Any] = {} + for key, value in packed_tensors.items(): + if isinstance(value, torch.Tensor): + values[key] = value[index : index + 1] + elif isinstance(value, list): + values[key] = value[index : index + 1] + + values.setdefault("pixel_values", [None]) + values.setdefault("image_grid_thw", [None]) + return PackedTensors(**values) # type: ignore[arg-type] + + +def _move_packed_tensors_to_device( + packed_tensors: PackedTensors, + device: torch.device, +) -> PackedTensors: + for key, value in packed_tensors.items(): + if isinstance(value, torch.Tensor): + packed_tensors[key] = value.to(device) # type: ignore[index] + return packed_tensors + + +def _load_lora_and_optimizer( + ctx: MegatronTrainContext, + *, + lora_path: str, + optimizer_state_path: str, +) -> dict[str, torch.Tensor]: + adapter_model_path = os.path.join(lora_path, "adapter_model.safetensors") + if os.path.exists(adapter_model_path): + _print0(ctx.rank, "Loading adapter model from", adapter_model_path) + adapter_model = load_file(adapter_model_path) + with torch.no_grad(): + for chunk in ctx.model: + for module in chunk.modules(): + if hasattr(module, "load_lora"): + module.load_lora(adapter_model) # type: ignore[attr-defined] + else: + _print0(ctx.rank, "No adapter model found at", adapter_model_path) + adapter_model = {} + with torch.no_grad(): + for chunk in ctx.model: + for module in chunk.modules(): + if hasattr(module, "reset_lora_parameters"): + module.reset_lora_parameters() # type: ignore[attr-defined] + + optimizer_shard_path = os.path.join( + optimizer_state_path, + f"{ctx.rank + 1:02d}-of-{ctx.world_size:02d}.pt", + ) + if os.path.exists(optimizer_shard_path): + _print0(ctx.rank, "Loading optimizer state from", optimizer_shard_path) + ctx.optimizer.load_state_dict(torch.load(optimizer_shard_path)) + else: + _print0( + ctx.rank, + "No optimizer state found at", + optimizer_shard_path, + "— resetting optimizer for new run", + ) + ctx.optimizer.optimizer.state.clear() + ctx.optimizer.reload_model_params() + return adapter_model + + +def _save_lora_and_optimizer( + ctx: MegatronTrainContext, + *, + adapter_model: dict[str, torch.Tensor], + lora_path: str, + optimizer_state_path: str, +) -> None: + sharded_state_dict: dict[str, torch.Tensor] = {} + for chunk in ctx.model: + for module in chunk.modules(): + if not hasattr(module, "sharded_lora_state_dict"): + continue + + module_state_dict: dict[str, torch.Tensor] = ( + module.sharded_lora_state_dict() # type: ignore[attr-defined] + ) + for key, value in module_state_dict.items(): + target_dtype = ( + adapter_model[key].dtype if key in adapter_model else value.dtype + ) + sharded_state_dict[key] = value.to(target_dtype) + + os.makedirs(lora_path, exist_ok=True) + shard_path = os.path.join( + lora_path, + f"adapter_model-{ctx.rank + 1:02d}-of-{ctx.world_size:02d}.safetensors", + ) + print("Saving adapter shard to", shard_path) + save_file(sharded_state_dict, shard_path) + + optimizer_shard_path = os.path.join( + optimizer_state_path, + f"{ctx.rank + 1:02d}-of-{ctx.world_size:02d}.pt", + ) + print("Saving optimizer shard to", optimizer_shard_path) + os.makedirs(optimizer_state_path, exist_ok=True) + torch.save(ctx.optimizer.state_dict(), optimizer_shard_path) + + +def _reduce_lora_grads( + model_chunks: list[Any], + *, + op: Any, +) -> int: + num_grads = 0 + for chunk in model_chunks: + for param in chunk.parameters(): + if param.grad is None: + continue + torch.distributed.all_reduce( + param.grad, + op=op, + group=ps.get_data_parallel_group(), + ) + num_grads += 1 + return num_grads + + +def _complete_job( + ctx: MegatronTrainContext, + *, + job_path: str | None, + log_path: str, + cleanup_path: str, +) -> None: + torch.distributed.barrier() + if ctx.rank != 0: + return + + if job_path is not None: + os.remove(job_path) + with open(log_path, "a+") as log_file: + log_file.write("all done\n") + shutil.rmtree(cleanup_path) diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py old mode 100644 new mode 100755 index 02e3b7cd..bb1c5cfd --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -1,330 +1,50 @@ # isort: off -import os - - -def _set_cache_dir(env_var: str, default_path: str) -> None: - if not os.environ.get(env_var): - os.environ[env_var] = os.path.expanduser(default_path) - os.makedirs(os.environ[env_var], exist_ok=True) +from art.megatron.runtime_env import configure_megatron_runtime_env - -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" -_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor") -_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache") +configure_megatron_runtime_env() # isort: on -import gc -import json -import math -import shutil +import os import time -from typing import Any, cast -from megatron.core import parallel_state as ps -from megatron.core.distributed import DistributedDataParallelConfig -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer -from megatron.core.transformer.module import MegatronModule -from pydantic import BaseModel -from safetensors.torch import load_file, save_file import torch -from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir -from art import dev, types -from art.loss import loss_fn, shift_tensor -from art.megatron.flex_attention import create_shared_prefix_attention_state -from art.megatron.lora import apply_lora_adapters -from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu -from art.megatron.provider import get_provider -from art.preprocessing.pack import ( - DiskPackedTensors, - PackedTensors, - packed_tensors_from_dir, +from art.megatron.jobs import ( + DEFAULT_JOBS_DIR, + DEFAULT_VLLM_WAKE_LOCK_PATH, + MegatronTrainingJob, ) +from art.megatron.offload import OffloadState, offload_to_cpu, reload_to_gpu +from art.megatron.shared import create_megatron_train_context, run_megatron_rl_job -provider = get_provider( +ctx = create_megatron_train_context( os.environ.get("MODEL_IDENTIFIER", "Qwen/Qwen3-30B-A3B-Instruct-2507") ) - -def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: - for module in model_chunks: - for param in module.parameters(): - param.requires_grad = False - return model_chunks - - -provider.register_pre_wrap_hook(lambda x: freeze_model(x) or x) - -model = provider.provide_distributed_model( - ddp_config=DistributedDataParallelConfig(), - data_parallel_random_init=False, -) - -rank = torch.distributed.get_rank() # ty:ignore[possibly-missing-attribute] -world_size = torch.distributed.get_world_size() # ty:ignore[possibly-missing-attribute] - -if rank == 0: - print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) - print("Resolved inductor cache_dir():", inductor_cache_dir()) - print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) - -for module in model: - while not isinstance(module, GPTModel) and hasattr(module, "module"): - module = module.module - if isinstance(module, GPTModel): - _preprocess = module._preprocess - - def _preprocess_hook(*args, **kwargs): - preproc_output = list(_preprocess(*args, **kwargs)) - preproc_output[0].requires_grad = True # type: ignore - table = preproc_output[1] # [S,B,1,D] type: ignore - D = table.size(-1) # type: ignore - table_flat = table.view(table.size(0), D) # type: ignore - # position_ids: [B, S] - position_ids = kwargs["position_ids"] - B, S = position_ids.shape - gathered = table_flat.index_select(0, position_ids.reshape(-1)) # [B*S, D] - gathered = gathered.view(B, S, D).permute(1, 0, 2).contiguous() # [S, B, D] - preproc_output[1] = gathered.unsqueeze(2) # [S, B, 1, D] - return tuple(preproc_output) - - module._preprocess = _preprocess_hook # type: ignore[attr-defined] - - -apply_lora_adapters(model, provider) - -optimizer = get_megatron_optimizer( - config=OptimizerConfig( - bf16=True, - lr=5e-6, - adam_beta1=0.9, - adam_beta2=0.99, - clip_grad=0.1, - weight_decay=0.1, - ), - model_chunks=model, # type: ignore -) - -if rank == 0: - # Print the number of parameters in the optimizer, nicely formatted - num_params = sum( - p.numel() - for group in optimizer.param_groups - if not group["is_decoupled_lr"] - for p in group["params"] - ) - print(f"Number of parameters in optimizer: {num_params:,}") - total_params = sum(p.numel() for m in model for p in m.parameters()) - percent = (num_params / total_params) * 100 if total_params > 0 else 0 - print(f"Optimizer parameters as percent of total: {percent:0.2f}%") - - -class TrainingJob(BaseModel): - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig - - -def print0(*values: Any) -> None: - if rank == 0: - print(*values) - - offload_state = OffloadState() - - -offload_to_cpu(model, optimizer, rank, offload_state) +offload_to_cpu(ctx.model, ctx.optimizer, ctx.rank, offload_state) while True: - torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) + torch.distributed.barrier() + os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True) job_names = sorted( - job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") + job_name for job_name in os.listdir(DEFAULT_JOBS_DIR) if job_name.endswith(".json") ) if not job_names: time.sleep(1) continue - wake_lock_path = "/tmp/megatron_vllm_waking" - while os.path.exists(wake_lock_path): + while os.path.exists(DEFAULT_VLLM_WAKE_LOCK_PATH): time.sleep(0.2) - reload_to_gpu(model, optimizer, rank, offload_state) - - job_name = job_names[0] - job_path = os.path.join(jobs_dir, job_name) - with open(job_path, "rb") as f: - job = TrainingJob.model_validate_json(f.read()) - config = job.config - experimental_config = job.experimental_config - print0("Loaded job from", job_path) - print0("Job:", job) - adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" - if os.path.exists(adapter_model_path): - print0("Loading adapter model from", adapter_model_path) - adapter_model = load_file(adapter_model_path) - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "load_lora"): - module.load_lora(adapter_model) # type: ignore - else: - print0("No adapter model found at", adapter_model_path) - adapter_model = {} - with torch.no_grad(): - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "reset_lora_parameters"): - module.reset_lora_parameters() # type: ignore - optimizer_shard_path = os.path.join( - job.optimizer_state_path, f"{rank + 1:02d}-of-{world_size:02d}.pt" - ) - if os.path.exists(optimizer_shard_path): - print( - "Loading optimizer state from", - optimizer_shard_path, - ) - optimizer.load_state_dict(torch.load(optimizer_shard_path)) - else: - # No checkpoint for this run; reset optimizer state to avoid cross-run leakage - print( - "No optimizer state found at", - optimizer_shard_path, - "— resetting optimizer for new run", - ) - optimizer.optimizer.state.clear() - optimizer.reload_model_params() - print0("Loading packed tensors from", job.disk_packed_tensors["dir"]) - packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) - num_sequences = job.disk_packed_tensors["num_sequences"] - dp_rank = ps.get_data_parallel_rank() - dp_world_size = ps.get_data_parallel_world_size() - num_indices = math.ceil(num_sequences / dp_world_size) - indices = list(range(dp_rank, num_sequences, dp_world_size)) - if not indices: - indices = [dp_rank % num_sequences] - # pad indices by repeating & slicing to target length - repeat = math.ceil(num_indices / len(indices)) - indices = (indices * repeat)[:num_indices] - for index in indices: - inputs = PackedTensors( # type: ignore - **{ - key: value[index : index + 1] - for key, value in packed_tensors.items() - if isinstance(value, torch.Tensor) - }, - pixel_values=[None], - image_grid_thw=[None], - ) - ref_logprobs = None - device = next(model[0].parameters()).device - for key, value in inputs.items(): - if isinstance(value, torch.Tensor): - inputs[key] = value.to(device) # type: ignore - attention_state = create_shared_prefix_attention_state( # should happen after group_ids is moved to device - group_ids=inputs["group_ids"], - parent_ids=inputs["parent_ids"], - ) - # Megatron full-layer recompute saves positional tensor args, so keep a tiny - # placeholder Tensor here and pass flex BlockMask state via attention_bias. - attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) - new_logprobs: torch.Tensor = -model[0]( - input_ids=inputs["tokens"], - position_ids=inputs["input_pos"], - attention_mask=attention_mask, - labels=shift_tensor(inputs["tokens"], 0), - extra_block_kwargs={"attention_bias": attention_state}, - ) - loss = loss_fn( - inputs, # type: ignore - new_logprobs, - ref_logprobs, - None, - experimental_config, - ) - probs_corr = loss.probs_corr.item() - print0("Correlation between old and new probabilities:", probs_corr) - loss = loss.mean_policy_loss - loss.backward() - # Reduce LoRA grads - start = time.perf_counter() - num_grads = 0 - for chunk in model: - for param in chunk.parameters(): - if param.grad is None: - continue - torch.distributed.all_reduce( # ty:ignore[possibly-missing-attribute] - param.grad, - op=torch.distributed.ReduceOp.AVG, # ty:ignore[possibly-missing-attribute] - group=ps.get_data_parallel_group(), - ) - num_grads += 1 - print0( - f"Reduced {num_grads} LoRA grads in {(time.perf_counter() - start) * 1e3:.1f} ms" - ) - for param_group in optimizer.param_groups: - param_group["lr"] = config.learning_rate - update_successful, grad_norm, num_zeros_in_grad = cast( - tuple[bool, float, int | None], optimizer.step() - ) - optimizer.zero_grad() - - # Mean reduce loss across all ranks for logging - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) # ty:ignore[possibly-missing-attribute] - - if rank == 0: - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_msg = json.dumps( - { - "loss/train": loss.item(), - "loss/grad_norm": grad_norm, - "probs_corr": probs_corr, - } - ) - print("Logging", log_msg) - log_file.write(log_msg + "\n") - - sharded_state_dict = {} - for chunk in model: - for module in chunk.modules(): - if hasattr(module, "sharded_lora_state_dict"): - module_sharded_lora_state_dict: dict[str, torch.Tensor] = ( - module.sharded_lora_state_dict() # type: ignore - ) - for key, value in module_sharded_lora_state_dict.items(): - target_dtype = ( - adapter_model[key].dtype - if key in adapter_model - else value.dtype - ) - sharded_state_dict[key] = value.to(target_dtype) - shard_path = os.path.join( - job.lora_path, - f"adapter_model-{rank + 1:02d}-of-{world_size:02d}.safetensors", - ) - print("Saving adapter shard to", shard_path) - save_file(sharded_state_dict, shard_path) - print("Saving optimizer shard to", optimizer_shard_path) - os.makedirs(job.optimizer_state_path, exist_ok=True) - torch.save(optimizer.state_dict(), optimizer_shard_path) - offload_to_cpu(model, optimizer, rank, offload_state) - # Release mmap-backed packed tensor references on all ranks before rank0 cleanup. - del packed_tensors - del adapter_model - if "inputs" in locals(): - del inputs - gc.collect() - torch.cuda.empty_cache() - # Ensure all ranks have finished saving before signaling completion - torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] - if rank == 0: - os.remove(job_path) - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_file.write("all done\n") - shutil.rmtree(job.disk_packed_tensors["dir"]) + reload_to_gpu(ctx.model, ctx.optimizer, ctx.rank, offload_state) + try: + job_path = os.path.join(DEFAULT_JOBS_DIR, job_names[0]) + with open(job_path, "rb") as f: + job = MegatronTrainingJob.model_validate_json(f.read()) + if ctx.rank == 0: + print("Loaded job from", job_path) + print("Job:", job) + run_megatron_rl_job(ctx, job, job_path=job_path) + finally: + offload_to_cpu(ctx.model, ctx.optimizer, ctx.rank, offload_state) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py old mode 100644 new mode 100755 index ce530fe5..fb469eb9 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -9,10 +9,13 @@ from art.serverless.client import Client, ExperimentalTrainingConfig from .. import dev +from .._backend_training import ( + aggregate_rl_training_metrics, + build_rl_train_configs, +) from ..backend import AnyTrainableModel, Backend from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, - average_metric_samples, build_training_summary_metrics, summarize_trajectory_groups, ) @@ -254,27 +257,19 @@ async def train( # type: ignore[override] """ groups_list = list(trajectory_groups) - # Build config objects from explicit kwargs - config = TrainConfig(learning_rate=learning_rate) - dev_config: dev.TrainConfig = { - "advantage_balance": advantage_balance, - "importance_sampling_level": importance_sampling_level, - "mask_prob_ratio": mask_prob_ratio, - "ppo": ppo, - "precalculate_logprobs": precalculate_logprobs, - "scale_rewards": scale_rewards, - } - # Only include optional fields if they're set - if epsilon is not None: - dev_config["epsilon"] = epsilon - if epsilon_high is not None: - dev_config["epsilon_high"] = epsilon_high - if max_negative_advantage_importance_sampling_weight is not None: - dev_config["max_negative_advantage_importance_sampling_weight"] = ( - max_negative_advantage_importance_sampling_weight - ) - if kimi_k2_tau is not None: - dev_config["kimi_k2_tau"] = kimi_k2_tau + config, dev_config = build_rl_train_configs( + learning_rate=learning_rate, + advantage_balance=advantage_balance, + scale_rewards=scale_rewards, + importance_sampling_level=importance_sampling_level, + mask_prob_ratio=mask_prob_ratio, + ppo=ppo, + precalculate_logprobs=precalculate_logprobs, + epsilon=epsilon, + epsilon_high=epsilon_high, + max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, + kimi_k2_tau=kimi_k2_tau, + ) # Collect metrics from training training_metrics: list[dict[str, float]] = [] @@ -284,21 +279,10 @@ async def train( # type: ignore[override] ): training_metrics.append(metrics) - # Aggregate metrics - avg_metrics = average_metric_samples(training_metrics) - summary = summarize_trajectory_groups(groups_list) - avg_metrics.setdefault( - "time/step_trainer_s", time.monotonic() - trainer_started - ) - avg_metrics.update( - { - key: value - for key, value in build_training_summary_metrics( - summary, - include_trainable_groups=True, - ).items() - if key not in avg_metrics - } + avg_metrics = aggregate_rl_training_metrics( + training_metrics=training_metrics, + trajectory_groups=groups_list, + trainer_started=trainer_started, ) # Get step and artifact name diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py old mode 100644 new mode 100755 index 5b6a563c..6a6a0eb9 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -8,15 +8,9 @@ import os import subprocess import sys -from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast - -from datasets import Dataset -import peft -import torch -from torch.optim import Optimizer -from transformers import GenerationMixin, PreTrainedModel -from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from trl import GRPOConfig, GRPOTrainer +from typing import Any, AsyncIterator + +from trl import GRPOTrainer from vllm import AsyncEngineArgs from vllm.lora.request import LoRARequest from vllm.v1.engine.async_llm import AsyncLLM @@ -24,139 +18,24 @@ from .. import dev, types from ..dev.validate import is_dedicated_mode from ..local.checkpoints import get_last_checkpoint_dir -from ..preprocessing.inputs import TrainInputs, create_train_inputs -from ..preprocessing.pack import ( - DiskPackedTensors, - PackedTensors, - packed_tensors_from_dir, -) +from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers -from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train +from .shared import ( + UnslothTrainContext, + create_unsloth_train_context, + run_unsloth_rl_training, + run_unsloth_sft_training, +) +from .train import gc_and_empty_cuda_cache logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from peft.peft_model import PeftModelForCausalLM - from trl import GRPOTrainer - - -# ============================================================================ -# Shared Utilities -# ============================================================================ - - -class SupportsLoadLora(Protocol): - """Protocol for models that support the optimized load_lora method.""" - - def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ... - - -class _StopTrainInputs: - """Dedicated sentinel for stopping the background trainer loop.""" - - -_STOP_TRAIN_INPUT = _StopTrainInputs() -_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0 -_TrainLoopInput = TrainInputs | _StopTrainInputs - - -def precalculate_new_logprobs( - trainer: "GRPOTrainer", - peft_model: "PeftModelForCausalLM", - packed_tensors: PackedTensors, - config: types.TrainConfig, - _config: dev.TrainConfig, -) -> torch.Tensor: - """Precalculate logprobs for all offsets and return as a tensor.""" - return torch.cat( - [ - trainer.compute_loss( - peft_model, - TrainInputs( # ty:ignore[missing-typed-dict-key] - **{ - k: v[_offset : _offset + 1] - for k, v in packed_tensors.items() - if isinstance(v, torch.Tensor) - }, - pixel_values=packed_tensors["pixel_values"][_offset : _offset + 1], - image_grid_thw=packed_tensors["image_grid_thw"][ - _offset : _offset + 1 - ], - config=config, - _config=_config, - return_new_logprobs=True, - ), - ) - for _offset in range(0, packed_tensors["tokens"].shape[0]) - ] - ).to("cpu") - - -async def process_train_batch( - packed_tensors: PackedTensors, - config: types.TrainConfig, - _config: dev.TrainConfig, - inputs_queue: asyncio.Queue[_TrainLoopInput], - results_queue: asyncio.Queue[dict[str, float]], - train_task: asyncio.Task[None], - trainer: "GRPOTrainer", - peft_model: "PeftModelForCausalLM", - warmup: bool, - verbose: bool = False, -): - """ - Process training batches and yield results. - - Yields tuples of (result, warmup_done) where warmup_done indicates if warmup just finished. - """ - precalculate_logprobs = _config.get("precalculate_logprobs", False) - - for offset in range(0, packed_tensors["tokens"].shape[0]): - for _ in range(2 if warmup else 1): - if precalculate_logprobs and not warmup: - # Preserve original logprobs before overwriting - packed_tensors["original_logprobs"] = packed_tensors["logprobs"] # type: ignore - packed_tensors["logprobs"] = precalculate_new_logprobs( - trainer, peft_model, packed_tensors, config, _config - ) - precalculate_logprobs = False - - inputs_queue.put_nowait( - create_train_inputs(packed_tensors, offset, config, _config, warmup) - ) - - # Wait for a result from the queue or for the training task to, - # presumably, raise an exception - done, _ = await asyncio.wait( - [ - asyncio.create_task(results_queue.get()), - train_task, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - if verbose: - print( - "Done waiting for a result from the queue or for the training task to, presumably, raise an exception" - ) - for task in done: - result = task.result() - # If `result` is `None`, the training task finished somehow. - assert result is not None, "The training task should never finish." - results_queue.task_done() - if warmup: - gc_and_empty_cuda_cache() - await asyncio.sleep(0.1) - warmup = False - else: - yield result - - def save_checkpoint( - trainer: "GRPOTrainer", + trainer: GRPOTrainer, output_dir: str, verbose: bool = False, ) -> str: @@ -200,111 +79,6 @@ def save_checkpoint( return checkpoint_dir -def _get_trainer_optimizer(trainer: GRPOTrainer) -> Optimizer: - optimizer = cast(Optimizer | None, getattr(trainer, "optimizer", None)) - if optimizer is None: - raise RuntimeError("Trainer optimizer must be initialized before training") - return optimizer - - -# ============================================================================ -# Model Classes -# ============================================================================ - - -class CausalLM(PreTrainedModel, GenerationMixin): - """Dummy class for type checking.""" - - pass - - -@dataclass -class UnslothState: - model: CausalLM - tokenizer: PreTrainedTokenizerBase - peft_model: peft.peft_model.PeftModelForCausalLM - trainer: GRPOTrainer - inputs_queue: asyncio.Queue[_TrainLoopInput] - results_queue: asyncio.Queue[dict[str, float]] - _is_offloaded: bool = False - _pinned_buffers: dict[str, torch.Tensor] | None = None - - def offload_to_cpu(self) -> None: - """Offload training model and optimizer to CPU using pinned memory for faster transfers.""" - if self._is_offloaded: - return - - # Initialize pinned buffer storage - if self._pinned_buffers is None: - self._pinned_buffers = {} - - # Offload model parameters to pinned memory for faster reload - for name, param in self.peft_model.named_parameters(): - if param.device.type == "cuda": - # Create pinned buffer if not exists or wrong size - if ( - name not in self._pinned_buffers - or self._pinned_buffers[name].shape != param.shape - ): - self._pinned_buffers[name] = torch.empty( - param.shape, dtype=param.dtype, device="cpu", pin_memory=True - ) - # Async copy to pinned memory - self._pinned_buffers[name].copy_(param.data, non_blocking=True) - param.data = self._pinned_buffers[name] - - # Offload optimizer state to pinned memory - optimizer = getattr(self.trainer, "optimizer", None) - if optimizer is not None and hasattr(optimizer, "state"): - for param_id, state in optimizer.state.items(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.device.type == "cuda": - key = f"opt_{id(param_id)}_{k}" - if ( - key not in self._pinned_buffers - or self._pinned_buffers[key].shape != v.shape - ): - self._pinned_buffers[key] = torch.empty( - v.shape, dtype=v.dtype, device="cpu", pin_memory=True - ) - self._pinned_buffers[key].copy_(v, non_blocking=True) - state[k] = self._pinned_buffers[key] - - # Sync to ensure all copies are complete before freeing GPU memory - torch.cuda.synchronize() - - self._is_offloaded = True - gc_and_empty_cuda_cache() - - def reload_to_gpu(self, device: str = "cuda:0") -> None: - """Reload training model and optimizer back to GPU using async transfers.""" - if not self._is_offloaded: - return - - # Reload model parameters from pinned memory (fast async transfer) - for name, param in self.peft_model.named_parameters(): - if param.device.type == "cpu": - # Allocate on GPU and async copy from pinned memory - gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) - gpu_tensor.copy_(param.data, non_blocking=True) - param.data = gpu_tensor - - # Reload optimizer state - optimizer = getattr(self.trainer, "optimizer", None) - if optimizer is not None and hasattr(optimizer, "state"): - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.device.type == "cpu": - gpu_tensor = torch.empty(v.shape, dtype=v.dtype, device=device) - gpu_tensor.copy_(v, non_blocking=True) - state[k] = gpu_tensor - - # Sync to ensure all copies are complete before training - torch.cuda.synchronize() - - self._is_offloaded = False - - # ============================================================================ # Service # ============================================================================ @@ -317,7 +91,6 @@ class UnslothService: config: dev.InternalModelConfig output_dir: str _is_sleeping: bool = False - _last_training_mode: Literal["sft", "rl"] | None = None _latest_step: int = 0 _lora_id_counter: int = 1 # Start from 1 since 0 is reserved # Dedicated mode subprocess state @@ -325,7 +98,6 @@ class UnslothService: _vllm_log_file: Any = field(default=None, repr=False) _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 - _train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False) @property def is_dedicated(self) -> bool: @@ -336,24 +108,6 @@ def _next_lora_id(self) -> int: self._lora_id_counter += 1 return self._lora_id_counter - async def aclose(self) -> None: - train_task = self._train_task - self._train_task = None - if train_task is None or train_task.done(): - self.close() - return - - # `_state` is a cached_property. Read from __dict__ directly so - # closing does not instantiate trainer state only to stop a task. - state = self.__dict__.get("_state") - assert isinstance(state, UnslothState) - state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) - try: - await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S) - except asyncio.TimeoutError: - train_task.cancel() - self.close() - # ========================================================================= # Dedicated mode: vLLM subprocess lifecycle # ========================================================================= @@ -564,27 +318,6 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: self._latest_step = step await llm.resume_generation() - def _reset_optimizer_if_mode_changed( - self, - mode: Literal["sft", "rl"], - ) -> None: - """Reset optimizer state if training mode changed. - - Uses a single shared optimizer (trainer.optimizer) for both SFT and RL. - Resets optimizer state (momentum, variance) only when switching between - training modes to avoid stale state from a different loss landscape. - """ - mode_changed = ( - self._last_training_mode is not None and self._last_training_mode != mode - ) - optimizer = _get_trainer_optimizer(self._state.trainer) - - if mode_changed: - # Clear all optimizer state (exp_avg, exp_avg_sq, step for each param) - optimizer.state.clear() - - self._last_training_mode = mode - async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -612,38 +345,11 @@ async def _train_dedicated( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: """Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU.""" - self._reset_optimizer_if_mode_changed("rl") - optimizer = _get_trainer_optimizer(self._state.trainer) - - rl_weight_decay = 0.1 - for param_group in optimizer.param_groups: - param_group["weight_decay"] = rl_weight_decay - - packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) - - await self._state.results_queue.join() - - if self._train_task is None: - self._train_task = asyncio.create_task( - train( - trainer=self._state.trainer, - results_queue=self._state.results_queue, - ) - ) - warmup = True - else: - warmup = False - - async for result in process_train_batch( - packed_tensors=packed_tensors, + async for result in run_unsloth_rl_training( + self._state, + disk_packed_tensors=disk_packed_tensors, config=config, _config=_config, - inputs_queue=self._state.inputs_queue, - results_queue=self._state.results_queue, - train_task=self._train_task, - trainer=self._state.trainer, - peft_model=self._state.peft_model, - warmup=warmup, verbose=verbose, ): yield result @@ -697,44 +403,11 @@ async def _train_shared( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() - # Reset optimizer state if switching from SFT to RL - self._reset_optimizer_if_mode_changed("rl") - optimizer = _get_trainer_optimizer(self._state.trainer) - - # Set RL-specific hyperparameters - rl_weight_decay = 0.1 - for param_group in optimizer.param_groups: - param_group["weight_decay"] = rl_weight_decay - - # Load packed tensors - packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) - - # Wait for existing batches to finish - await self._state.results_queue.join() - - # If we haven't already, start the training task - if self._train_task is None: - self._train_task = asyncio.create_task( - train( - trainer=self._state.trainer, - results_queue=self._state.results_queue, - ) - ) - warmup = True - else: - warmup = False - - # Train on the batch using shared logic - async for result in process_train_batch( - packed_tensors=packed_tensors, + async for result in run_unsloth_rl_training( + self._state, + disk_packed_tensors=disk_packed_tensors, config=config, _config=_config, - inputs_queue=self._state.inputs_queue, - results_queue=self._state.results_queue, - train_task=self._train_task, - trainer=self._state.trainer, - peft_model=self._state.peft_model, - warmup=warmup, verbose=verbose, ): yield result @@ -806,8 +479,6 @@ async def train_sft( raise NotImplementedError( "train_sft is not yet supported in dedicated mode" ) - import time - llm = await self.llm # === Setup === @@ -830,91 +501,19 @@ async def train_sft( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() - # Get model and optimizer - peft_model = self._state.peft_model - self._reset_optimizer_if_mode_changed("sft") - optimizer = _get_trainer_optimizer(self._state.trainer) - - # Set SFT-specific hyperparameters - sft_weight_decay = 0.01 - for param_group in optimizer.param_groups: - param_group["weight_decay"] = sft_weight_decay - - # Reset environment variable that may be set by RL training - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" - - peft_model.train() - device = next(peft_model.parameters()).device - max_grad_norm = 1.0 - if verbose: print("SFT training started") - # === Process batches === - batch_idx = 0 - for batch in batches: - batch_start_time = time.perf_counter() - batch_loss = 0.0 - - # Update learning rate for this batch - for param_group in optimizer.param_groups: - param_group["lr"] = batch.learning_rate - - # Total trainable tokens for loss normalization - num_items_in_batch = torch.tensor( - batch.num_trainable_tokens, dtype=torch.long, device=device - ) - - # Process each trajectory in the batch (gradient accumulation) - for trajectory_tensor in batch.trajectory_tensors: - # Move tensors to device - input_ids = trajectory_tensor["input_ids"].to(device) - attention_mask = trajectory_tensor["attention_mask"].to(device) - labels = trajectory_tensor["labels"].to(device) - - # Forward pass with num_items_in_batch for proper loss normalization - outputs = peft_model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - num_items_in_batch=num_items_in_batch, - ) - - loss = outputs.loss - - # Backward pass - accumulate gradients - loss.backward() - - # Track metrics - batch_loss += loss.item() - - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_( - peft_model.parameters(), max_grad_norm - ).item() - - # Optimizer step at the end of each batch - optimizer.step() - optimizer.zero_grad() - - # Compute timing metrics - batch_time = time.perf_counter() - batch_start_time - tokens_per_second = ( - batch.num_trainable_tokens / batch_time if batch_time > 0 else 0.0 - ) - - if verbose: - print( - f"Batch {batch_idx}: loss={batch_loss:.4f}, lr={batch.learning_rate:.2e}, " - f"grad_norm={grad_norm:.4f}, tok/s={tokens_per_second:.1f}" - ) - - batch_idx += 1 - + async for result in run_unsloth_sft_training( + self._state, + batches, + verbose=verbose, + weight_decay=0.01, + ): yield { - "loss/train": batch_loss, - "loss/learning_rate": batch.learning_rate, - "loss/grad_norm": grad_norm, + "loss/train": result["loss"], + "loss/learning_rate": result["learning_rate"], + "loss/grad_norm": result["grad_norm"], } # === Cleanup === @@ -958,82 +557,18 @@ async def train_sft( print("SFT training finished") @cached_property - def _state(self) -> UnslothState: - import unsloth - - # Initialize Unsloth model - init_args = self.config.get("init_args", {}) + def _state(self) -> UnslothTrainContext: + init_args = dict(self.config.get("init_args", {})) checkpoint_dir = get_last_checkpoint_dir(self.output_dir) if checkpoint_dir: init_args["model_name"] = checkpoint_dir else: init_args["model_name"] = self.base_model - model, tokenizer = cast( - tuple[CausalLM, PreTrainedTokenizerBase], - unsloth.FastLanguageModel.from_pretrained(**init_args), - ) - - # Initialize PEFT model - skip if already a PeftModel (e.g. loaded from checkpoint) - if ( - hasattr(model, "peft_config") - and getattr(model, "peft_config", None) is not None - ): - # Model already has LoRA adapters (loaded from checkpoint) - peft_model = cast(peft.peft_model.PeftModelForCausalLM, model) - else: - peft_model = cast( - peft.peft_model.PeftModelForCausalLM, - unsloth.FastLanguageModel.get_peft_model( - model, **self.config.get("peft_args", {}) - ), - ) - - # Unsloth's model patching can leave the PEFT model without - # `warnings_issued`, which GRPOTrainer expects during init. - if not hasattr(peft_model, "warnings_issued"): - peft_model.warnings_issued = {} # type: ignore[attr-defined] - - # Initialize trainer with dummy dataset - data = {"prompt": ""} - trainer = GRPOTrainer( - model=peft_model, # type: ignore - reward_funcs=[], - args=GRPOConfig(**self.config.get("trainer_args", {})), - train_dataset=Dataset.from_list([data for _ in range(10_000_000)]), - processing_class=tokenizer, - ) - - # Initialize optimizer eagerly using trainer's configured settings. - if trainer.optimizer is None: - trainer.create_optimizer() - - # Initialize queues - inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue() - results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() - - # Patch trainer _prepare_inputs() to pull from queue - def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: - async def get_inputs() -> _TrainLoopInput: - return await inputs_queue.get() - - # Force otherwise synchronous _prepare_inputs() to yield - # with nested asyncio.run() call - inputs = asyncio.run(get_inputs()) - if isinstance(inputs, _StopTrainInputs): - raise StopTrainingLoop() - - return cast(dict[str, torch.Tensor], inputs) - - trainer._prepare_inputs = _async_prepare_inputs - - return UnslothState( - model=model, - tokenizer=tokenizer, - peft_model=peft_model, - trainer=trainer, - inputs_queue=inputs_queue, - results_queue=results_queue, + return create_unsloth_train_context( + init_args=init_args, + peft_args=dict(self.config.get("peft_args", {})), + trainer_args=dict(self.config.get("trainer_args", {})), ) @cached_property diff --git a/src/art/unsloth/shared.py b/src/art/unsloth/shared.py new file mode 100755 index 00000000..9854bf2a --- /dev/null +++ b/src/art/unsloth/shared.py @@ -0,0 +1,454 @@ +import asyncio +from dataclasses import dataclass +import os +import time +from typing import Any, AsyncIterator, Iterable, Literal, cast + +from datasets import Dataset +import nest_asyncio +import peft +import torch +from torch.optim import Optimizer +from transformers import GenerationMixin, PreTrainedModel +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from trl import GRPOConfig, GRPOTrainer + +from .. import dev, types +from ..preprocessing.inputs import TrainInputs, create_train_inputs +from ..preprocessing.pack import ( + DiskPackedTensors, + PackedTensors, + packed_tensors_from_dir, +) +from ..preprocessing.tokenize import SFTBatch +from .train import gc_and_empty_cuda_cache, train + +nest_asyncio.apply() + + +class CausalLM(PreTrainedModel, GenerationMixin): + """Dummy class for type checking.""" + + pass + + +@dataclass +class UnslothTrainContext: + model: CausalLM + tokenizer: PreTrainedTokenizerBase + peft_model: peft.peft_model.PeftModelForCausalLM + trainer: GRPOTrainer + inputs_queue: asyncio.Queue[TrainInputs] + results_queue: asyncio.Queue[dict[str, float]] + train_task: asyncio.Task[None] | None = None + warmup_pending: bool = True + last_training_mode: Literal["sft", "rl"] | None = None + _is_offloaded: bool = False + _pinned_buffers: dict[str, torch.Tensor] | None = None + + def offload_to_cpu(self) -> None: + if self._is_offloaded: + return + + if self._pinned_buffers is None: + self._pinned_buffers = {} + + for name, param in self.peft_model.named_parameters(): + if param.device.type != "cuda": + continue + if ( + name not in self._pinned_buffers + or self._pinned_buffers[name].shape != param.shape + ): + self._pinned_buffers[name] = torch.empty( + param.shape, + dtype=param.dtype, + device="cpu", + pin_memory=True, + ) + self._pinned_buffers[name].copy_(param.data, non_blocking=True) + param.data = self._pinned_buffers[name] + + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None and hasattr(optimizer, "state"): + for param_id, state in optimizer.state.items(): + for key, value in state.items(): + if not isinstance(value, torch.Tensor) or value.device.type != "cuda": + continue + buffer_key = f"opt_{id(param_id)}_{key}" + if ( + buffer_key not in self._pinned_buffers + or self._pinned_buffers[buffer_key].shape != value.shape + ): + self._pinned_buffers[buffer_key] = torch.empty( + value.shape, + dtype=value.dtype, + device="cpu", + pin_memory=True, + ) + self._pinned_buffers[buffer_key].copy_(value, non_blocking=True) + state[key] = self._pinned_buffers[buffer_key] + + torch.cuda.synchronize() + self._is_offloaded = True + gc_and_empty_cuda_cache() + + def reload_to_gpu(self, device: str = "cuda:0") -> None: + if not self._is_offloaded: + return + + for _, param in self.peft_model.named_parameters(): + if param.device.type != "cpu": + continue + gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) + gpu_tensor.copy_(param.data, non_blocking=True) + param.data = gpu_tensor + + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None and hasattr(optimizer, "state"): + for state in optimizer.state.values(): + for key, value in state.items(): + if not isinstance(value, torch.Tensor) or value.device.type != "cpu": + continue + gpu_tensor = torch.empty(value.shape, dtype=value.dtype, device=device) + gpu_tensor.copy_(value, non_blocking=True) + state[key] = gpu_tensor + + torch.cuda.synchronize() + self._is_offloaded = False + + async def load_lora_adapter(self, lora_path: str) -> None: + try: + await self.results_queue.join() + except Exception: + pass + try: + torch.cuda.synchronize() + except Exception: + pass + + try: + from safetensors.torch import load_file as load_safetensors # type: ignore + except Exception: + load_safetensors = None # type: ignore[assignment] + + state_dict = None + st_path = os.path.join(lora_path, "adapter_model.safetensors") + bin_path = os.path.join(lora_path, "adapter_model.bin") + alt_st_path = os.path.join(lora_path, "model.safetensors") + alt_bin_path = os.path.join(lora_path, "pytorch_model.bin") + try: + if os.path.exists(st_path) and load_safetensors is not None: + state_dict = load_safetensors(st_path, device="cpu") + elif os.path.exists(bin_path): + state_dict = torch.load(bin_path, map_location="cpu") # type: ignore[call-arg] + elif os.path.exists(alt_st_path) and load_safetensors is not None: + state_dict = load_safetensors(alt_st_path, device="cpu") + elif os.path.exists(alt_bin_path): + state_dict = torch.load(alt_bin_path, map_location="cpu") # type: ignore[call-arg] + else: + raise FileNotFoundError(f"No adapter weights found in {lora_path}") + except Exception as exc: + raise RuntimeError(f"Failed to load LoRA adapter weights: {exc}") + + with torch.no_grad(): + self.peft_model.zero_grad(set_to_none=True) + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None: + optimizer = getattr(optimizer, "optimizer", optimizer) + if hasattr(optimizer, "zero_grad"): + optimizer.zero_grad(set_to_none=True) # type: ignore[arg-type] + if hasattr(optimizer, "state") and isinstance(optimizer.state, dict): + optimizer.state.clear() + + try: + try: + from peft.utils.save_and_load import ( + set_peft_model_state_dict as _set_peft_model_state_dict, + ) + except Exception: + from peft import ( + set_peft_model_state_dict as _set_peft_model_state_dict, # type: ignore + ) + + active_adapter = getattr(self.peft_model, "active_adapter", "default") + _set_peft_model_state_dict( + self.peft_model, + state_dict, + adapter_name=active_adapter, + ) + self.peft_model.set_adapter(active_adapter) + except Exception as exc: + raise RuntimeError(f"Failed to set LoRA weights in-place: {exc}") + + try: + torch.cuda.synchronize() + except Exception: + pass + + async def load_optimizer_state(self, checkpoint_dir: str) -> None: + try: + await self.results_queue.join() + except Exception: + pass + try: + torch.cuda.synchronize() + except Exception: + pass + + optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt") + if os.path.exists(optimizer_path): + optimizer_state = torch.load(optimizer_path, map_location="cpu") + self.trainer.optimizer.load_state_dict(optimizer_state) + + def save_lora_adapter(self, lora_path: str) -> None: + self.trainer.save_model(lora_path) + + def save_optimizer_state(self, checkpoint_dir: str) -> None: + optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt") + torch.save(self.trainer.optimizer.state_dict(), optimizer_path) + + +def create_unsloth_train_context( + *, + init_args: dict[str, Any], + peft_args: dict[str, Any], + trainer_args: dict[str, Any], + use_fast_model: bool = False, +) -> UnslothTrainContext: + import unsloth + + loader_cls = unsloth.FastModel if use_fast_model else unsloth.FastLanguageModel + model, tokenizer = cast( + tuple[CausalLM, PreTrainedTokenizerBase], + loader_cls.from_pretrained(**init_args), + ) + + if hasattr(model, "peft_config") and getattr(model, "peft_config", None) is not None: + peft_model = cast(peft.peft_model.PeftModelForCausalLM, model) + else: + peft_model = cast( + peft.peft_model.PeftModelForCausalLM, + loader_cls.get_peft_model(model, **peft_args), + ) + + if not hasattr(peft_model, "warnings_issued"): + peft_model.warnings_issued = {} # type: ignore[attr-defined] + + trainer = GRPOTrainer( + model=peft_model, # type: ignore[arg-type] + reward_funcs=[], + args=GRPOConfig(**trainer_args), + train_dataset=Dataset.from_list([{"prompt": ""} for _ in range(10_000_000)]), + processing_class=tokenizer, + ) + if trainer.optimizer is None: + trainer.create_optimizer() + + inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue() + results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() + + def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: + async def get_inputs() -> TrainInputs: + return await inputs_queue.get() + + inputs = asyncio.run(get_inputs()) + return cast(dict[str, torch.Tensor], inputs) + + trainer._prepare_inputs = _async_prepare_inputs + + return UnslothTrainContext( + model=model, + tokenizer=tokenizer, + peft_model=peft_model, + trainer=trainer, + inputs_queue=inputs_queue, + results_queue=results_queue, + ) + + +def _get_trainer_optimizer(ctx: UnslothTrainContext) -> Optimizer: + optimizer = cast(Optimizer | None, getattr(ctx.trainer, "optimizer", None)) + if optimizer is None: + raise RuntimeError("Trainer optimizer must be initialized before training") + return optimizer + + +def _reset_optimizer_if_mode_changed( + ctx: UnslothTrainContext, + mode: Literal["sft", "rl"], +) -> None: + mode_changed = ctx.last_training_mode is not None and ctx.last_training_mode != mode + if mode_changed: + _get_trainer_optimizer(ctx).state.clear() + ctx.last_training_mode = mode + + +def _precalculate_new_logprobs( + ctx: UnslothTrainContext, + packed_tensors: PackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, +) -> torch.Tensor: + return torch.cat( + [ + ctx.trainer.compute_loss( + ctx.peft_model, + TrainInputs( # ty:ignore[missing-typed-dict-key] + **{ + key: value[offset : offset + 1] + for key, value in packed_tensors.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=packed_tensors["pixel_values"][offset : offset + 1], + image_grid_thw=packed_tensors["image_grid_thw"][offset : offset + 1], + config=config, + _config=_config, + return_new_logprobs=True, + ), + ) + for offset in range(0, packed_tensors["tokens"].shape[0]) + ] + ).to("cpu") + + +async def run_unsloth_rl_training( + ctx: UnslothTrainContext, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, +) -> AsyncIterator[dict[str, float]]: + _reset_optimizer_if_mode_changed(ctx, "rl") + optimizer = _get_trainer_optimizer(ctx) + for param_group in optimizer.param_groups: + param_group["weight_decay"] = 0.1 + + packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) + await ctx.results_queue.join() + + if ctx.train_task is None: + ctx.train_task = asyncio.create_task( + train( + trainer=ctx.trainer, + results_queue=ctx.results_queue, + ) + ) + + warmup = ctx.warmup_pending + precalculate_logprobs = _config.get("precalculate_logprobs", False) + + for offset in range(0, packed_tensors["tokens"].shape[0]): + for _ in range(2 if warmup else 1): + if precalculate_logprobs and not warmup: + packed_tensors["original_logprobs"] = packed_tensors["logprobs"] # type: ignore + packed_tensors["logprobs"] = _precalculate_new_logprobs( + ctx, + packed_tensors, + config, + _config, + ) + precalculate_logprobs = False + + ctx.inputs_queue.put_nowait( + create_train_inputs(packed_tensors, offset, config, _config, warmup) + ) + + done, _ = await asyncio.wait( + [ + asyncio.create_task(ctx.results_queue.get()), + ctx.train_task, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + if verbose: + print( + "Done waiting for a result from the queue or for the training task to, presumably, raise an exception" + ) + for task in done: + result = task.result() + assert result is not None, "The training task should never finish." + ctx.results_queue.task_done() + if warmup: + gc_and_empty_cuda_cache() + await asyncio.sleep(0.1) + warmup = False + ctx.warmup_pending = False + else: + yield result + + +async def run_unsloth_sft_training( + ctx: UnslothTrainContext, + batches: Iterable[SFTBatch], + verbose: bool = False, + *, + weight_decay: float = 0.0, + max_grad_norm: float = 1.0, +) -> AsyncIterator[dict[str, float]]: + _reset_optimizer_if_mode_changed(ctx, "sft") + optimizer = _get_trainer_optimizer(ctx) + + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" + + for param_group in optimizer.param_groups: + param_group["weight_decay"] = weight_decay + + ctx.peft_model.train() + device = next(ctx.peft_model.parameters()).device + + for batch_idx, batch in enumerate(batches): + batch_start_time = time.perf_counter() + batch_loss = 0.0 + + for param_group in optimizer.param_groups: + param_group["lr"] = batch.learning_rate + + num_trainable_tokens = torch.tensor( + batch.num_trainable_tokens, + dtype=torch.long, + device=device, + ) + + for trajectory_tensor in batch.trajectory_tensors: + input_ids = trajectory_tensor["input_ids"].to(device) + attention_mask = trajectory_tensor["attention_mask"].to(device) + labels = trajectory_tensor["labels"].to(device) + + outputs = ctx.peft_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + num_items_in_batch=num_trainable_tokens, + ) + loss = outputs.loss + loss.backward() + batch_loss += loss.item() + + grad_norm = torch.nn.utils.clip_grad_norm_( + ctx.peft_model.parameters(), + max_grad_norm, + ).item() + + optimizer.step() + optimizer.zero_grad() + + batch_time = time.perf_counter() - batch_start_time + tokens_per_second = ( + batch.num_trainable_tokens / batch_time if batch_time > 0 else 0.0 + ) + + if verbose: + print( + f"Batch {batch_idx}: loss={batch_loss:.4f}, lr={batch.learning_rate:.2e}, " + f"grad_norm={grad_norm:.4f}, tok/s={tokens_per_second:.1f}" + ) + + yield { + "loss": batch_loss, + "learning_rate": batch.learning_rate, + "grad_norm": grad_norm, + "num_trajectories": float(batch.num_trajectories), + "num_trainable_tokens": float(batch.num_trainable_tokens), + "tokens_per_second": tokens_per_second, + } From a1b8efcd7f6ba5015d7abb0ea27561b9ea3d9d5f Mon Sep 17 00:00:00 2001 From: Kovbo Date: Sat, 21 Mar 2026 02:03:21 +0000 Subject: [PATCH 2/5] Normalize ART Python file modes --- src/art/_backend_training.py | 0 src/art/local/backend.py | 0 src/art/megatron/jobs.py | 0 src/art/megatron/runtime_env.py | 0 src/art/megatron/service.py | 0 src/art/megatron/shared.py | 0 src/art/megatron/train.py | 0 src/art/serverless/backend.py | 0 src/art/unsloth/service.py | 0 src/art/unsloth/shared.py | 0 10 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 src/art/_backend_training.py mode change 100755 => 100644 src/art/local/backend.py mode change 100755 => 100644 src/art/megatron/jobs.py mode change 100755 => 100644 src/art/megatron/runtime_env.py mode change 100755 => 100644 src/art/megatron/service.py mode change 100755 => 100644 src/art/megatron/shared.py mode change 100755 => 100644 src/art/megatron/train.py mode change 100755 => 100644 src/art/serverless/backend.py mode change 100755 => 100644 src/art/unsloth/service.py mode change 100755 => 100644 src/art/unsloth/shared.py diff --git a/src/art/_backend_training.py b/src/art/_backend_training.py old mode 100755 new mode 100644 diff --git a/src/art/local/backend.py b/src/art/local/backend.py old mode 100755 new mode 100644 diff --git a/src/art/megatron/jobs.py b/src/art/megatron/jobs.py old mode 100755 new mode 100644 diff --git a/src/art/megatron/runtime_env.py b/src/art/megatron/runtime_env.py old mode 100755 new mode 100644 diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py old mode 100755 new mode 100644 diff --git a/src/art/megatron/shared.py b/src/art/megatron/shared.py old mode 100755 new mode 100644 diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py old mode 100755 new mode 100644 diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py old mode 100755 new mode 100644 diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py old mode 100755 new mode 100644 diff --git a/src/art/unsloth/shared.py b/src/art/unsloth/shared.py old mode 100755 new mode 100644 From 86ae9336c34d1c6fe71d17bedf663aa18eb1ce02 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Mon, 23 Mar 2026 18:15:43 +0000 Subject: [PATCH 3/5] Fix lint and typing regressions in shared training refactor --- src/art/_backend_training.py | 8 ++------ src/art/megatron/service.py | 9 +++++++-- src/art/megatron/shared.py | 16 ++++++++++++---- src/art/megatron/train.py | 4 +++- src/art/unsloth/service.py | 13 +++++++++---- src/art/unsloth/shared.py | 23 ++++++++++++++++++----- 6 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/art/_backend_training.py b/src/art/_backend_training.py index b8e82377..e698a7f1 100644 --- a/src/art/_backend_training.py +++ b/src/art/_backend_training.py @@ -51,9 +51,7 @@ def build_rl_train_configs( } if allow_training_without_logprobs is not None: - dev_config["allow_training_without_logprobs"] = ( - allow_training_without_logprobs - ) + dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs if plot_tensors is not None: dev_config["plot_tensors"] = plot_tensors if truncated_importance_sampling is not None: @@ -63,9 +61,7 @@ def build_rl_train_configs( scale_learning_rate_by_reward_std_dev ) if logprob_calculation_chunk_size is not None: - dev_config["logprob_calculation_chunk_size"] = ( - logprob_calculation_chunk_size - ) + dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size if num_trajectories_learning_rate_multiplier_power is not None: dev_config["num_trajectories_learning_rate_multiplier_power"] = ( num_trajectories_learning_rate_multiplier_power diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 686af707..d11e8d75 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass import datetime from functools import cached_property +import importlib import json import os from pathlib import Path @@ -10,8 +11,6 @@ from typing import Any, AsyncIterator from peft.tuners.lora.config import LoraConfig -from safetensors import safe_open -from safetensors.torch import load_file, save_file import torch from vllm import AsyncEngineArgs from vllm.lora.request import LoRARequest @@ -31,6 +30,12 @@ MegatronTrainingJob, ) +safetensors = importlib.import_module("safetensors") +safetensors_torch = importlib.import_module("safetensors.torch") +safe_open = safetensors.safe_open +load_file = safetensors_torch.load_file +save_file = safetensors_torch.save_file + @dataclass class MegatronService: diff --git a/src/art/megatron/shared.py b/src/art/megatron/shared.py index b84b873d..dba62103 100644 --- a/src/art/megatron/shared.py +++ b/src/art/megatron/shared.py @@ -1,14 +1,14 @@ +from dataclasses import dataclass import gc +import importlib import json import math import os import shutil import time -from dataclasses import dataclass from typing import Any from megatron.core import parallel_state as ps -from safetensors.torch import load_file, save_file import torch from ..loss import loss_fn, shift_tensor @@ -16,6 +16,10 @@ from .flex_attention import create_shared_prefix_attention_state from .jobs import MegatronSFTTrainingJob, MegatronTrainingJob +safetensors_torch = importlib.import_module("safetensors.torch") +load_file = safetensors_torch.load_file +save_file = safetensors_torch.save_file + @dataclass class MegatronTrainContext: @@ -274,7 +278,9 @@ def run_megatron_sft_job( update_successful, grad_norm, num_zeros_in_grad = ctx.optimizer.step() ctx.optimizer.zero_grad() - torch.distributed.reduce(batch_loss, dst=0, op=torch.distributed.ReduceOp.SUM) + torch.distributed.reduce( + batch_loss, dst=0, op=torch.distributed.ReduceOp.SUM + ) avg_loss = batch_loss / num_trainable_tokens batch_time = time.perf_counter() - batch_start_time @@ -289,7 +295,9 @@ def run_megatron_sft_job( "loss": avg_loss.item(), "learning_rate": job.learning_rates[batch_idx], "grad_norm": float(grad_norm), - "num_trajectories": float(batch_metadata["num_trajectories"]), + "num_trajectories": float( + batch_metadata["num_trajectories"] + ), "num_trainable_tokens": float(num_trainable_tokens), "tokens_per_second": tokens_per_second, } diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index bb1c5cfd..fa0520c1 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -28,7 +28,9 @@ torch.distributed.barrier() os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True) job_names = sorted( - job_name for job_name in os.listdir(DEFAULT_JOBS_DIR) if job_name.endswith(".json") + job_name + for job_name in os.listdir(DEFAULT_JOBS_DIR) + if job_name.endswith(".json") ) if not job_names: time.sleep(1) diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 6a6a0eb9..c94e6ea3 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -8,7 +8,7 @@ import os import subprocess import sys -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, cast from trl import GRPOTrainer from vllm import AsyncEngineArgs @@ -18,6 +18,7 @@ from .. import dev, types from ..dev.validate import is_dedicated_mode from ..local.checkpoints import get_last_checkpoint_dir +from ..preprocessing.inputs import TrainInputs from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch from ..utils.convert_moe_lora import convert_checkpoint_if_needed @@ -34,6 +35,7 @@ logger = logging.getLogger(__name__) + def save_checkpoint( trainer: GRPOTrainer, output_dir: str, @@ -558,7 +560,7 @@ async def train_sft( @cached_property def _state(self) -> UnslothTrainContext: - init_args = dict(self.config.get("init_args", {})) + init_args = dict(cast(dict[str, Any], self.config.get("init_args") or {})) checkpoint_dir = get_last_checkpoint_dir(self.output_dir) if checkpoint_dir: init_args["model_name"] = checkpoint_dir @@ -567,8 +569,11 @@ def _state(self) -> UnslothTrainContext: return create_unsloth_train_context( init_args=init_args, - peft_args=dict(self.config.get("peft_args", {})), - trainer_args=dict(self.config.get("trainer_args", {})), + peft_args=cast(dict[str, Any], self.config.get("peft_args") or {}), + trainer_args=cast( + dict[str, Any], + self.config.get("trainer_args") or {}, + ), ) @cached_property diff --git a/src/art/unsloth/shared.py b/src/art/unsloth/shared.py index 9854bf2a..dc117ed0 100644 --- a/src/art/unsloth/shared.py +++ b/src/art/unsloth/shared.py @@ -73,7 +73,10 @@ def offload_to_cpu(self) -> None: if optimizer is not None and hasattr(optimizer, "state"): for param_id, state in optimizer.state.items(): for key, value in state.items(): - if not isinstance(value, torch.Tensor) or value.device.type != "cuda": + if ( + not isinstance(value, torch.Tensor) + or value.device.type != "cuda" + ): continue buffer_key = f"opt_{id(param_id)}_{key}" if ( @@ -108,9 +111,14 @@ def reload_to_gpu(self, device: str = "cuda:0") -> None: if optimizer is not None and hasattr(optimizer, "state"): for state in optimizer.state.values(): for key, value in state.items(): - if not isinstance(value, torch.Tensor) or value.device.type != "cpu": + if ( + not isinstance(value, torch.Tensor) + or value.device.type != "cpu" + ): continue - gpu_tensor = torch.empty(value.shape, dtype=value.dtype, device=device) + gpu_tensor = torch.empty( + value.shape, dtype=value.dtype, device=device + ) gpu_tensor.copy_(value, non_blocking=True) state[key] = gpu_tensor @@ -224,7 +232,10 @@ def create_unsloth_train_context( loader_cls.from_pretrained(**init_args), ) - if hasattr(model, "peft_config") and getattr(model, "peft_config", None) is not None: + if ( + hasattr(model, "peft_config") + and getattr(model, "peft_config", None) is not None + ): peft_model = cast(peft.peft_model.PeftModelForCausalLM, model) else: peft_model = cast( @@ -301,7 +312,9 @@ def _precalculate_new_logprobs( if isinstance(value, torch.Tensor) }, pixel_values=packed_tensors["pixel_values"][offset : offset + 1], - image_grid_thw=packed_tensors["image_grid_thw"][offset : offset + 1], + image_grid_thw=packed_tensors["image_grid_thw"][ + offset : offset + 1 + ], config=config, _config=_config, return_new_logprobs=True, From d8d8e4d7db2f9ad6256fde155e9f0d573f3a78bc Mon Sep 17 00:00:00 2001 From: Kovbo Date: Mon, 23 Mar 2026 18:31:20 +0000 Subject: [PATCH 4/5] Fix ART ty failures in local backend and MoE conversion --- src/art/local/backend.py | 27 ++++++++++++++++++++++++++- src/art/utils/convert_moe_lora.py | 8 +++++--- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index abb6395d..3d3b38de 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -162,6 +162,9 @@ def _allocated_gpu_count(self, model: Model) -> int: def __enter__(self) -> Self: return self + async def __aenter__(self) -> Self: + return self + def __exit__( self, exc_type: type[BaseException] | None, @@ -170,11 +173,19 @@ def __exit__( ) -> None: self._close() + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + await self.close() + async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - self._close() + await self._aclose() def _close(self) -> None: for _, service in self._services.items(): @@ -183,6 +194,17 @@ def _close(self) -> None: close() close_proxy(service) + async def _aclose(self) -> None: + for _, service in self._services.items(): + aclose = getattr(service, "aclose", None) + if aclose is not None: + await aclose() + else: + close = getattr(service, "close", None) + if close is not None: + close() + close_proxy(service) + async def register( self, model: Model, @@ -505,6 +527,7 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, + loss_fn: Literal["cispo", "ppo"] | None = None, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, @@ -600,6 +623,8 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) + if loss_fn is not None: + ppo = loss_fn == "ppo" resolved_kl_ref_adapter_path = kl_ref_adapter_path if ( diff --git a/src/art/utils/convert_moe_lora.py b/src/art/utils/convert_moe_lora.py index 0ea80f63..ff3e893c 100644 --- a/src/art/utils/convert_moe_lora.py +++ b/src/art/utils/convert_moe_lora.py @@ -12,13 +12,15 @@ ... """ +import importlib import json import os import re -import safetensors.torch import torch +safetensors_torch = importlib.import_module("safetensors.torch") + def _has_fused_moe_lora(tensors: dict[str, torch.Tensor]) -> bool: """Check if the adapter contains fused MoE LoRA tensors.""" @@ -152,7 +154,7 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None: if not os.path.exists(adapter_path) or not os.path.exists(config_path): return - tensors = safetensors.torch.load_file(adapter_path) + tensors = safetensors_torch.load_file(adapter_path) if not _has_fused_moe_lora(tensors): return @@ -168,7 +170,7 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None: ) # Overwrite the adapter with the converted tensors - safetensors.torch.save_file(new_tensors, adapter_path) + safetensors_torch.save_file(new_tensors, adapter_path) # Update adapter_config.json target_modules adapter_config["target_modules"] = [ From b19e94c958f155f7e99690ab136fb153d88b3b77 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Tue, 24 Mar 2026 19:49:33 +0000 Subject: [PATCH 5/5] Minimize LocalBackend diff against main --- src/art/local/backend.py | 97 ++++++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 3d3b38de..baec8576 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -48,6 +48,7 @@ build_rl_train_configs, ) from ..backend import AnyTrainableModel, Backend +from ..costs import build_cost_calculator, get_model_pricing from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, build_training_summary_metrics, @@ -185,26 +186,23 @@ async def close(self) -> None: """ If running vLLM in a separate process, this will kill that process and close the communication threads. """ - await self._aclose() + for service in self._services.values(): + aclose = getattr(service, "aclose", None) + if aclose is None: + close = getattr(service, "close", None) + if close is not None: + close() + else: + await aclose() + close_proxy(service) def _close(self) -> None: - for _, service in self._services.items(): + for service in self._services.values(): close = getattr(service, "close", None) if close is not None: close() close_proxy(service) - async def _aclose(self) -> None: - for _, service in self._services.items(): - aclose = getattr(service, "aclose", None) - if aclose is not None: - await aclose() - else: - close = getattr(service, "close", None) - if close is not None: - close() - close_proxy(service) - async def register( self, model: Model, @@ -231,6 +229,11 @@ async def register( # (wandb initialization is now handled by the model's _get_wandb_run method) if model.trainable and "WANDB_API_KEY" in os.environ: _ = model._get_wandb_run() + if model.trainable: + trainable_model = cast(TrainableModel, model) + pricing = get_model_pricing(trainable_model.base_model) + if pricing is not None: + trainable_model.set_cost_calculator(build_cost_calculator(pricing)) def _model_inference_name(self, model: Model, step: int | None = None) -> str: """Return the inference name for a model checkpoint. @@ -244,25 +247,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str: If None, returns name for latest checkpoint (step 0 initially). """ - # For LocalBackend, vLLM always serves LoRA adapters with @step suffix - # Default to step 0 when not specified (the initial checkpoint created at registration) - if step is not None: - actual_step = step - elif model.name in self._services and self._in_process: - # In dedicated mode the service tracks which adapter vLLM has - # actually loaded. Reading the filesystem would race: the - # checkpoint directory appears before the HTTP reload completes. - svc = self._services[model.name] - loaded_step = getattr(svc, "_latest_step", None) - actual_step = ( - loaded_step if loaded_step is not None else self.__get_step(model) - ) - else: - actual_step = self.__get_step(model) - name = f"{model.name}@{actual_step}" + requested_step = step + + if step is None and isinstance(model, TrainableModel): + from ..dev.validate import is_dedicated_mode + + service = self._services.get(model.name) + if service is not None and is_dedicated_mode( + model._internal_config or dev.InternalModelConfig() + ): + loaded_step = getattr(service, "_latest_step", None) + if isinstance(loaded_step, int): + step = loaded_step + + if step is None: + # The checkpoint directory is written before dedicated-mode + # vLLM finishes reloading the new adapter. + step = self.__get_step(model) + name = f"{model.name}@{step}" logger.debug( - f"[BACKEND] _model_inference_name: step_arg={step} " - f"actual_step={actual_step} -> {name}" + f"[BACKEND] _model_inference_name: step_arg={requested_step} " + f"actual_step={step} -> {name}" ) return name @@ -527,13 +532,14 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, - loss_fn: Literal["cispo", "ppo"] | None = None, + loss_fn: Literal["cispo", "ppo"] = "cispo", + loss_fn_config: dict | None = None, + normalize_advantages: bool = True, + adam_params: object | None = None, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, - # RL algorithm settings - ppo: bool = False, epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -570,6 +576,14 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. + loss_fn: RL loss function. LocalBackend currently supports + "cispo" and "ppo". + loss_fn_config: Additional loss-function config. Not supported by + LocalBackend. + normalize_advantages: Whether to normalize advantages. LocalBackend + currently requires True. + adam_params: Custom optimizer params. Not supported by + LocalBackend. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. Tokens diverging more from the reference get reduced advantages. Defaults to 0.0 (disabled). @@ -579,8 +593,7 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. - ppo: Whether to use PPO clipping. Defaults to False. - epsilon: Clip epsilon for importance sampling. Defaults based on ppo. + epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages in range [-1.0, 1.0]. Defaults to 0.0 (balanced). @@ -623,8 +636,14 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) - if loss_fn is not None: - ppo = loss_fn == "ppo" + if loss_fn not in {"cispo", "ppo"}: + raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.") + if loss_fn_config is not None: + raise ValueError("LocalBackend requires loss_fn_config=None.") + if not normalize_advantages: + raise ValueError("LocalBackend requires normalize_advantages=True.") + if adam_params is not None: + raise ValueError("LocalBackend requires adam_params=None.") resolved_kl_ref_adapter_path = kl_ref_adapter_path if ( @@ -641,7 +660,7 @@ async def train( # type: ignore[override] scale_rewards=scale_rewards, importance_sampling_level=importance_sampling_level, mask_prob_ratio=mask_prob_ratio, - ppo=ppo, + ppo=loss_fn == "ppo", precalculate_logprobs=precalculate_logprobs, epsilon=epsilon, epsilon_high=epsilon_high,