Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions src/art/_backend_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections.abc import Iterable
import time
from typing import Literal

from . import dev
from .metrics_taxonomy import (
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
from .trajectories import TrajectoryGroup
from .types import TrainConfig


def build_rl_train_configs(
*,
learning_rate: float,
advantage_balance: float = 0.0,
scale_rewards: bool = True,
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
] = "token",
mask_prob_ratio: bool = False,
ppo: bool = False,
precalculate_logprobs: bool = False,
epsilon: float | None = None,
epsilon_high: float | None = None,
max_negative_advantage_importance_sampling_weight: float | None = None,
kimi_k2_tau: float | None = None,
kl_penalty_coef: float = 0.0,
allow_training_without_logprobs: bool | None = None,
plot_tensors: bool | None = None,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool | None = None,
logprob_calculation_chunk_size: int | None = None,
num_trajectories_learning_rate_multiplier_power: float | None = None,
kl_ref_adapter_path: str | None = None,
) -> tuple[TrainConfig, dev.TrainConfig]:
config = TrainConfig(
learning_rate=learning_rate,
kl_penalty_coef=kl_penalty_coef,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"ppo": ppo,
"precalculate_logprobs": precalculate_logprobs,
"scale_rewards": scale_rewards,
}

if allow_training_without_logprobs is not None:
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
if plot_tensors is not None:
dev_config["plot_tensors"] = plot_tensors
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if scale_learning_rate_by_reward_std_dev is not None:
dev_config["scale_learning_rate_by_reward_std_dev"] = (
scale_learning_rate_by_reward_std_dev
)
if logprob_calculation_chunk_size is not None:
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
if num_trajectories_learning_rate_multiplier_power is not None:
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
num_trajectories_learning_rate_multiplier_power
)
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path

return config, dev_config


def aggregate_rl_training_metrics(
*,
training_metrics: list[dict[str, float]],
trajectory_groups: Iterable[TrajectoryGroup],
trainer_started: float,
) -> dict[str, float]:
groups_list = list(trajectory_groups)
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
)
return avg_metrics
87 changes: 35 additions & 52 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@
from mp_actors import close_proxy, move_to_child_process

from .. import dev
from .._backend_training import (
aggregate_rl_training_metrics,
build_rl_train_configs,
)
from ..backend import AnyTrainableModel, Backend
from ..costs import build_cost_calculator, get_model_pricing
from ..metrics_taxonomy import (
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
Expand Down Expand Up @@ -642,45 +645,36 @@ async def train( # type: ignore[override]
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")

# Build config objects from explicit kwargs
config = TrainConfig(
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": loss_fn == "ppo",
"precalculate_logprobs": precalculate_logprobs,
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
"scale_rewards": scale_rewards,
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
}
# Only include optional fields if they're set
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
elif kl_penalty_reference_step is not None:
ref_checkpoint_dir = get_step_checkpoint_dir(
resolved_kl_ref_adapter_path = kl_ref_adapter_path
if (
resolved_kl_ref_adapter_path is None
and kl_penalty_reference_step is not None
):
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path),
kl_penalty_reference_step,
)
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
config, dev_config = build_rl_train_configs(
learning_rate=learning_rate,
advantage_balance=advantage_balance,
scale_rewards=scale_rewards,
importance_sampling_level=importance_sampling_level,
mask_prob_ratio=mask_prob_ratio,
ppo=loss_fn == "ppo",
precalculate_logprobs=precalculate_logprobs,
epsilon=epsilon,
epsilon_high=epsilon_high,
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
)

# Collect metrics from training
training_metrics: list[dict[str, float]] = []
Expand All @@ -690,21 +684,10 @@ async def train( # type: ignore[override]
):
training_metrics.append(metrics)

# Aggregate metrics
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault(
"time/step_trainer_s", time.monotonic() - trainer_started
)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
avg_metrics = aggregate_rl_training_metrics(
training_metrics=training_metrics,
trajectory_groups=groups_list,
trainer_started=trainer_started,
)

# Get step and checkpoint path
Expand Down
31 changes: 31 additions & 0 deletions src/art/megatron/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Literal

from pydantic import BaseModel

from .. import dev, types
from ..preprocessing.pack import DiskPackedTensors

DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl"
DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs"
DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking"


class MegatronTrainingJob(BaseModel):
lora_path: str
optimizer_state_path: str
disk_packed_tensors: DiskPackedTensors
config: types.TrainConfig
experimental_config: dev.TrainConfig
log_path: str = DEFAULT_TRAINING_LOG_PATH


class MegatronSFTTrainingJob(BaseModel):
job_type: Literal["sft"] = "sft"
lora_path: str
optimizer_state_path: str
sft_data_dir: str
num_batches: int
learning_rates: list[float]
weight_decay: float = 0.0
max_grad_norm: float = 1.0
log_path: str = DEFAULT_TRAINING_LOG_PATH
15 changes: 15 additions & 0 deletions src/art/megatron/runtime_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os


def _set_cache_dir(env_var: str, default_path: str) -> None:
if not os.environ.get(env_var):
os.environ[env_var] = os.path.expanduser(default_path)
os.makedirs(os.environ[env_var], exist_ok=True)


def configure_megatron_runtime_env() -> None:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor")
_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache")
38 changes: 19 additions & 19 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass
import datetime
from functools import cached_property
import importlib
import json
import os
from pathlib import Path
Expand All @@ -10,9 +11,6 @@
from typing import Any, AsyncIterator

from peft.tuners.lora.config import LoraConfig
from pydantic import BaseModel
from safetensors import safe_open
from safetensors.torch import load_file, save_file
import torch
from vllm import AsyncEngineArgs
from vllm.lora.request import LoRARequest
Expand All @@ -26,16 +24,17 @@
from ..utils.get_model_step import get_step_from_dir
from ..utils.output_dirs import get_step_checkpoint_dir
from ..vllm import get_llm, openai_server_task, run_on_workers
from .jobs import (
DEFAULT_JOBS_DIR,
DEFAULT_TRAINING_LOG_PATH,
MegatronTrainingJob,
)


class MegatronTrainingJob(BaseModel):
"""Job format for communication with train.py"""

lora_path: str
optimizer_state_path: str
disk_packed_tensors: DiskPackedTensors
config: types.TrainConfig
experimental_config: dev.TrainConfig
safetensors = importlib.import_module("safetensors")
safetensors_torch = importlib.import_module("safetensors.torch")
safe_open = safetensors.safe_open
load_file = safetensors_torch.load_file
save_file = safetensors_torch.save_file


@dataclass
Expand Down Expand Up @@ -236,34 +235,35 @@ async def train(

self._optimizer_state_path = self._get_optimizer_state_path()

jobs_dir = "/tmp/megatron_training_jobs"
os.makedirs(jobs_dir, exist_ok=True)
for job_name in os.listdir(jobs_dir):
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
for job_name in os.listdir(DEFAULT_JOBS_DIR):
if job_name.endswith(".json"):
os.remove(os.path.join(jobs_dir, job_name))
os.remove(os.path.join(DEFAULT_JOBS_DIR, job_name))
job = MegatronTrainingJob(
lora_path=lora_path,
optimizer_state_path=self._optimizer_state_path,
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=_config,
)
job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json")
job_path = os.path.join(
DEFAULT_JOBS_DIR, f"{datetime.datetime.now().isoformat()}.json"
)
with open(job_path, "w") as f:
f.write(job.model_dump_json())

num_lines = 0
while True:
await asyncio.sleep(0.1)
try:
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:
with open(DEFAULT_TRAINING_LOG_PATH, "a+") as log_file:
log_file.seek(0)
lines = log_file.readlines()[num_lines:]
for line in lines:
if line := line.strip():
if line == "all done":
self._merge_lora_adapter(lora_path)
os.remove("/tmp/megatron_training_log.jsonl")
os.remove(DEFAULT_TRAINING_LOG_PATH)
break
num_lines += 1
yield json.loads(line)
Expand Down
Loading
Loading