Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 46 additions & 84 deletions examples/pytorch/quantized_model_init/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
local shards on each rank's GPU.
2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization
(actual quantization happens in ``reset_parameters`` after sharding).
3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
4. ``FusedAdam`` with FP32 master weights for full-precision training updates.
3. ``preserve_high_precision_init_val`` -- Keeps the original BF16 weight
values on CPU so they can seed the optimizer's FP32 master weights,
avoiding the precision loss of round-tripping through FP8.
4. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
5. ``FusedAdam`` with FP32 master weights for full-precision training updates.

.. note::
``fuse_wgrad_accumulation`` is **not** used here. That feature writes
Expand All @@ -38,10 +41,9 @@
from torch.distributed.tensor import DTensor

import transformer_engine.pytorch as te
from transformer_engine.pytorch import QuantizedTensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule

# ── Configuration (matches main.py) ──────────────────────────────────
# ── Configuration ────────────────────────────────────────────────────
HIDDEN_SIZE = 256
FFN_HIDDEN_SIZE = 1024
NUM_ATTENTION_HEADS = 8
Expand All @@ -60,10 +62,6 @@ def dist_print(msg):

def main():
# ── 1. Distributed setup ─────────────────────────────────────────
assert "TORCHELASTIC_RUN_ID" in os.environ, (
"This script must be launched with torchrun, e.g.:\n"
" torchrun --nproc-per-node 2 fully_shard.py"
)
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])

Expand All @@ -74,10 +72,12 @@ def main():
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# ── 2. Create model on meta device (zero memory) ────────────────
# quantized_model_init sets the flag for FP8 weight initialization,
# but with device="meta" no actual memory is allocated yet.
with te.quantized_model_init(enabled=True):
# ── 2. Create model on meta device (zero memory) ─────────────────
# quantized_model_init flags parameters for FP8 quantization.
# preserve_high_precision_init_val=True saves the original BF16
# values on CPU so they can seed optimizer master weights later,
# avoiding the precision loss of dequantizing from FP8.
with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True):
model = torch.nn.Sequential(
*[
te.TransformerLayer(
Expand All @@ -93,52 +93,51 @@ def main():
for _ in range(NUM_LAYERS)
]
)

# Verify all parameters are on meta device (no GPU memory used).
for name, param in model.named_parameters():
assert param.device == torch.device("meta"), f"{name} is not on meta device"
dist_print("Model created on meta device (zero GPU memory).")

# ── 3. FSDP2 sharding ───────────────────────────────────────────
# Apply sharding to the meta-device model. FSDP2 wraps parameters
# ── 3. FSDP2 sharding ───────────────────────────────────────────
# Apply sharding to the meta-device model. FSDP2 wraps parameters
# as DTensors but no GPU memory is allocated yet.
mesh = DeviceMesh("cuda", list(range(world_size)))
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
dist_print("FSDP2 sharding applied to meta-device model.")

# ── 4. Materialize parameters on GPU ─────────────────────────────
# ── 4. Materialize parameters on GPU ─────────────────────────────
# reset_parameters() on each TE module materializes the local shard
# on CUDA, applies weight initialization, and quantizes to FP8.
# Because preserve_high_precision_init_val=True, the pre-quantization
# BF16 values are saved on CPU for each local shard.
for module in model.modules():
if isinstance(module, TransformerEngineBaseModule):
module.reset_parameters()
dist_print("Parameters materialized on GPU.")

# Post-materialization verification.
for name, param in model.named_parameters():
assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding"
qt_count = sum(
1
for _, p in model.named_parameters()
if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor)
)
assert qt_count > 0, "No QuantizedTensor local tensors after materialization"
dist_print(
f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params "
"wrapped in DTensors."
)

# ── 5. Optimizer ─────────────────────────────────────────────────
# ── 5. Optimizer with FP32 master weights ────────────────────────
optimizer = te.optimizers.FusedAdam(
model.parameters(),
lr=1e-3,
master_weights=True,
master_weight_dtype=torch.float32,
)
dist_print("Using FusedAdam with master_weights=True.")

# ── 6. Training loop ─────────────────────────────────────────────
# ── 6. Seed master weights from high-precision init values ───────
# By default, FusedAdam initializes master weights by dequantizing
# the FP8 parameters, which introduces quantization noise. Instead,
# we seed them from the original BF16 init values preserved in step 2.
for param in model.parameters():
optimizer.initialize_state(param, store_param_remainders=False)
local = param._local_tensor if isinstance(param, DTensor) else param
hp_val = getattr(local, "get_high_precision_init_val", lambda: None)()
if hp_val is not None:
optimizer.set_scaled_state(
param, "master_param", hp_val.to(device=device, dtype=torch.float32)
)
local.clear_high_precision_init_val()
dist_print("Optimizer master weights seeded from high-precision init values.")

# ── 7. Training loop ─────────────────────────────────────────────
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)

Expand All @@ -153,56 +152,22 @@ def main():
optimizer.step()
dist_print(f" Step {step}: loss = {loss.item():.6f}")

# ── 7. Post-training assertions ──────────────────────────────────
dist_print("\nVerifying invariants ...")

qt_after = 0
for name, param in model.named_parameters():
assert isinstance(param, DTensor), f"{name} lost DTensor wrapping"
if isinstance(param._local_tensor, QuantizedTensor):
qt_after += 1
assert qt_after > 0, "No QuantizedTensor local tensors after training"
dist_print(f" {qt_after} params still have QuantizedTensor local tensors.")

# Optimizer states: master weights and moments should be float32.
for param in model.parameters():
state = optimizer.state[param]
if "master_param" in state:
assert (
state["master_param"].dtype == torch.float32
), f"Master weight dtype {state['master_param'].dtype}, expected float32"
assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32"
assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32"

dist_print("All assertions passed!")
dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor")
dist_print(" - Optimizer master weights: float32")
dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32")

# ── 8. Distributed checkpoint: save and load ─────────────────────
# torch.distributed.checkpoint (DCP) saves sharded state — each rank
# writes only its local shard. This preserves FP8 compute weights
# and the full optimizer state (master weights, moments, step count).
# writes only its local shard, preserving FP8 compute weights and
# the full optimizer state (master weights, moments, step count).
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
)

# Use a fixed path so all ranks agree on the checkpoint location.
checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint"
dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...")

# Save sharded checkpoint. DCP handles DTensor shards natively —
# each rank writes only its local shard to the filesystem.
dcp.save(
{"model": model.state_dict(), "optimizer": optimizer.state_dict()},
checkpoint_id=checkpoint_dir,
)
dist_print(" Checkpoint saved (FP8 weights + optimizer state).")

# Load checkpoint back. Provide empty state dict containers with the
# Load checkpoint back. Provide empty state dict containers with the
# same structure; DCP fills them from the saved files.
state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
Expand All @@ -225,6 +190,11 @@ def main():
# authoritative FP32 values (more precise than dequantizing FP8).
# All ranks must participate in gathering; only rank 0 saves.
from safetensors.torch import save_file
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
)

full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True)

Expand All @@ -238,10 +208,10 @@ def main():

for key, value in full_model_state.items():
if key in opt_param_states and "master_param" in opt_param_states[key]:
# Prefer optimizer's FP32 master weight (maintained throughout training).
# Prefer optimizer's FP32 master weight.
fp32_state[key] = opt_param_states[key]["master_param"].float()
elif isinstance(value, QuantizedTensor):
# Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off).
elif isinstance(value, te.QuantizedTensor):
# Fallback: dequantize FP8 → FP32.
fp32_state[key] = value.dequantize().float()
else:
# Non-FP8 params (e.g. LayerNorm weights): cast to FP32.
Expand All @@ -251,14 +221,6 @@ def main():
save_file(fp32_state, save_path)
dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}")

# Quick verification: all saved tensors are float32.
from safetensors.torch import load_file

loaded = load_file(save_path)
for k, v in loaded.items():
assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}"
dist_print(f" Verified: all {len(loaded)} tensors are float32.")

dist.destroy_process_group()


Expand Down
Loading