diff --git a/src/art/dev/train.py b/src/art/dev/train.py index b0e232c5..5da3e1ab 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False): ] kimi_k2_tau: float | None kl_penalty_coef: float + kl_penalty_source: Literal["current_learner", "sample"] kl_ref_adapter_path: str | None logprob_calculation_chunk_size: int mask_prob_ratio: bool diff --git a/src/art/local/backend.py b/src/art/local/backend.py index ad743757..35f162d4 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -537,6 +537,7 @@ async def train( # type: ignore[override] kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner", epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -590,6 +591,11 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. + kl_penalty_source: Which policy's logprobs to compare against the + reference when building the centered KL penalty. Use + "current_learner" to match the original ART implementation, or + "sample" to shape from the rollout policy logprobs, which is + usually better for async/off-policy workloads. epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages @@ -641,16 +647,20 @@ async def train( # type: ignore[override] raise ValueError("LocalBackend requires normalize_advantages=True.") if adam_params is not None: raise ValueError("LocalBackend requires adam_params=None.") + assert kl_penalty_source in {"current_learner", "sample"} # Build config objects from explicit kwargs config = TrainConfig( - learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef + learning_rate=learning_rate, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_source=kl_penalty_source, ) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, "allow_training_without_logprobs": allow_training_without_logprobs, "importance_sampling_level": importance_sampling_level, "kl_penalty_coef": kl_penalty_coef, + "kl_penalty_source": kl_penalty_source, "mask_prob_ratio": mask_prob_ratio, "plot_tensors": plot_tensors, "ppo": loss_fn == "ppo", diff --git a/src/art/loss.py b/src/art/loss.py index 5a73d7b7..59cfa46a 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -95,7 +95,14 @@ def loss_fn( kl_policy_ref: torch.Tensor | None = None kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0) if kl_penalty_coef > 0 and ref_logprobs is not None: - kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask + match experimental_config.get("kl_penalty_source", "current_learner"): + case "sample": + kl_source_logprobs = old_logprobs.detach() + case "current_learner": + kl_source_logprobs = new_logprobs.detach() + case other: + raise AssertionError(other) + kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6) kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask advantages = advantages + kl_penalty diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 302cbe78..a50e6d57 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -78,6 +78,8 @@ def __init__( loss_fn_config: dict | None = None, normalize_advantages: bool = True, adam_params: object | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, max_steps: int | None = None, # Discard handling discard_queue_multiplier: int = 100, @@ -129,6 +131,8 @@ def __init__( self.loss_fn_config = loss_fn_config self.normalize_advantages = normalize_advantages self.adam_params = adam_params + self.kl_penalty_coef = kl_penalty_coef + self.kl_penalty_reference_step = kl_penalty_reference_step self.max_steps = max_steps self._status_log_interval_seconds = log_interval_seconds self.eval_every_n_steps = eval_every_n_steps @@ -452,6 +456,14 @@ async def _training_stage(self) -> None: if os.getenv("ART_TRAIN_STEP_LOG"): print(f"[train] step {expected_step} starting (batch={len(batch)})") try: + kl_train_kwargs: dict[str, object] = {} + if self.kl_penalty_coef > 0.0: + kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef + kl_train_kwargs["kl_penalty_source"] = "sample" + if self.kl_penalty_reference_step is not None: + kl_train_kwargs["kl_penalty_reference_step"] = ( + self.kl_penalty_reference_step + ) result = await self.backend.train( self.model, batch, @@ -461,6 +473,7 @@ async def _training_stage(self) -> None: normalize_advantages=self.normalize_advantages, save_checkpoint=should_checkpoint, adam_params=self.adam_params, + **kl_train_kwargs, ) except Exception: self._status.note_training_end() diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py index d944efc6..82c0f2a2 100644 --- a/src/art/test/test_kl_advantage.py +++ b/src/art/test/test_kl_advantage.py @@ -2,7 +2,7 @@ import torch -from art.loss import Loss, loss_fn +from art.loss import loss_fn, shift_tensor def _make_inputs( @@ -114,3 +114,50 @@ def test_kl_advantage_does_not_affect_when_no_ref(): loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5}) assert loss.kl_policy_ref is None + + +def test_kl_advantage_can_use_sample_logprobs() -> None: + """Sample-source KL should use stored rollout logprobs rather than learner logprobs.""" + inputs = _make_inputs(seq_len=8) + inputs["logprobs"] = torch.tensor( + [[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32 + ) + new_logprobs = torch.tensor( + [[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32 + ) + ref_logprobs = torch.full((1, 8), -0.5) + assistant_mask = shift_tensor(inputs["assistant_mask"], False).to( + new_logprobs.dtype + ) + sampled_logprobs = torch.where( + torch.isnan(shift_tensor(inputs["logprobs"], float("nan"))), + new_logprobs.detach(), + shift_tensor(inputs["logprobs"], float("nan")), + ) + expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + + sample_loss = loss_fn( + inputs, + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"}, + ) + learner_loss = loss_fn( + inputs, + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"}, + ) + + assert sample_loss.kl_policy_ref is not None + assert learner_loss.kl_policy_ref is not None + assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl) + assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl) + assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref) diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index c1687bf7..65f59ca6 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -24,6 +24,7 @@ from openai.types.chat.completion_create_params import CompletionCreateParams from openai.types.completion_usage import CompletionUsage import tinker +import torch import uvicorn from art.tinker.cookbook_v import renderers, tokenizer_utils @@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str: return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric) +async def _apply_kl_penalty( + datums: list[tinker.Datum], + reference_sampling_client: tinker.SamplingClient, + kl_penalty_coef: float, +) -> dict[str, float]: + assert datums + assert kl_penalty_coef > 0.0 + + full_sequences: list[tinker.ModelInput] = [] + sampled_logprobs_by_datum: list[torch.Tensor] = [] + masks_by_datum: list[torch.Tensor] = [] + advantages_by_datum: list[torch.Tensor] = [] + for datum in datums: + target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch() + assert target_tokens.numel() > 0 + full_sequences.append( + datum.model_input.append_int(int(target_tokens[-1].item())) + ) + sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch()) + masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float()) + advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch()) + + reference_logprobs_by_datum = await asyncio.gather( + *[ + reference_sampling_client.compute_logprobs_async(full_sequence) + for full_sequence in full_sequences + ] + ) + + logprob_diffs_by_datum: list[torch.Tensor] = [] + for reference_logprobs, sampled_logprobs, mask in zip( + reference_logprobs_by_datum, + sampled_logprobs_by_datum, + masks_by_datum, + strict=True, + ): + reference_values = reference_logprobs[1:] + assert len(reference_values) == sampled_logprobs.numel() + assert all(value is not None for value in reference_values) + reference_logprobs_tensor = torch.tensor( + reference_values, + dtype=sampled_logprobs.dtype, + ) + logprob_diffs_by_datum.append( + (sampled_logprobs - reference_logprobs_tensor) * mask + ) + + total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum() + assert total_tokens.item() > 0 + avg_logprob_diff = ( + torch.stack( + [logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum] + ).sum() + / total_tokens + ) + + for datum, advantages, mask, logprob_diff in zip( + datums, + advantages_by_datum, + masks_by_datum, + logprob_diffs_by_datum, + strict=True, + ): + datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch( + advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask + ) + + return {"loss/kl_policy_ref": float(avg_logprob_diff)} + + @dataclass class ModelState: service_client: tinker.ServiceClient @@ -239,7 +310,14 @@ async def train( # type: ignore[override] save_checkpoint: bool = False, loss_fn_config: dict | None = None, adam_params: tinker.AdamParams | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_penalty_source: Literal["sample"] = "sample", ) -> TrainResult: + assert kl_penalty_source == "sample", ( + "TinkerNativeBackend only supports kl_penalty_source='sample'." + ) + state = self._model_state[model.name] groups_list = list(trajectory_groups) summary = summarize_trajectory_groups(groups_list) @@ -272,6 +350,23 @@ async def train( # type: ignore[override] train_tokens, pricing ) trainer_started = time.monotonic() + sampled_kl_policy_ref: float | None = None + + if kl_penalty_coef > 0: + kl_metrics = await self._tinker_sample_call( + "apply_kl_penalty", + _apply_kl_penalty( + datums, + await self._get_kl_reference_sampling_client( + state, + model.base_model, + kl_penalty_reference_step, + ), + kl_penalty_coef, + ), + ) + sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"] + metrics.update(kl_metrics) if adam_params is None: adam_params = tinker.AdamParams( @@ -310,6 +405,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) if optim_output.metrics: @@ -317,6 +417,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) @@ -697,6 +802,19 @@ async def _get_sampler_client( state.sampler_clients[actual_step] = sampler_client return sampler_client + async def _get_kl_reference_sampling_client( + self, + state: ModelState, + base_model: str, + step: int | None, + ) -> tinker.SamplingClient: + if step is not None: + return await self._get_sampler_client(state, step) + return await self._tinker_sample_call( + "create_sampling_client_async", + state.service_client.create_sampling_client_async(base_model=base_model), + ) + def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] for message in messages: diff --git a/src/art/types.py b/src/art/types.py index 088041ad..317fc156 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -17,6 +17,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 kl_penalty_coef: float = 0.0 + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner" class TrainSFTConfig(pydantic.BaseModel): diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py index d6d04bc7..11fab51d 100644 --- a/tests/integration/test_pipeline_localbackend_dedicated.py +++ b/tests/integration/test_pipeline_localbackend_dedicated.py @@ -1,7 +1,10 @@ """Dedicated LocalBackend smoke test for PipelineTrainer.""" import asyncio +import json +import math import os +from pathlib import Path import tempfile import uuid @@ -163,6 +166,8 @@ async def rollout_fn( min_batch_size=1, max_batch_size=1, max_steps=2, + kl_penalty_coef=0.25, + kl_penalty_reference_step=0, loss_fn="cispo", eval_fn=None, ) @@ -180,5 +185,23 @@ async def rollout_fn( model_ids = [m.id async for m in client.models.list()] assert f"{model.name}@0" in model_ids assert f"{model.name}@{latest_step}" in model_ids + + history_path = ( + Path(tmpdir) + / model.project + / "models" + / model.name + / "history.jsonl" + ) + history_rows = [ + json.loads(line) for line in history_path.read_text().splitlines() + ] + kl_values = [ + row["loss/kl_policy_ref"] + for row in history_rows + if "loss/kl_policy_ref" in row + ] + assert kl_values + assert all(math.isfinite(value) for value in kl_values) finally: await client.close() diff --git a/tests/integration/test_tinker_native_backend.py b/tests/integration/test_tinker_native_backend.py index 09ff33c4..4e5c61a5 100644 --- a/tests/integration/test_tinker_native_backend.py +++ b/tests/integration/test_tinker_native_backend.py @@ -9,6 +9,8 @@ import art from art.tinker_native import TinkerNativeBackend +from art.tinker_native.backend import _apply_kl_penalty +from art.tinker_native.data import trajectory_groups_to_datums DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507" @@ -37,6 +39,8 @@ async def simple_rollout( max_tokens=10, timeout=60, temperature=1, + logprobs=True, + top_logprobs=0, ) choice = chat_completion.choices[0] content = (choice.message.content or "").lower() @@ -115,6 +119,85 @@ async def make_group(prompt: str) -> art.TrajectoryGroup: await backend.close() +@pytest.mark.skipif( + "TINKER_API_KEY" not in os.environ, + reason="TINKER_API_KEY not set - skipping TinkerNativeBackend KL test", +) +async def test_tinker_native_backend_kl_identity_metric(): + model_name = f"test-tinker-native-kl-{uuid.uuid4().hex[:8]}" + with tempfile.TemporaryDirectory() as tmpdir: + backend = TinkerNativeBackend(path=tmpdir) + model = art.TrainableModel( + name=model_name, + project="integration-tests", + base_model=get_base_model(), + ) + try: + await model.register(backend) + + openai_client = model.openai_client() + current_step = await model.get_step() + model_name_step = model.get_inference_name(step=current_step) + prompts = ["Say yes", "Say no", "Say maybe"] + + async def make_group(prompt: str) -> art.TrajectoryGroup: + import asyncio + + trajectories = await asyncio.gather( + *[ + simple_rollout(openai_client, model_name_step, prompt) + for _ in range(2) + ] + ) + return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined] + + train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined] + [make_group(prompt) for prompt in prompts] + ) + ensure_reward_variance(train_groups) + + state = backend._model_state[model.name] + datums = trajectory_groups_to_datums( + train_groups, + state.renderer, + state.tokenizer, + ) + assert datums + + reference_sampling_client = await backend._get_kl_reference_sampling_client( + state, + model.base_model, + current_step, + ) + expected_kl = ( + await _apply_kl_penalty( + trajectory_groups_to_datums( + train_groups, + state.renderer, + state.tokenizer, + ), + reference_sampling_client, + kl_penalty_coef=0.25, + ) + )["loss/kl_policy_ref"] + + result = await backend.train( + model, + train_groups, + learning_rate=1e-5, + kl_penalty_coef=0.25, + kl_penalty_reference_step=current_step, + ) + + assert result.metrics["loss/kl_policy_ref"] == pytest.approx( + expected_kl, + abs=0.05, + ) + assert result.metrics["loss/kl_policy_ref"] == pytest.approx(0.0, abs=0.05) + finally: + await backend.close() + + @pytest.mark.skipif( "TINKER_API_KEY" not in os.environ, reason="TINKER_API_KEY not set - skipping TinkerNativeBackend fork test", diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index e63fdb59..7219e55a 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -88,6 +88,44 @@ async def test_pipeline_trainer_preserves_backend_train_kwargs(tmp_path: Path) - } +@pytest.mark.asyncio +async def test_pipeline_trainer_forwards_kl_kwargs_for_generic_backend( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-generic-backend-kl-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_reference_step=7, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs == { + "learning_rate": 1e-5, + "loss_fn": "cispo", + "loss_fn_config": None, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": None, + "kl_penalty_coef": 0.25, + "kl_penalty_reference_step": 7, + "kl_penalty_source": "sample", + } + + @pytest.mark.asyncio async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( tmp_path: Path, @@ -165,6 +203,45 @@ async def fake_train_model( assert seen["dev_config"]["ppo"] is True +@pytest.mark.asyncio +async def test_local_backend_train_passes_kl_penalty_source(tmp_path: Path) -> None: + model = TrainableModel( + name="local-backend-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group([1.0])], + kl_penalty_coef=0.25, + kl_penalty_source="sample", + save_checkpoint=False, + ) + + assert result.step == 1 + assert seen["config"].kl_penalty_source == "sample" + assert seen["dev_config"]["kl_penalty_source"] == "sample" + + @pytest.mark.asyncio async def test_local_backend_async_context_manager_awaits_async_cleanup( tmp_path: Path, diff --git a/tests/unit/test_tinker_native_kl.py b/tests/unit/test_tinker_native_kl.py new file mode 100644 index 00000000..a2d16d01 --- /dev/null +++ b/tests/unit/test_tinker_native_kl.py @@ -0,0 +1,77 @@ +import pytest +import tinker + +from art import TrainableModel +from art.tinker_native.backend import TinkerNativeBackend, _apply_kl_penalty +from art.tinker_native.data import build_datum + + +class FakeSamplingClient(tinker.SamplingClient): + def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None: + self._responses = responses + + async def compute_logprobs_async( + self, prompt: tinker.ModelInput + ) -> list[float | None]: + return self._responses[tuple(prompt.to_ints())] + + +@pytest.mark.asyncio +async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None: + datum_a = build_datum( + prompt_tokens=[101, 102], + completion_tokens=[201, 202], + logprobs=[-0.4, -0.8], + advantage=1.0, + ) + datum_b = build_datum( + prompt_tokens=[301, 302], + completion_tokens=[401], + logprobs=[-0.2], + advantage=2.0, + ) + assert datum_a is not None + assert datum_b is not None + + sampling_client = FakeSamplingClient( + { + (101, 102, 201, 202): [None, -9.0, -0.1, -0.5], + (301, 302, 401): [None, -7.0, -0.05], + } + ) + + metrics = await _apply_kl_penalty( + [datum_a, datum_b], + sampling_client, + kl_penalty_coef=2.0, + ) + + assert metrics == {"loss/kl_policy_ref": pytest.approx(-0.25)} + assert datum_a.loss_fn_inputs["advantages"].tolist() == pytest.approx( + [0.0, 1.1, 1.1] + ) + assert datum_b.loss_fn_inputs["advantages"].tolist() == pytest.approx([0.0, 1.8]) + + +@pytest.mark.asyncio +async def test_tinker_native_backend_rejects_current_learner_kl_source( + tmp_path, +) -> None: + backend = TinkerNativeBackend(tinker_api_key="test-key", path=str(tmp_path)) + model = TrainableModel( + name="tinker-native-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + + with pytest.raises( + AssertionError, + match="only supports kl_penalty_source='sample'", + ): + await backend.train( + model, + [], + kl_penalty_coef=0.25, + kl_penalty_source="current_learner", # ty:ignore[invalid-argument-type] + )