[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389
[tinker] Support PPO loss with Tinker and add critic model in SkyRLTrainBackend#1389tamoghnokandar wants to merge 3 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request significantly refactors the SkyRL-Train backend to support both policy and critic models, primarily for PPO. It introduces a model_role concept, updating create_model methods, data structures (_model_ids, _model_metadata), and core functionalities like forward_backward, forward, optim_step, and checkpointing to handle distinct roles. The ppo_critic loss function is added, along with corresponding input fields (values, returns) in API types and batch preparation. The Jax backend is updated to enforce that it only supports 'policy' models and raises errors for 'ppo_critic' loss. Review comments suggest extracting duplicated validation logic for model_role and loss_fn into a helper method and refactoring the loss_fn_outputs construction for better efficiency and clarity.
I am having trouble creating individual review comments. Click here to see my feedback.
skyrl/backends/skyrl_train_backend.py (470-473)
This validation logic for role and loss_fn is duplicated in the forward method (lines 549-552). To improve maintainability and avoid code duplication, consider extracting this logic into a private helper method. For example:
def _validate_batch_role_and_loss(self, role: str, loss_fn: str):
if role == "critic" and loss_fn != "ppo_critic":
raise ValueError(f"Critic batches must use loss_fn='ppo_critic', got {loss_fn!r}")
if role != "critic" and loss_fn == "ppo_critic":
raise ValueError("loss_fn='ppo_critic' is only valid for critic models")You could then call self._validate_batch_role_and_loss(role, loss_fn) in both forward_backward and forward methods.
skyrl/backends/skyrl_train_backend.py (517-531)
The current implementation for constructing loss_fn_outputs is a bit inefficient and could be clearer. It initializes loss_fn_outputs on line 517, and then potentially re-initializes it as an empty list on line 519 if "loss_fn_outputs" is in data.
This can be refactored to be more direct and avoid the unnecessary list creation.
if "loss_fn_outputs" in data:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
raw_output = data["loss_fn_outputs"][i]
formatted_output = {}
for key in ("elementwise_loss", "logprobs", "values"):
values = list(raw_output.get(key, []))
if values or key in raw_output:
formatted_output[key] = {
"data": values,
"dtype": "float32",
"shape": [len(values)],
}
loss_fn_outputs.append(formatted_output)
else:
loss_fn_outputs = [{} for _ in range(end_idx - start_idx)]| model_id = next(iter(unique_models)) | ||
| if self._get_role(model_id) != "policy": |
There was a problem hiding this comment.
🟡 sample() crashes with ValueError on base model sampling (empty model_id)
When a base model sampling request arrives with model_id="" (the standard way to sample from the base model without a LoRA adapter), self._get_role("") at skyrl/backends/skyrl_train_backend.py:636 raises a ValueError because the empty string is not a key in self._model_ids. The old code handled this gracefully via unique_models != {self._model_id}, returning a proper ErrorResponse for each request. The new code lets the exception propagate up to the engine's process_batch_requests catch-all handler (skyrl/tinker/engine.py:682), which returns a less informative error message ("Model not found") for the entire batch.
| model_id = next(iter(unique_models)) | |
| if self._get_role(model_id) != "policy": | |
| model_id = next(iter(unique_models)) | |
| if not model_id or self._get_role(model_id) != "policy": |
Was this helpful? React with 👍 or 👎 to provide feedback.
Fixes the first issue of #1380.
Summary
LossFnInputswith optionalvaluesandreturnsfields, and addppo_criticloss typeSkyrlTrainBackendfrom single-model to multi-model registry (model_id → role), supporting both policyand critic actor groups
create_modelwithmodel_role="critic", sharing the policy's base model withindependent training
register_actor_group()andset_algorithm_config()toWorkerDispatchfor dynamic critic registrationTest plan
LossFnInputsfields (values,returns) andDatum.to_types()conversionprepare_model_pass_batchwithppo_criticloss typeppo_criticloss