Skip to content

Commit a43031e

Browse files
tohtanastas00
andauthored
Add comment explaining nesting torch.autocast (#1000)
* Add comment explaining outer torch.autocast in bf16_master_weight example The outer autocast covers loss_fn which runs outside engine.forward(). The nested autocast on the model forward is harmless. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> * Update training/bf16_master_weight/train.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
1 parent ece52bc commit a43031e

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

training/bf16_master_weight/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,11 @@ def main():
292292
input_ids = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device)
293293
labels = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device)
294294

295-
# Forward pass with optional autocast
295+
# Forward pass with an optional autocast.
296+
# DeepSpeed already applies torch.autocast inside engine.forward(), but
297+
# we wrap the entire forward+loss block so that loss_fn also runs under
298+
# autocast. The nested autocast on engine.forward() is harmless —
299+
# PyTorch's torch.autocast is idempotent when nested with the same dtype.
296300
if use_autocast:
297301
with torch.autocast(device_type="cuda", dtype=autocast_dtype):
298302
logits = model_engine(input_ids)

0 commit comments

Comments
 (0)