Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False):
]
kimi_k2_tau: float | None
kl_penalty_coef: float
kl_penalty_source: Literal["current_learner", "sample"]
kl_ref_adapter_path: str | None
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
Expand Down
12 changes: 11 additions & 1 deletion src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ async def train( # type: ignore[override]
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_ref_adapter_path: str | None = None,
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner",
epsilon: float | None = None,
epsilon_high: float | None = None,
# Advantage computation
Expand Down Expand Up @@ -590,6 +591,11 @@ async def train( # type: ignore[override]
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
checkpoint to use as the KL reference. Alternative to
kl_penalty_reference_step.
kl_penalty_source: Which policy's logprobs to compare against the
reference when building the centered KL penalty. Use
"current_learner" to match the original ART implementation, or
"sample" to shape from the rollout policy logprobs, which is
usually better for async/off-policy workloads.
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
advantage_balance: Balance between negative and positive advantages
Expand Down Expand Up @@ -641,16 +647,20 @@ async def train( # type: ignore[override]
raise ValueError("LocalBackend requires normalize_advantages=True.")
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")
assert kl_penalty_source in {"current_learner", "sample"}

# Build config objects from explicit kwargs
config = TrainConfig(
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
learning_rate=learning_rate,
kl_penalty_coef=kl_penalty_coef,
kl_penalty_source=kl_penalty_source,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"kl_penalty_source": kl_penalty_source,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": loss_fn == "ppo",
Expand Down
9 changes: 8 additions & 1 deletion src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,14 @@ def loss_fn(
kl_policy_ref: torch.Tensor | None = None
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
if kl_penalty_coef > 0 and ref_logprobs is not None:
kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask
match experimental_config.get("kl_penalty_source", "current_learner"):
case "sample":
kl_source_logprobs = old_logprobs.detach()
case "current_learner":
kl_source_logprobs = new_logprobs.detach()
case other:
raise AssertionError(other)
kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
advantages = advantages + kl_penalty
Expand Down
13 changes: 13 additions & 0 deletions src/art/pipeline_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(
loss_fn_config: dict | None = None,
normalize_advantages: bool = True,
adam_params: object | None = None,
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
max_steps: int | None = None,
# Discard handling
discard_queue_multiplier: int = 100,
Expand Down Expand Up @@ -129,6 +131,8 @@ def __init__(
self.loss_fn_config = loss_fn_config
self.normalize_advantages = normalize_advantages
self.adam_params = adam_params
self.kl_penalty_coef = kl_penalty_coef
self.kl_penalty_reference_step = kl_penalty_reference_step
self.max_steps = max_steps
self._status_log_interval_seconds = log_interval_seconds
self.eval_every_n_steps = eval_every_n_steps
Expand Down Expand Up @@ -452,6 +456,14 @@ async def _training_stage(self) -> None:
if os.getenv("ART_TRAIN_STEP_LOG"):
print(f"[train] step {expected_step} starting (batch={len(batch)})")
try:
kl_train_kwargs: dict[str, object] = {}
if self.kl_penalty_coef > 0.0:
kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef
kl_train_kwargs["kl_penalty_source"] = "sample"
if self.kl_penalty_reference_step is not None:
kl_train_kwargs["kl_penalty_reference_step"] = (
self.kl_penalty_reference_step
)
result = await self.backend.train(
self.model,
batch,
Expand All @@ -461,6 +473,7 @@ async def _training_stage(self) -> None:
normalize_advantages=self.normalize_advantages,
save_checkpoint=should_checkpoint,
adam_params=self.adam_params,
**kl_train_kwargs,
)
except Exception:
self._status.note_training_end()
Expand Down
49 changes: 48 additions & 1 deletion src/art/test/test_kl_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from art.loss import Loss, loss_fn
from art.loss import loss_fn, shift_tensor


def _make_inputs(
Expand Down Expand Up @@ -114,3 +114,50 @@ def test_kl_advantage_does_not_affect_when_no_ref():

loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5})
assert loss.kl_policy_ref is None


def test_kl_advantage_can_use_sample_logprobs() -> None:
"""Sample-source KL should use stored rollout logprobs rather than learner logprobs."""
inputs = _make_inputs(seq_len=8)
inputs["logprobs"] = torch.tensor(
[[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32
)
new_logprobs = torch.tensor(
[[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32
)
ref_logprobs = torch.full((1, 8), -0.5)
assistant_mask = shift_tensor(inputs["assistant_mask"], False).to(
new_logprobs.dtype
)
sampled_logprobs = torch.where(
torch.isnan(shift_tensor(inputs["logprobs"], float("nan"))),
new_logprobs.detach(),
shift_tensor(inputs["logprobs"], float("nan")),
)
expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)

sample_loss = loss_fn(
inputs,
new_logprobs,
ref_logprobs,
None,
{"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"},
)
learner_loss = loss_fn(
inputs,
new_logprobs,
ref_logprobs,
None,
{"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"},
)

assert sample_loss.kl_policy_ref is not None
assert learner_loss.kl_policy_ref is not None
assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl)
assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl)
assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref)
118 changes: 118 additions & 0 deletions src/art/tinker_native/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from openai.types.chat.completion_create_params import CompletionCreateParams
from openai.types.completion_usage import CompletionUsage
import tinker
import torch
import uvicorn

from art.tinker.cookbook_v import renderers, tokenizer_utils
Expand Down Expand Up @@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str:
return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric)


async def _apply_kl_penalty(
datums: list[tinker.Datum],
reference_sampling_client: tinker.SamplingClient,
kl_penalty_coef: float,
) -> dict[str, float]:
assert datums
assert kl_penalty_coef > 0.0

full_sequences: list[tinker.ModelInput] = []
sampled_logprobs_by_datum: list[torch.Tensor] = []
masks_by_datum: list[torch.Tensor] = []
advantages_by_datum: list[torch.Tensor] = []
for datum in datums:
target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch()
assert target_tokens.numel() > 0
full_sequences.append(
datum.model_input.append_int(int(target_tokens[-1].item()))
)
sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch())
masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float())
advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch())

reference_logprobs_by_datum = await asyncio.gather(
*[
reference_sampling_client.compute_logprobs_async(full_sequence)
for full_sequence in full_sequences
]
)

logprob_diffs_by_datum: list[torch.Tensor] = []
for reference_logprobs, sampled_logprobs, mask in zip(
reference_logprobs_by_datum,
sampled_logprobs_by_datum,
masks_by_datum,
strict=True,
):
reference_values = reference_logprobs[1:]
assert len(reference_values) == sampled_logprobs.numel()
assert all(value is not None for value in reference_values)
reference_logprobs_tensor = torch.tensor(
reference_values,
dtype=sampled_logprobs.dtype,
)
logprob_diffs_by_datum.append(
(sampled_logprobs - reference_logprobs_tensor) * mask
)

total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum()
assert total_tokens.item() > 0
avg_logprob_diff = (
torch.stack(
[logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum]
).sum()
/ total_tokens
)

for datum, advantages, mask, logprob_diff in zip(
datums,
advantages_by_datum,
masks_by_datum,
logprob_diffs_by_datum,
strict=True,
):
datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch(
advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask
)

return {"loss/kl_policy_ref": float(avg_logprob_diff)}


@dataclass
class ModelState:
service_client: tinker.ServiceClient
Expand Down Expand Up @@ -239,7 +310,14 @@ async def train( # type: ignore[override]
save_checkpoint: bool = False,
loss_fn_config: dict | None = None,
adam_params: tinker.AdamParams | None = None,
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_penalty_source: Literal["sample"] = "sample",
) -> TrainResult:
assert kl_penalty_source == "sample", (
"TinkerNativeBackend only supports kl_penalty_source='sample'."
)

state = self._model_state[model.name]
groups_list = list(trajectory_groups)
summary = summarize_trajectory_groups(groups_list)
Expand Down Expand Up @@ -272,6 +350,23 @@ async def train( # type: ignore[override]
train_tokens, pricing
)
trainer_started = time.monotonic()
sampled_kl_policy_ref: float | None = None

if kl_penalty_coef > 0:
kl_metrics = await self._tinker_sample_call(
"apply_kl_penalty",
_apply_kl_penalty(
datums,
await self._get_kl_reference_sampling_client(
state,
model.base_model,
kl_penalty_reference_step,
),
kl_penalty_coef,
),
)
sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"]
metrics.update(kl_metrics)

if adam_params is None:
adam_params = tinker.AdamParams(
Expand Down Expand Up @@ -310,13 +405,23 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
if value is None:
continue
canonical_key = _canonicalize_upstream_metric_key(key)
if (
sampled_kl_policy_ref is not None
and canonical_key == "loss/kl_policy_ref"
):
continue
if canonical_key:
metrics[canonical_key] = float(value)
if optim_output.metrics:
for key, value in optim_output.metrics.items():
if value is None:
continue
canonical_key = _canonicalize_upstream_metric_key(key)
if (
sampled_kl_policy_ref is not None
and canonical_key == "loss/kl_policy_ref"
):
continue
if canonical_key:
metrics[canonical_key] = float(value)

Expand Down Expand Up @@ -697,6 +802,19 @@ async def _get_sampler_client(
state.sampler_clients[actual_step] = sampler_client
return sampler_client

async def _get_kl_reference_sampling_client(
self,
state: ModelState,
base_model: str,
step: int | None,
) -> tinker.SamplingClient:
if step is not None:
return await self._get_sampler_client(state, step)
return await self._tinker_sample_call(
"create_sampling_client_async",
state.service_client.create_sampling_client_async(base_model=base_model),
)

def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for message in messages:
Expand Down
1 change: 1 addition & 0 deletions src/art/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class TrainConfig(pydantic.BaseModel):
learning_rate: float = 5e-6
kl_penalty_coef: float = 0.0
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner"


class TrainSFTConfig(pydantic.BaseModel):
Expand Down
23 changes: 23 additions & 0 deletions tests/integration/test_pipeline_localbackend_dedicated.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Dedicated LocalBackend smoke test for PipelineTrainer."""

import asyncio
import json
import math
import os
from pathlib import Path
import tempfile
import uuid

Expand Down Expand Up @@ -163,6 +166,8 @@ async def rollout_fn(
min_batch_size=1,
max_batch_size=1,
max_steps=2,
kl_penalty_coef=0.25,
kl_penalty_reference_step=0,
loss_fn="cispo",
eval_fn=None,
)
Expand All @@ -180,5 +185,23 @@ async def rollout_fn(
model_ids = [m.id async for m in client.models.list()]
assert f"{model.name}@0" in model_ids
assert f"{model.name}@{latest_step}" in model_ids

history_path = (
Path(tmpdir)
/ model.project
/ "models"
/ model.name
/ "history.jsonl"
)
history_rows = [
json.loads(line) for line in history_path.read_text().splitlines()
]
kl_values = [
row["loss/kl_policy_ref"]
for row in history_rows
if "loss/kl_policy_ref" in row
]
assert kl_values
assert all(math.isfinite(value) for value in kl_values)
finally:
await client.close()
Loading
Loading