diff --git a/examples/pytorch/quantized_model_init/fully_shard.py b/examples/pytorch/quantized_model_init/fully_shard.py index 6131712001..c1d97579b1 100644 --- a/examples/pytorch/quantized_model_init/fully_shard.py +++ b/examples/pytorch/quantized_model_init/fully_shard.py @@ -13,8 +13,11 @@ local shards on each rank's GPU. 2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization (actual quantization happens in ``reset_parameters`` after sharding). -3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. -4. ``FusedAdam`` with FP32 master weights for full-precision training updates. +3. ``preserve_high_precision_init_val`` -- Keeps the original BF16 weight + values on CPU so they can seed the optimizer's FP32 master weights, + avoiding the precision loss of round-tripping through FP8. +4. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. +5. ``FusedAdam`` with FP32 master weights for full-precision training updates. .. note:: ``fuse_wgrad_accumulation`` is **not** used here. That feature writes @@ -38,10 +41,9 @@ from torch.distributed.tensor import DTensor import transformer_engine.pytorch as te -from transformer_engine.pytorch import QuantizedTensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule -# ── Configuration (matches main.py) ────────────────────────────────── +# ── Configuration ──────────────────────────────────────────────────── HIDDEN_SIZE = 256 FFN_HIDDEN_SIZE = 1024 NUM_ATTENTION_HEADS = 8 @@ -60,10 +62,6 @@ def dist_print(msg): def main(): # ── 1. Distributed setup ───────────────────────────────────────── - assert "TORCHELASTIC_RUN_ID" in os.environ, ( - "This script must be launched with torchrun, e.g.:\n" - " torchrun --nproc-per-node 2 fully_shard.py" - ) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) @@ -74,10 +72,12 @@ def main(): torch.manual_seed(42) torch.cuda.manual_seed(42) - # ── 2. Create model on meta device (zero memory) ──────────────── - # quantized_model_init sets the flag for FP8 weight initialization, - # but with device="meta" no actual memory is allocated yet. - with te.quantized_model_init(enabled=True): + # ── 2. Create model on meta device (zero memory) ───────────────── + # quantized_model_init flags parameters for FP8 quantization. + # preserve_high_precision_init_val=True saves the original BF16 + # values on CPU so they can seed optimizer master weights later, + # avoiding the precision loss of dequantizing from FP8. + with te.quantized_model_init(enabled=True, preserve_high_precision_init_val=True): model = torch.nn.Sequential( *[ te.TransformerLayer( @@ -93,14 +93,10 @@ def main(): for _ in range(NUM_LAYERS) ] ) - - # Verify all parameters are on meta device (no GPU memory used). - for name, param in model.named_parameters(): - assert param.device == torch.device("meta"), f"{name} is not on meta device" dist_print("Model created on meta device (zero GPU memory).") - # ── 3. FSDP2 sharding ──────────────────────────────────────────── - # Apply sharding to the meta-device model. FSDP2 wraps parameters + # ── 3. FSDP2 sharding ─────────────────────────────────────────── + # Apply sharding to the meta-device model. FSDP2 wraps parameters # as DTensors but no GPU memory is allocated yet. mesh = DeviceMesh("cuda", list(range(world_size))) for child in model.children(): @@ -108,37 +104,40 @@ def main(): fully_shard(model, mesh=mesh) dist_print("FSDP2 sharding applied to meta-device model.") - # ── 4. Materialize parameters on GPU ────────────────────────────── + # ── 4. Materialize parameters on GPU ───────────────────────────── # reset_parameters() on each TE module materializes the local shard # on CUDA, applies weight initialization, and quantizes to FP8. + # Because preserve_high_precision_init_val=True, the pre-quantization + # BF16 values are saved on CPU for each local shard. for module in model.modules(): if isinstance(module, TransformerEngineBaseModule): module.reset_parameters() + dist_print("Parameters materialized on GPU.") - # Post-materialization verification. - for name, param in model.named_parameters(): - assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding" - qt_count = sum( - 1 - for _, p in model.named_parameters() - if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) - ) - assert qt_count > 0, "No QuantizedTensor local tensors after materialization" - dist_print( - f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params " - "wrapped in DTensors." - ) - - # ── 5. Optimizer ───────────────────────────────────────────────── + # ── 5. Optimizer with FP32 master weights ──────────────────────── optimizer = te.optimizers.FusedAdam( model.parameters(), lr=1e-3, master_weights=True, master_weight_dtype=torch.float32, ) - dist_print("Using FusedAdam with master_weights=True.") - # ── 6. Training loop ───────────────────────────────────────────── + # ── 6. Seed master weights from high-precision init values ─────── + # By default, FusedAdam initializes master weights by dequantizing + # the FP8 parameters, which introduces quantization noise. Instead, + # we seed them from the original BF16 init values preserved in step 2. + for param in model.parameters(): + optimizer.initialize_state(param, store_param_remainders=False) + local = param._local_tensor if isinstance(param, DTensor) else param + hp_val = getattr(local, "get_high_precision_init_val", lambda: None)() + if hp_val is not None: + optimizer.set_scaled_state( + param, "master_param", hp_val.to(device=device, dtype=torch.float32) + ) + local.clear_high_precision_init_val() + dist_print("Optimizer master weights seeded from high-precision init values.") + + # ── 7. Training loop ───────────────────────────────────────────── x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) @@ -153,56 +152,22 @@ def main(): optimizer.step() dist_print(f" Step {step}: loss = {loss.item():.6f}") - # ── 7. Post-training assertions ────────────────────────────────── - dist_print("\nVerifying invariants ...") - - qt_after = 0 - for name, param in model.named_parameters(): - assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" - if isinstance(param._local_tensor, QuantizedTensor): - qt_after += 1 - assert qt_after > 0, "No QuantizedTensor local tensors after training" - dist_print(f" {qt_after} params still have QuantizedTensor local tensors.") - - # Optimizer states: master weights and moments should be float32. - for param in model.parameters(): - state = optimizer.state[param] - if "master_param" in state: - assert ( - state["master_param"].dtype == torch.float32 - ), f"Master weight dtype {state['master_param'].dtype}, expected float32" - assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" - assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" - - dist_print("All assertions passed!") - dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor") - dist_print(" - Optimizer master weights: float32") - dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32") - # ── 8. Distributed checkpoint: save and load ───────────────────── # torch.distributed.checkpoint (DCP) saves sharded state — each rank - # writes only its local shard. This preserves FP8 compute weights - # and the full optimizer state (master weights, moments, step count). + # writes only its local shard, preserving FP8 compute weights and + # the full optimizer state (master weights, moments, step count). import torch.distributed.checkpoint as dcp - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - get_optimizer_state_dict, - ) - # Use a fixed path so all ranks agree on the checkpoint location. checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint" dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...") - # Save sharded checkpoint. DCP handles DTensor shards natively — - # each rank writes only its local shard to the filesystem. dcp.save( {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, checkpoint_id=checkpoint_dir, ) dist_print(" Checkpoint saved (FP8 weights + optimizer state).") - # Load checkpoint back. Provide empty state dict containers with the + # Load checkpoint back. Provide empty state dict containers with the # same structure; DCP fills them from the saved files. state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} dcp.load(state_to_load, checkpoint_id=checkpoint_dir) @@ -225,6 +190,11 @@ def main(): # authoritative FP32 values (more precise than dequantizing FP8). # All ranks must participate in gathering; only rank 0 saves. from safetensors.torch import save_file + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_optimizer_state_dict, + ) full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) @@ -238,10 +208,10 @@ def main(): for key, value in full_model_state.items(): if key in opt_param_states and "master_param" in opt_param_states[key]: - # Prefer optimizer's FP32 master weight (maintained throughout training). + # Prefer optimizer's FP32 master weight. fp32_state[key] = opt_param_states[key]["master_param"].float() - elif isinstance(value, QuantizedTensor): - # Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off). + elif isinstance(value, te.QuantizedTensor): + # Fallback: dequantize FP8 → FP32. fp32_state[key] = value.dequantize().float() else: # Non-FP8 params (e.g. LayerNorm weights): cast to FP32. @@ -251,14 +221,6 @@ def main(): save_file(fp32_state, save_path) dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}") - # Quick verification: all saved tensors are float32. - from safetensors.torch import load_file - - loaded = load_file(save_path) - for k, v in loaded.items(): - assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" - dist_print(f" Verified: all {len(loaded)} tensors are float32.") - dist.destroy_process_group()