Skip to content

[train][2/N] Support for Megatron PP + CP for R3#1335

Merged
erictang000 merged 39 commits intomainfrom
r3-pp-cp
Mar 26, 2026
Merged

[train][2/N] Support for Megatron PP + CP for R3#1335
erictang000 merged 39 commits intomainfrom
r3-pp-cp

Conversation

@devpatelio
Copy link
Copy Markdown
Collaborator

@devpatelio devpatelio commented Mar 17, 2026

Summary

Extending #1273, this PR provides support for pipeline parallelism and context parallelism for R3. See #815 for tracking of future tasks to fully support routing replay in all settings.

Implementation

Pipeline Parallelism
For pipeline parallelism, we create a helper function _get_current_pp_stage_layer_range(model_config) which maps the current PP rank and its layers to the global layer offset across all the model layers so that we can use this offset to correctly select the corresponding replay instances from a RouterReplay.global_router_replay_instances.

First, we get the number of pipeline stages from PP world size along with the total number of model layers. For models containing dense layers / unequal pipeline stages, megatron supports setting a customer number of layers for the first and last PP rank. Then, we capture these values from the model config and check to see if the remaining number of layers can be evenly distributed across the remaining PP ranks. Finally, we return the transformer-layer range owned by the current PP rank as s_p, n_p, where:

  • s_p is the global starting layer index for rank p
  • n_p is the number of transformer layers assigned to that stage

For an even partition with L total layers and P pipeline stages:

  • next_n_pp_layers = L // P, start_index = next_n_pp_layers * pp_rank
  • the offset should thus span (next_n_pp_layers * pp_rank) : (next_n_pp_layers * (pp_rank+1)

For uneven partitioning, if the first and/or last stages are assigned custom layer counts, we subtract those from $L$, split the remaining layers evenly among the remaining stages, and then shift the start index accordingly. This means we can support cases like Moonlight-16B models which have 27 layers, where we can pass num_layers_in_first_pipeline_stage as 13 for PP=2.

Context Parallelism

When using sample_packing, our megatron worker pre-processes and post-processes packed sequences. When CP is enabled, it is split into CP*2 chunks, so each effective GPU gets 2 CP chunks of half the size. See NVIDIA/TransformerEngine#1368. To account for this extra chunking, the setup_per_microbatch_replay_forward method is updated to so that the effective_seq_len accounts for cp_size * 2 (same as the alignment in preprocess_packed_seqs in megatron_utils.py) along with the seqlen_per_cp as seqlen_per_cp // 2. We then index the front and back halves of these CP chunks from the aligned indices across the CP ranks and then concatenate them. This ensures that the router replay indices see the correct tokens from this CP chunking for megatron.

Testing

You can test with CP and/or PP configs from the test_router_replay file.


Open with Devin

@devpatelio devpatelio changed the title replay utils update [train][2/N] Support for Megatron PP + CP for R3 Mar 17, 2026
@devpatelio devpatelio requested a review from erictang000 March 20, 2026 21:08
@devpatelio devpatelio marked this pull request as ready for review March 20, 2026 21:08
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Megatron's pipeline parallelism (PP) and context parallelism (CP) to the router replay (R3) feature. The changes are well-implemented, introducing a new helper function to manage layer ranges in PP and updating index packing logic for CP. The test suite has been significantly enhanced with parameterized tests that cover various parallelism configurations using real data, which is a great improvement. I've found a minor configuration issue in the new test script that should be addressed.

Comment thread examples/train/router_replay/router_replay_fully_async.sh Outdated
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=false \
trainer.strategy=fsdp2 \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line sets trainer.strategy=fsdp2, but it's later overridden by trainer.strategy=megatron on line 76. Since this script is for a Megatron-based run, this line is redundant and could cause confusion. It should be removed.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
devin-ai-integration[bot]

This comment was marked as resolved.

…dices

Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 8 additional findings in Devin Review.

Open in Devin Review

Comment on lines +377 to +379
self.enable_router_replay = transformer_config_kwargs.get(
"moe_enable_routing_replay", megatron_config.moe_enable_routing_replay
)
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Validation for moe_enable_routing_replay doesn't check transformer_config_kwargs, creating a gap with the worker's logic

The PR changes megatron_worker.py:377-379 to determine self.enable_router_replay from transformer_config_kwargs.get("moe_enable_routing_replay", megatron_config.moe_enable_routing_replay), but the validation in skyrl/train/utils/utils.py:205 only checks cfg.trainer.policy.megatron_config.moe_enable_routing_replay (the top-level config field). If a user sets moe_enable_routing_replay=True only in transformer_config_kwargs (as the tests do at tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py:237-239), the worker enables replay but the validation skips the assertion that enable_return_routed_experts must also be True. This could lead to router replay being silently ineffective (since rollout_expert_indices would be None without enable_return_routed_experts).

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=false \
trainer.strategy=fsdp2 \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Duplicate trainer.strategy setting in example shell script — first value (fsdp2) is a leftover

Line 64 sets trainer.strategy=fsdp2 and line 76 sets trainer.strategy=megatron. The second override wins, but the first is clearly a leftover from copying the async example script. Since the rest of the script configures Megatron-specific parameters (TP/PP/CP/EP), the fsdp2 value on line 64 is incorrect and misleading. A user who copies only the first section of this script would get the wrong strategy.

Suggested change
trainer.strategy=fsdp2 \
trainer.strategy=megatron \
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

SumanthRH and others added 13 commits March 20, 2026 23:09
…awn` (#1333)

# What does this PR do?

Adds `skyrl.utils.worker_setup.worker_setup_fn` which sets the
multiprocessing start method to 'spawn', and wires it into `ray.init`
via the `worker_process_setup_hook`runtime_env key. Includes unit tests
verifying the hook is applied in Ray workers.

We've previously seen many examples where ray <> fork interact in weird
ways. There's code in the `skyrl_entrypoint` task (torch dataloader) as
well as in other ray worker processes (ex: megatron workers) that rely
on multiprocessing. Using worker setup hook provides us with a
consistent way to handle this for all Ray worker processes.

Test plan:

- [x]  CPU tests
- [x] One GPU test that uses ray : `test_model_wrapper.py`
- [x] New CPU tests for worker setup function

<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1333"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…e; Use the correct GSM8k path for `test_generator_multi_turn_gsm8k_router_replay` (#1339)
…ensors be right-aligned (#1285)

## Summary

Fix cross-sample padding inflation in
`convert_prompts_responses_to_batch_tensors` by replacing the
two-segment padding layout with a unified left-padded layout.

**Before:** sequences padded to `max(prompt_lens) + max(response_lens)`.
When the longest prompt and longest response come from different samples
(common in step-wise training), this can approach **2x max_seq_len**.

**After:** sequences padded to `max(prompt_i + response_i)` — the
tightest bound that preserves every real token. **No tokens are
truncated.**

## Problem

The old layout pads prompts and responses as independent segments:

```
| [PAD..PAD PROMPT] | [RESPONSE PAD..PAD] |
|<-- max_prompt --->|<--- max_response -->|
```

In step-wise training, prompt and response lengths are anti-correlated
across steps:
- Early turns: short prompt (5K), long response (50K)
- Late turns: long prompt (55K), short response (4K)

The padded `seq_len = 55K + 50K = 105K`, far exceeding the actual
`max_seq_len = 60K`. With 61,440 step-samples, this inflates the
`sequences` tensor from ~75 GB to ~103 GB (for max_seq_len=80K) — pure
padding waste.

## Solution

Eliminate the two-segment layout. Each sequence is now a single
left-padded block:

```
| [PAD..PAD  PROMPT  RESPONSE] |
|<------- max_total ---------->|
```

Where `max_total = max(prompt_i + response_i)`. The response is always
at the end of the sequence, so the existing model forward pass slicing
(`log_probs[:, -num_actions-1:-1]`) works unchanged.

### Response data alignment change: left-aligned → right-aligned

Because response tokens are now at the **end** of each sequence (with
variable-length prompts before them), the response logprobs extracted by
the model are **right-aligned** within the `(batch, max_response)`
slice. Response-level tensors (action_mask, rewards, loss_masks,
rollout_logprobs) are correspondingly right-aligned to match:

```
Old (left-aligned):  [resp_tok, resp_tok, resp_tok, PAD, PAD]
New (right-aligned): [PAD, PAD, resp_tok, resp_tok, resp_tok]
```

All downstream consumers use masked operations (`masked_mean(loss *
loss_mask, loss_mask)`, `scores.unsqueeze(-1) * response_mask`, etc.)
which are alignment-agnostic. The `loss_fn_outputs` extraction for the
Tinker API path uses `action_mask.sum()` + `[:valid_len]` which would
need a follow-up adjustment for that specific code path (currently not
used in the standard RL training loop — it's popped at
`trainer.py:1088`).

## Changes

| File | Change |
|------|--------|
| `skyrl/train/dataset/preprocess.py` | Unified left-pad layout,
right-aligned response data, optional `max_seq_len` warning |
| `skyrl/train/trainer.py` | Pass `max_seq_len` from config to padding
function |
| `tests/train/dataset/test_preprocess.py` | 8 new tests for unified
layout, right-alignment, anti-correlation, no-mutation, backward compat
|

## Test plan

- [x] All 12 unit tests pass (4 existing updated + 8 new)
- [x] Verify step-wise training run produces same loss curves
(right-alignment changes tensor layout but not masked loss values)
- [x] Verify non-step-wise training is unaffected (max_total =
max_prompt + max_response when not anti-correlated)

## Curves

GSM8K CI run:
<img width="1431" height="278" alt="image"
src="https://github.com/user-attachments/assets/5cc4ea54-f6df-498d-ae7e-c3cf243610fa"
/>


https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/reports/Untitled-Report--VmlldzoxNjE3MTA0OQ

Ran with 8xH100s 
- Baseline from previous PRs (blue) -- without TIS
- Non-step wise search r1 ran with (red) -- with TIS
```bash
USE_CONVERSATION_MULTI_TURN=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1
```
- Step-wise search r1 ran with (brown) -- with TIS
```bash
USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1
```

<img width="559" height="252" alt="image"
src="https://github.com/user-attachments/assets/0a287d07-5f5f-471e-a02e-570905ad468a"
/>

The step-wise training time is much worse as of now (roughly 4x, scales
with num turns), and hopefully can be improved after
#1277

<img width="839" height="500" alt="image"
src="https://github.com/user-attachments/assets/2d1cbd65-1a25-4e7e-8c60-d5b221a97800"
/>

Gsm8k + Megatron CI run (purple is with this PR after rebasing)
https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k_ci_megatron/runs/uoaga4uz?nw=nwusercharlieruan
<img width="1464" height="271" alt="image"
src="https://github.com/user-attachments/assets/6129afb9-1550-465b-92f8-d8d12063142c"
/>
---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rt method to `spawn`" (#1344)

Reverts #1333

Fixes #1342 and #1343 . It looks like we hit the same issue as
ray-project/ray#61350 when dealing with worker
process setup hook and vllm with the ray backend.

The long term fix is actually in the ray repo - the bug has been fixed
in ray-project/ray#61473 and we should be able
to make use of the setup hook after upgrading to the next ray release.
Until then, I've just reverted the changes and added `spawn` for the mp
context for our dataloader

I did a quick smoke test by running the gsm8k example and the script
enters the first step successfully
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1344"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Alternatively we do not care the step order and not use the `cumsum`
trick in advantage broadcast

### Summary

Previously, `validate_generator_output()` was **skipped entirely** when
`step_wise_trajectories=True`:

```python
if not self.cfg.generator.step_wise_trajectories:
    validate_generator_output(len(input_batch["prompts"]), generator_output)
```

This meant step-wise generator outputs had no validation at all —
malformed `is_last_step`, missing `trajectory_ids`, or non-contiguous
trajectory ordering would silently produce wrong training results.

The non-contiguous case is particularly dangerous: the trainer's
advantage broadcast uses a `cumsum` trick that assumes all steps of the
same trajectory are adjacent in the batch. If steps are interleaved
across trajectories, advantages are silently mapped to the wrong steps
with no error.

### Changes

**`skyrl/train/utils/trainer_utils.py`**
- Added `step_wise: bool = False` parameter to
`validate_generator_output()` (backward compatible — existing callers
are unaffected)
- Extracted `_validate_step_wise_fields()` for step-wise specific
checks:
  - `is_last_step` and `trajectory_ids` are present and correctly sized
- `is_last_step[-1]` is `True` (last sample must be a trajectory's final
step)
- **Contiguous ordering**: all steps of the same trajectory are adjacent
(catches the silent `cumsum` bug)
- **Boundary alignment**: `is_last_step[i]` is `True` wherever (and only
when) `trajectory_ids` changes between consecutive samples
- In step-wise mode, `num_prompts != num_responses` is allowed (step
expansion is expected)

**`skyrl/train/trainer.py`**
- Changed from skipping validation to calling with `step_wise=True`:
```python
validate_generator_output(
    len(input_batch["prompts"]),
    generator_output,
    step_wise=self.cfg.generator.step_wise_trajectories,
)
```

**`tests/train/test_trainer_utils.py`**
- 9 new tests covering all step-wise validation cases

### Test plan

- [x] `pytest tests/train/test_trainer_utils.py` — all 44 tests pass (35
existing + 9 new)
- [x] Existing non-step-wise validation tests unaffected (backward
compatible `step_wise=False` default)
- [x] New tests cover: valid output, single-step trajectories, missing
fields, length mismatches, non-contiguous ordering, boundary
misalignment, all-False `is_last_step`

### E2E test

Ran the multi-turn gsm8k example E2E. Made sure it is indeed multi-turn
since `generate/batch_num_seq` is ~6800 rather than 2560 (512 * 5)
```bash
  # Run training (script defaults to 1 GPU, override for 8 GPU + step-wise multi-turn)                             
  bash examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh \
    generator.step_wise_trajectories=true \
    generator.use_conversation_multi_turn=true \
    generator.max_turns=5 \
```
…PP-collective caches with bucketing (#1345)

## Summary
- Bucketed weight sync reused `WeightConversionTask` objects (and their
mapping caches) across sync cycles, causing incorrect vLLM weight
updates for DeepSeek-V3 style models (like Moonlight-16B-A3B, or
GLM-4.7-Flash) with PP > 1 (for both LoRA and full finetuning). The
mapping objects cache PP-collective metadata that becomes stale across
train/offload/reload cycles.
- Fix: store only the bucket index structure (which task indices go in
which bucket) once, and rebuild fresh tasks with clean mapping objects
on each `extract_weights()` call. This preserves packed-broadcast
performance while ensuring correct PP collectives every sync.

This manifested in extremely unstable training + reward collapsing for
Deepseek-v3 style models with megatron.

## Test plan
- [x] Moonlight-16B with PP=2: reward increases without significant
weight sync time regression

## Results

### Moonlight-16B-A3B GSM8k
Before in purple, after the fix in blue:
<img width="315" height="246" alt="image"
src="https://github.com/user-attachments/assets/9a7119af-aea8-4b81-888f-f05cd3865c99"
/>
<img width="318" height="242" alt="image"
src="https://github.com/user-attachments/assets/b899766f-9a7e-43fc-98f9-2f856dc04c3c"
/>

Weight sync timing (~10s after, ~8s before):
<img width="317" height="246" alt="image"
src="https://github.com/user-attachments/assets/ed46cdbc-9b95-4533-bb5f-416db56a2847"
/>

### GLM-4.7-Flash GSM8k

Before in red, after the fix in tan (GLM with LoRA)
<img width="337" height="525" alt="image"
src="https://github.com/user-attachments/assets/e647d6f4-42a9-4317-99d4-e3d1d940b038"
/>

Weight sync timing (~15 after, ~12 before):
<img width="331" height="239" alt="image"
src="https://github.com/user-attachments/assets/cc98f4e9-314c-462c-ab7f-80b29bc96f70"
/>



<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1345"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Fixes comment
#1281 (comment)

There was no good reason to only account for `is_last_step` since
response IDs are not cumulative in step-wise training (unlike input
tokens, which are cumulative)
#1346)

# What does this PR do?

Fixes CI failure on main for the `SkyRL-Train-CPU` workflow:
https://github.com/NovaSky-AI/SkyRL/actions/runs/23273262330/job/67670625938

After #1344 , we added `multiprocessing_context='spawn'` to the
`build_dataloader` function. It looks like there was one case where the
change here affected a test that was not affected by the usage of
`worker_process_startup_hook` previously. A CPU test
`test_dataloader_seeding` referenced a local dataset class in dataloader
map function. After switch to `spawn`, we need to ensure that the
dataset class is a global variable.
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1346"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
# What does this PR do?

Adds LoRA support in the new inference codepath.

## Summary
- Adds a new `/v1/load_lora_adapter` endpoint to `VLLMServerActor` to
support loading lora weights in the vllm server
- Adds support for lora weights in `RemoteInferenceClient` with a new
method `load_lora_from_disk`. The caller is expected to provide a
`lora_path` argument pointing to the lora weights on disk.
- Improved concurrency in `generate` : While running E2E tests for lora
in the new codepath, I noticed that the generation speeds where 10x
worse in the new codepath compared to the old `inference_engines/`
codepath in SkyRL. The root cause is that the currently we use a single
semaphore for the generate + detokenize stage. This led to detokenize
stage starving GPUs when max concurrency was reached. the solution is to
use separate semaphores so that generation can proceed independently.
Generation speed is now at parity with old codepath
- Adds lora weight sync test to CI 

## Test Plan
- [x] Unit tests: `_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra
dev --extra fsdp pytest -s
tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py`
- [x]  E2E lora test


## Results

I ran the lora example
`examples/train/lora/run_qwen2_5_0.5b_gsm8k_grpo_lora.sh` with
`_SKYRL_USE_NEW_INFERENCE=1` and compared it with the baseline

Some curves are below: 

<img width="455" height="299" alt="Screenshot 2026-03-18 at 11 17 10 PM"
src="https://github.com/user-attachments/assets/3ec1be26-6357-4eb0-a833-49366f19d135"
/>

<img width="448" height="291" alt="Screenshot 2026-03-18 at 11 17 23 PM"
src="https://github.com/user-attachments/assets/f6c8f836-0cc0-4f46-87a9-d286d7d54f57"
/>

<img width="676" height="576" alt="image"
src="https://github.com/user-attachments/assets/55126ef2-f1a3-48f4-9de2-83f7323d8fee"
/>


<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1329"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…_kl_loss=True) (#1353)

Fixes #1340
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1353"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments + questions

particularly a little confused about how it works for MoE models with dense layers before moe layers and PP together.

Comment thread tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py Outdated
Comment thread tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py
Comment thread tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py Outdated
Comment thread tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py Outdated
Comment thread tests/backends/skyrl_train/gpu/gpu_ci/test_router_replay.py
Comment thread skyrl/train/utils/utils.py Outdated
Comment thread skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py Outdated
Comment thread skyrl/backends/skyrl_train/utils/replay_utils.py Outdated
Comment thread skyrl/backends/skyrl_train/utils/replay_utils.py Outdated
Comment thread skyrl/backends/skyrl_train/utils/replay_utils.py
@erictang000
Copy link
Copy Markdown
Collaborator

screenshot from DMs:

image

Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work! looks a lot cleaner, tested it out and everything is working as expected

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 13 additional findings in Devin Review.

Open in Devin Review

Comment thread skyrl/backends/skyrl_train/utils/replay_utils.py
@erictang000 erictang000 merged commit 135bf67 into main Mar 26, 2026
6 checks passed
@erictang000 erictang000 deleted the r3-pp-cp branch March 26, 2026 22:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants