Skip to content

[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389

Open
tamoghnokandar wants to merge 3 commits intoNovaSky-AI:mainfrom
tamoghnokandar:add_logprobs
Open

[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389
tamoghnokandar wants to merge 3 commits intoNovaSky-AI:mainfrom
tamoghnokandar:add_logprobs

Conversation

@tamoghnokandar
Copy link
Contributor

@tamoghnokandar tamoghnokandar commented Mar 25, 2026

Fixes the first issue of #1380.

Summary

  • Add critic model support to the Tinker backend, enabling actor-critic PPO training through the Tinker API
  • Extend LossFnInputs with optional values and returns fields, and add ppo_critic loss type
  • Refactor SkyrlTrainBackend from single-model to multi-model registry (model_id → role), supporting both policy
    and critic actor groups
  • Critic models are created via create_model with model_role="critic", sharing the policy's base model with
    independent training
  • Add register_actor_group() and set_algorithm_config() to WorkerDispatch for dynamic critic registration

Test plan

  • Unit tests for new LossFnInputs fields (values, returns) and Datum.to_types() conversion
  • Unit tests for prepare_model_pass_batch with ppo_critic loss type
  • Unit tests verifying JAX backend rejects critic role and ppo_critic loss
  • Verify existing policy-only workflows remain unaffected (values/returns default to empty)

Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the SkyRL-Train backend to support both policy and critic models, primarily for PPO. It introduces a model_role concept, updating create_model methods, data structures (_model_ids, _model_metadata), and core functionalities like forward_backward, forward, optim_step, and checkpointing to handle distinct roles. The ppo_critic loss function is added, along with corresponding input fields (values, returns) in API types and batch preparation. The Jax backend is updated to enforce that it only supports 'policy' models and raises errors for 'ppo_critic' loss. Review comments suggest extracting duplicated validation logic for model_role and loss_fn into a helper method and refactoring the loss_fn_outputs construction for better efficiency and clarity.

I am having trouble creating individual review comments. Click here to see my feedback.

skyrl/backends/skyrl_train_backend.py (470-473)

medium

This validation logic for role and loss_fn is duplicated in the forward method (lines 549-552). To improve maintainability and avoid code duplication, consider extracting this logic into a private helper method. For example:

def _validate_batch_role_and_loss(self, role: str, loss_fn: str):
    if role == "critic" and loss_fn != "ppo_critic":
        raise ValueError(f"Critic batches must use loss_fn='ppo_critic', got {loss_fn!r}")
    if role != "critic" and loss_fn == "ppo_critic":
        raise ValueError("loss_fn='ppo_critic' is only valid for critic models")

You could then call self._validate_batch_role_and_loss(role, loss_fn) in both forward_backward and forward methods.

skyrl/backends/skyrl_train_backend.py (517-531)

medium

The current implementation for constructing loss_fn_outputs is a bit inefficient and could be clearer. It initializes loss_fn_outputs on line 517, and then potentially re-initializes it as an empty list on line 519 if "loss_fn_outputs" is in data.

This can be refactored to be more direct and avoid the unnecessary list creation.

            if "loss_fn_outputs" in data:
                loss_fn_outputs = []
                for i in range(start_idx, end_idx):
                    raw_output = data["loss_fn_outputs"][i]
                    formatted_output = {}
                    for key in ("elementwise_loss", "logprobs", "values"):
                        values = list(raw_output.get(key, []))
                        if values or key in raw_output:
                            formatted_output[key] = {
                                "data": values,
                                "dtype": "float32",
                                "shape": [len(values)],
                            }
                    loss_fn_outputs.append(formatted_output)
            else:
                loss_fn_outputs = [{} for _ in range(end_idx - start_idx)]

devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 7 additional findings in Devin Review.

Open in Devin Review

Comment on lines +635 to +636
model_id = next(iter(unique_models))
if self._get_role(model_id) != "policy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 sample() crashes with ValueError on base model sampling (empty model_id)

When a base model sampling request arrives with model_id="" (the standard way to sample from the base model without a LoRA adapter), self._get_role("") at skyrl/backends/skyrl_train_backend.py:636 raises a ValueError because the empty string is not a key in self._model_ids. The old code handled this gracefully via unique_models != {self._model_id}, returning a proper ErrorResponse for each request. The new code lets the exception propagate up to the engine's process_batch_requests catch-all handler (skyrl/tinker/engine.py:682), which returns a less informative error message ("Model not found") for the entire batch.

Suggested change
model_id = next(iter(unique_models))
if self._get_role(model_id) != "policy":
model_id = next(iter(unique_models))
if not model_id or self._get_role(model_id) != "policy":
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant