Skip to content

Monitoring improvements#425

Merged
le1nux merged 45 commits intomainfrom
monitoring_improvements
Feb 11, 2026
Merged

Monitoring improvements#425
le1nux merged 45 commits intomainfrom
monitoring_improvements

Conversation

@le1nux
Copy link
Member

@le1nux le1nux commented Dec 6, 2025

What does this PR do?

This PR adds multiple changes at the same time.

  • Configurable multi-layer FSDP units
  • Option to provide experiment root path to modalities
  • Added steppable profiler (e.g., for tracing of forward/backward passes)
  • Fix: Hybrid sharding now correctly configurable
  • Completely refactored the Profiling
  • Improved error handling. Errors are now captured and stored as JSON
  • Add tutorials on Einsum Transformer (Example model integration) and profiling

General Changes

  • ..

Breaking Changes

  • ..

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

le1nux added 30 commits December 6, 2025 12:30
@le1nux le1nux marked this pull request as ready for review February 10, 2026 09:45
@le1nux le1nux requested a review from BlueCrescent February 10, 2026 14:35
Comment on lines +103 to +109
def __exit__(self, exc_type, exc_value, traceback):
if self._curr_step is None:
raise RuntimeError("SteppableMemoryProfilerContext exited without being entered")
if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps:
# if we exit before finishing all steps, dump the memory snapshot
raise RuntimeError("SteppableMemoryProfilerContext exited before finishing all steps")
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log the error if this exit is reached by an exception?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not supressing the error and it's still propagated up in the call stack. I would leave error handling to the caller.

Comment on lines +127 to +128
with open(self._memory_snapshot_path, "wb") as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with open(self._memory_snapshot_path, "wb") as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
torch.cuda.memory._dump_snapshot(self._memory_snapshot_path)

Comment on lines +37 to +39
self.dp_degree = get_parallel_degree(
device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we encode the information that these two form the data parallel degree somewhere more globally (e.g. in a get_data_parallel_degree(device_mesh) function at an appropriate place or use components.settings.step_profile.dp_degree)?

Comment on lines -66 to +67
return new_path
return original_path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Comment on lines +146 to +187
@overload
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Forward pass of the GPT2LLM module.

Args:
inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
- sample_key (str): Key for the input tensor containing token ids.

Returns:
dict[str, torch.Tensor]: A dictionary containing output tensors.
- prediction_key (str): Key for the output tensor containing logits.
"""
...

@overload
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the module.

Args:
inputs (torch.Tensor): A tensor containing input token ids.

Returns:
torch.Tensor: A tensor containing output logits.
"""
...

def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor:
"""
Forward pass of the module.

Args:
inputs (dict[str, torch.Tensor] | torch.Tensor): Input data.

Returns:
dict[str, torch.Tensor] | torch.Tensor: Model output.
"""
if isinstance(inputs, dict):
return {self.prediction_key: self.forward_impl(inputs[self.sample_key])}
else:
return self.forward_impl(inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should consider moving this code into NNModel so that you have the option to derive from that class and do not need to copy this stuff when adding a new model.

**What’s inside**
- `train.py`: registers the custom model and launches the run.
- `einsum_transformer_config.yaml`: training + model config.
- `run.sh`: example `torchrun` command for 8 GPUs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `run.sh`: example `torchrun` command for 8 GPUs.
- `run.sh`: example `torchrun` command for 4 GPUs.

@le1nux le1nux merged commit e462f57 into main Feb 11, 2026
3 checks passed
@le1nux le1nux deleted the monitoring_improvements branch February 11, 2026 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants