Conversation
… time measurement
…tivatinos and weights
Einsum transformer
| def __exit__(self, exc_type, exc_value, traceback): | ||
| if self._curr_step is None: | ||
| raise RuntimeError("SteppableMemoryProfilerContext exited without being entered") | ||
| if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: | ||
| # if we exit before finishing all steps, dump the memory snapshot | ||
| raise RuntimeError("SteppableMemoryProfilerContext exited before finishing all steps") | ||
| return |
There was a problem hiding this comment.
Should we log the error if this exit is reached by an exception?
There was a problem hiding this comment.
It's not supressing the error and it's still propagated up in the call stack. I would leave error handling to the caller.
| with open(self._memory_snapshot_path, "wb") as output: | ||
| pickle.dump(torch.cuda.memory._snapshot(), output) |
There was a problem hiding this comment.
| with open(self._memory_snapshot_path, "wb") as output: | |
| pickle.dump(torch.cuda.memory._snapshot(), output) | |
| torch.cuda.memory._dump_snapshot(self._memory_snapshot_path) |
| self.dp_degree = get_parallel_degree( | ||
| device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD] | ||
| ) |
There was a problem hiding this comment.
Could we encode the information that these two form the data parallel degree somewhere more globally (e.g. in a get_data_parallel_degree(device_mesh) function at an appropriate place or use components.settings.step_profile.dp_degree)?
| return new_path | ||
| return original_path |
| @overload | ||
| def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Forward pass of the GPT2LLM module. | ||
|
|
||
| Args: | ||
| inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. | ||
| - sample_key (str): Key for the input tensor containing token ids. | ||
|
|
||
| Returns: | ||
| dict[str, torch.Tensor]: A dictionary containing output tensors. | ||
| - prediction_key (str): Key for the output tensor containing logits. | ||
| """ | ||
| ... | ||
|
|
||
| @overload | ||
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Forward pass of the module. | ||
|
|
||
| Args: | ||
| inputs (torch.Tensor): A tensor containing input token ids. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A tensor containing output logits. | ||
| """ | ||
| ... | ||
|
|
||
| def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: | ||
| """ | ||
| Forward pass of the module. | ||
|
|
||
| Args: | ||
| inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. | ||
|
|
||
| Returns: | ||
| dict[str, torch.Tensor] | torch.Tensor: Model output. | ||
| """ | ||
| if isinstance(inputs, dict): | ||
| return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} | ||
| else: | ||
| return self.forward_impl(inputs) |
There was a problem hiding this comment.
Maybe we should consider moving this code into NNModel so that you have the option to derive from that class and do not need to copy this stuff when adding a new model.
| **What’s inside** | ||
| - `train.py`: registers the custom model and launches the run. | ||
| - `einsum_transformer_config.yaml`: training + model config. | ||
| - `run.sh`: example `torchrun` command for 8 GPUs. |
There was a problem hiding this comment.
| - `run.sh`: example `torchrun` command for 8 GPUs. | |
| - `run.sh`: example `torchrun` command for 4 GPUs. |
What does this PR do?
This PR adds multiple changes at the same time.
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)