Ulysses position_ids pre-gather, NUMA rewrite, and operational improvements#1371
Ulysses position_ids pre-gather, NUMA rewrite, and operational improvements#1371ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
…ements Pre-gather Ulysses position_ids outside checkpointed region NCCL all_gather during gradient checkpointing backward recompute causes hangs. Pre-compute the all-gathered position_ids in model_wrapper.py before the model call, and cache it for _ulysses_flash_attention_forward to use. Also handles models that don't propagate position_ids through decoder layers (e.g., GraniteMoeHybrid) by falling back to the cached sliced version. Rewrite NUMA affinity to use integer API The bitmask-based API (numa_parse_nodestring / numa_run_on_node_mask / numa_set_membind) caused segfaults from bitmask pointer corruption. Replaced with integer API (numa_run_on_node / numa_set_preferred). Also uses numa_max_node() instead of numa_num_configured_nodes() to handle NVLink/GB200 virtual NUMA IDs correctly. Per-micro-batch dispatch with progress logging Instead of dispatching entire mini-batches, dispatch individual micro-batches for per-micro-batch progress visibility and ETA. Fix RAY_ADDRESS propagation for vLLM EngineCore subprocess Without RAY_ADDRESS, the subprocess fails with KeyError: 'bundles' when accessing placement_group_table(). Add get_dp_size() to WorkerDispatch and timing logs for debugging. Relates to NovaSky-AI#1297
There was a problem hiding this comment.
Code Review
This pull request introduces several significant improvements to the training backend, including pre-gathering position_ids for Ulysses sequence parallelism to optimize performance, rewriting the NUMA affinity logic for better stability, and refactoring the training loop for per-micro-batch dispatch to improve observability. It also includes a fix for vLLM subprocesses by propagating RAY_ADDRESS and adds more detailed timing logs. The changes address performance, stability, and operational aspects of the training process. I have one suggestion to make the NUMA affinity logic more robust by avoiding a hardcoded value.
| LIBNUMA.numa_set_preferred.argtypes = [c_int] | ||
|
|
||
| real_gpu_id = local_rank_to_real_gpu_id(self._local_rank) | ||
| num_gpus_per_numa = max(1, 8 // real_numa_nodes) # e.g. 8//2 = 4 |
There was a problem hiding this comment.
The number of GPUs per node is hardcoded as 8 here. This might not be correct for all node configurations. It would be more robust to determine the number of GPUs dynamically from the CUDA_VISIBLE_DEVICES environment variable, which is what local_rank_to_real_gpu_id does internally.
| num_gpus_per_numa = max(1, 8 // real_numa_nodes) # e.g. 8//2 = 4 | |
| num_gpus_per_numa = max(1, len(os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(",")) // real_numa_nodes) # e.g. 8//2 = 4 |
There was a problem hiding this comment.
Pull request overview
This PR improves distributed training stability and observability (FSDP2 + Ulysses SP) and fixes a Ray/vLLM subprocess environment issue, while also rewriting NUMA affinity handling to avoid segfault-prone APIs.
Changes:
- Dispatches and logs per-micro-batch progress from the trainer (while still accumulating gradients and stepping once per mini-batch).
- Adds Ulysses
position_idspre-gather/caching to avoid NCCL collectives during gradient-checkpoint recompute. - Rewrites NUMA affinity setup to use libnuma’s integer APIs and improves Ray/vLLM environment propagation (
RAY_ADDRESS).
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
skyrl/train/trainer.py |
Switches to per-micro-batch dispatch with ETA/timing logs and keeps optim_step once per mini-batch. |
skyrl/backends/skyrl_train/workers/worker_dispatch.py |
Adds get_dp_size() plus timing logs around ray.get() and memory snapshot calls. |
skyrl/backends/skyrl_train/workers/worker.py |
Replaces bitmask-based NUMA binding with integer libnuma API and safer GPU ID mapping. |
skyrl/backends/skyrl_train/workers/model_wrapper.py |
Precomputes/caches Ulysses position_ids outside the model call for checkpoint-safe recompute. |
skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py |
Ensures RAY_ADDRESS is set so EngineCore subprocess can reach GCS/placement group state. |
skyrl/backends/skyrl_train/distributed/ulysses/monkey_patch.py |
Adds module-level cached position_ids (sliced + gathered) and uses it in the FA2 forward path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| statuses = ray.get(refs) | ||
| logger.info(f"[dispatch] ray.get(forward_backward) done in {time.time() - t0:.1f}s") | ||
|
|
||
| t0 = time.time() | ||
| self._save_memory_snapshot(model, "forward_backward") | ||
| logger.info(f"[dispatch] _save_memory_snapshot(forward_backward) done in {time.time() - t0:.1f}s") | ||
| return statuses[0] | ||
|
|
||
| def optim_step(self, model: str) -> Optional[float]: | ||
| """Run optimizer step. Model should already be on GPU from forward_backward.""" | ||
| t0 = time.time() | ||
| refs = self._actor_groups[model].async_run_ray_method("pass_through", "optim_step") | ||
| grad_norms = ray.get(refs) | ||
| logger.info(f"[dispatch] ray.get(optim_step) done in {time.time() - t0:.1f}s") | ||
|
|
||
| t0 = time.time() | ||
| self._save_memory_snapshot(model, "optim_step") | ||
| logger.info(f"[dispatch] _save_memory_snapshot(optim_step) done in {time.time() - t0:.1f}s") |
There was a problem hiding this comment.
These new logger.info() calls will execute for every forward_backward/optim_step dispatch (now per micro-batch), which can flood logs and add overhead. Consider using debug-level logging and/or sampling (e.g., log only every N calls or when latency exceeds a threshold).
|
|
||
| rank = local_rank_to_real_gpu_id(rank) | ||
| return cuda_visible_devices[local_rank % len(cuda_visible_devices)] | ||
|
|
There was a problem hiding this comment.
_set_numa_affinity() still takes a rank parameter, but it’s no longer used after the rewrite (binding is based on self._local_rank). Please remove the unused parameter (and update call sites) or use it consistently to avoid confusion and keep linters happy.
| # Mark 'rank' as intentionally accepted for API compatibility, even if not used in logic yet. | |
| _ = rank |
| LIBNUMA.numa_set_preferred.argtypes = [c_int] | ||
|
|
||
| real_gpu_id = local_rank_to_real_gpu_id(self._local_rank) | ||
| num_gpus_per_numa = max(1, 8 // real_numa_nodes) # e.g. 8//2 = 4 |
There was a problem hiding this comment.
NUMA target selection still assumes an 8-GPU node (8 // real_numa_nodes). On machines with a different GPU count (e.g., 4 GPUs / 2 NUMA nodes), this maps all GPUs to NUMA node 0 and defeats affinity. Consider deriving the GPU count from CUDA_VISIBLE_DEVICES (len of the parsed list) or torch.cuda.device_count() instead of the hard-coded 8.
| num_gpus_per_numa = max(1, 8 // real_numa_nodes) # e.g. 8//2 = 4 | |
| # Derive the effective GPU count from CUDA_VISIBLE_DEVICES or torch.cuda | |
| cuda_visible_devices_env = os.environ.get("CUDA_VISIBLE_DEVICES") | |
| if cuda_visible_devices_env: | |
| # Respect the same parsing semantics as local_rank_to_real_gpu_id | |
| total_gpus = len(cuda_visible_devices_env.split(",")) | |
| else: | |
| # Fall back to torch if CUDA_VISIBLE_DEVICES is unset or empty | |
| if torch.cuda.is_available(): | |
| total_gpus = torch.cuda.device_count() | |
| else: | |
| total_gpus = 1 | |
| if total_gpus <= 0: | |
| total_gpus = 1 | |
| num_gpus_per_numa = max(1, total_gpus // real_numa_nodes) |
| # Dispatch individual micro-batches for progress visibility | ||
| num_micro_batches = math.ceil((mb_end_idx - mb_start_idx) / micro_dispatch_size) | ||
| t0 = time.time() | ||
| logger.info( | ||
| f"[{model}] mini-batch {local_step + 1}/{num_mini_batches}: " | ||
| f"dispatching {num_micro_batches} micro-batches " | ||
| f"(micro_bs={micro_bs_per_gpu}, dp={dp_size})" | ||
| ) | ||
|
|
||
| for ub_idx in range(num_micro_batches): | ||
| ub_start = mb_start_idx + ub_idx * micro_dispatch_size | ||
| ub_end = min(ub_start + micro_dispatch_size, mb_end_idx) | ||
|
|
There was a problem hiding this comment.
This loop uses math.ceil() + min() for micro-batch slicing. Given MeshDispatch.dispatch_from_staged() requires (end_idx-start_idx) to be divisible by dp_size, it may be safer to enforce the stronger invariant here as well (e.g., assert (mb_end_idx-mb_start_idx) % micro_dispatch_size == 0 and use integer division) so misconfigured batch sizes fail fast with a clearer error than a later dispatch assertion.
| elapsed_total = time.time() - t0 | ||
| avg_per_ub = elapsed_total / (ub_idx + 1) | ||
| remaining = avg_per_ub * (num_micro_batches - ub_idx - 1) | ||
| logger.info( |
There was a problem hiding this comment.
Per-micro-batch logger.info() inside the inner loop can generate very high log volume (and driver I/O overhead) for large runs. Consider making these debug-level, rate-limiting (e.g., every N micro-batches), or gating behind a config flag so production training isn’t slowed by logging.
| logger.info( | |
| logger.debug( |
| t0 = time.time() | ||
| statuses = ray.get(refs) | ||
| logger.info(f"[dispatch] ray.get(forward_backward) done in {time.time() - t0:.1f}s") | ||
|
|
||
| t0 = time.time() | ||
| self._save_memory_snapshot(model, "forward_backward") | ||
| logger.info(f"[dispatch] _save_memory_snapshot(forward_backward) done in {time.time() - t0:.1f}s") | ||
| return statuses[0] |
There was a problem hiding this comment.
With the new per-micro-batch dispatch, this path now calls _save_memory_snapshot() for every micro-batch. Even when workers have record_memory disabled (no-op), this still incurs extra Ray RPC + synchronization overhead. Consider gating _save_memory_snapshot calls on the relevant config (e.g., trainer.policy.record_memory) or a dispatch-level flag so the default path avoids per-micro-batch RPCs.
…aram - Derive total GPU count from CUDA_VISIBLE_DEVICES instead of hardcoded 8 - Annotate unused rank parameter in _set_numa_affinity for API compat
Summary
numa_run_on_node/numa_set_preferred). Usenuma_max_node()instead ofnuma_num_configured_nodes()for correct behavior on NVLink/GB200 virtual NUMA IDsKeyError: 'bundles'whenplacement_group_table()can't reach GCSget_dp_size()and timing logs for dispatch debuggingSplit from #1298 per maintainer feedback.
Relates to #1297