44#
55# See LICENSE for license information.
66
7+ import argparse
78import os
89import sys
9- import argparse
10+ import shutil
11+ from contextlib import nullcontext
12+ from copy import deepcopy
1013from dataclasses import dataclass
14+ from pathlib import Path
1115
1216import transformer_engine .pytorch as te
1317import transformer_engine .common .recipe
14-
18+ from transformer_engine . pytorch import QuantizedTensor
1519import torch
1620import torch .distributed as dist
1721from torch .distributed .checkpoint import save , load
2731from torch .distributed import DeviceMesh
2832from torch .distributed ._composable .fsdp import fully_shard
2933from torch .distributed .device_mesh import init_device_mesh
30- from transformer_engine .pytorch import QuantizedTensor
31- from contextlib import nullcontext
3234
3335LOCAL_RANK = None
3436
37+ # Needed for `torch.distributed.checkpoint.{save,load}` because
38+ # multiple processes need to write to the same directory.
39+ SHARED_TMP_DIR = "/tmp/pytest-shared-tmp"
40+
3541
3642@dataclass
3743class AppState (Stateful ):
@@ -63,7 +69,7 @@ def state_dict(self):
6369 # yet get_state_dict / _init_optim_state produce empty Tensors.
6470 # TransformerEngine uses empty Tensors for dummy Parameters.
6571 optimizer_state_dict ["state" ][fqn ] = {}
66- if fqn .endswith (". _extra_state" ):
72+ if fqn .endswith ("_extra_state" ):
6773 # Evict `_extra_state` quantization data from model checkpoint.
6874 model_state_dict .pop (fqn )
6975 return {
@@ -203,14 +209,30 @@ def init_te_model(config):
203209 kwargs ["device" ] = config .device
204210 kwargs ["tp_size" ] = config .tp_size
205211
212+ # DeviceMesh / DTensor-related model parameter operations!
213+ # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
214+ # If not using meta device initialization, reset_parameters is called during __init__.
215+ if config .tp_size > 1 : # (H/F)SDP-TP
216+ assert "dp_shard" in config .mesh .mesh_dim_names
217+ assert "tp" in config .mesh .mesh_dim_names
218+ dist_print (f"Tensor parallelism activated with size: { config .tp_size } " )
219+ # For TP shards as DTensors.
220+ kwargs ["tp_mesh" ] = config .mesh ["tp" ]
221+ # For per-tensor quantization recipes with TP.
222+ kwargs ["weight_mesh" ] = config .mesh ["dp_shard" , "tp" ]._flatten ("weight_mesh" )
223+ elif len (config .mesh .mesh_dim_names ) > 1 : # HSDP
224+ assert "dp_shard" in config .mesh .mesh_dim_names
225+ # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
226+ # Used for per-tensor quantization recipes like Float8CurrentScaling.
227+ kwargs ["weight_mesh" ] = config .mesh ["dp_shard" ] # Only sharding with FSDP.
228+
206229 layer_type = get_te_layer_from_string (config .layer_type )
207230 # We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
208231 # more details below.
209232 if layer_type in [
210233 te .TransformerLayer ,
211234 te .MultiheadAttention ,
212235 te .LayerNormMLP ,
213- # TODO(@cspades): GroupedLinear testing.
214236 ]:
215237 # For this case, we are creating a model that resemebles production use-cases
216238 # wherein there are mltiple TransformerLayers in the model. And we would need
@@ -221,24 +243,9 @@ def init_te_model(config):
221243 kwargs ["fuse_qkv_params" ] = True
222244 if layer_type is te .MultiheadAttention :
223245 kwargs ["input_layernorm" ] = True
224- # DeviceMesh / DTensor-related model parameter operations!
225- # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
226- # If not using meta device initialization, reset_parameters is called during __init__.
227246 if config .tp_size > 1 :
228- assert "dp_shard" in config .mesh .mesh_dim_names
229- assert "tp" in config .mesh .mesh_dim_names
230- dist_print (f"Tensor parallelism activated with size: { config .tp_size } " )
231247 # Activate TP in TE.
232248 kwargs ["set_parallel_mode" ] = True
233- # For TP shards as DTensors.
234- kwargs ["tp_mesh" ] = config .mesh ["tp" ]
235- # For per-tensor quantization recipes with TP.
236- kwargs ["weight_mesh" ] = config .mesh ["dp_shard" , "tp" ]._flatten ("weight_mesh" )
237- elif len (config .mesh .mesh_dim_names ) > 1 :
238- assert "dp_shard" in config .mesh .mesh_dim_names
239- # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
240- # Used for per-tensor quantization recipes like Float8CurrentScaling.
241- kwargs ["weight_mesh" ] = config .mesh ["dp_shard" ] # Only sharding with FSDP.
242249 # Initialize model.
243250 model = nn .Sequential (* [layer_type (* args , ** kwargs ) for _ in range (config .num_layers )])
244251 elif layer_type in [te .LayerNormLinear , te .Linear ]:
@@ -247,26 +254,11 @@ def init_te_model(config):
247254 # reshard_after_forward=True for the parameters of these model.
248255 args [1 ] *= 3 # QKV projection
249256 out_shape [- 1 ] *= 3
250- # DeviceMesh / DTensor-related model parameter operations!
251- # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
252- # If not using meta device initialization, reset_parameters is called during __init__.
253257 if config .tp_size > 1 :
254- assert "dp_shard" in config .mesh .mesh_dim_names
255- assert "tp" in config .mesh .mesh_dim_names
256- dist_print (f"Tensor parallelism activated with size: { config .tp_size } " )
257258 # Activate TP in TE.
258259 kwargs ["parallel_mode" ] = "column"
259- # For TP shards as DTensors.
260- kwargs ["tp_mesh" ] = config .mesh ["tp" ]
261- # For per-tensor quantization recipes with TP.
262- kwargs ["weight_mesh" ] = config .mesh ["dp_shard" , "tp" ]._flatten ("weight_mesh" )
263260 # Modify output shape for column-parallel Linear.
264261 out_shape [- 1 ] //= config .tp_size
265- elif len (config .mesh .mesh_dim_names ) > 1 :
266- assert "dp_shard" in config .mesh .mesh_dim_names
267- # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
268- # Used for per-tensor quantization recipes like Float8CurrentScaling.
269- kwargs ["weight_mesh" ] = config .mesh ["dp_shard" ] # Only sharding with FSDP.
270262 # Initialize model.
271263 model = layer_type (* args , ** kwargs )
272264 else :
@@ -352,7 +344,9 @@ def test_fp8_fsdp2_allgather(model):
352344 # FP32 manual weight allgather
353345 fp32_allgathered_params = {}
354346 for name , param in model .named_parameters ():
355- assert isinstance (param , DTensor )
347+ assert isinstance (
348+ param , DTensor
349+ ), f"[test_fp8_fsdp2_allgather] { param } should be a DTensor."
356350 local_tensor = param ._local_tensor
357351 device_mesh = param .device_mesh
358352 dist_group = (
@@ -471,7 +465,7 @@ def _train(args):
471465 optimizer = optim .Adam (model .parameters (), lr = 1e-3 )
472466
473467 """
474- Pre-Save Training
468+ FSDP2 Training
475469 """
476470 for iteration in range (args .iter ):
477471 # Zero the parameter gradients
@@ -499,6 +493,154 @@ def _train(args):
499493 if args .fp8_init :
500494 test_fp8_fsdp2_allgather (model )
501495
496+ """
497+ DCP Checkpoint Testing
498+ """
499+ # Compute the pre-save model loss to the last random input
500+ # with respect to the last random target.
501+ model .eval ()
502+ with (
503+ torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 )
504+ if args .recipe == "NVFP4BlockScaling"
505+ else nullcontext ()
506+ ):
507+ with te .autocast (enabled = True , recipe = fp8_recipe ):
508+ output = model (input_data )
509+ pre_save_loss = F .mse_loss (output , target )
510+
511+ # Save deep copy of the model and optimizer state before checkpointing.
512+ # NOTE(@cspades): deepcopy has issues with DTensors. Just clone().
513+ s1 = {}
514+ for key , val in model .state_dict ().items ():
515+ s1 [key ] = val .clone ()
516+ optim_state_dict = optimizer .state_dict ()
517+ o1 = {"state" : {}}
518+ for idx , state in optim_state_dict ["state" ].items ():
519+ o1_state = o1 ["state" ].setdefault (idx , {})
520+ for key , val in state .items ():
521+ o1_state [key ] = val .clone ()
522+ o1 ["param_groups" ] = deepcopy (optim_state_dict ["param_groups" ])
523+
524+ # Write model to checkpoint.
525+ CKPT_DIR = (
526+ Path (SHARED_TMP_DIR )
527+ / "run_fsdp2_model"
528+ / f"dcp-{ '_' .join (str (x ) for x in args .sharding_dims )} -{ args .layer_type } -{ args .recipe } -fp8_init_{ args .fp8_init } "
529+ )
530+ CKPT_DIR .mkdir (parents = True , exist_ok = True , mode = 0o777 )
531+ state_dict = {"app" : AppState (model = model , optimizer = optimizer )}
532+ torch .distributed .checkpoint .save (state_dict , checkpoint_id = str (CKPT_DIR ))
533+
534+ # Perform an extra training step to change the weights such that
535+ # state parity tests will fail unless the checkpoint is loaded
536+ # without any errors or incongruities vs. the saved model state.
537+ model .train ()
538+ for iteration in range (args .iter ):
539+ optimizer .zero_grad ()
540+ with (
541+ torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 )
542+ if args .recipe == "NVFP4BlockScaling"
543+ else nullcontext ()
544+ ):
545+ with te .autocast (enabled = True , recipe = fp8_recipe ):
546+ output = model (torch .randn (inp_shape ).to (device ))
547+ loss = F .mse_loss (output , torch .randn (out_shape ).to (device ))
548+ loss .backward ()
549+ optimizer .step ()
550+
551+ # Load the checkpoint.
552+ state_dict = {"app" : AppState (model = model , optimizer = optimizer )}
553+ torch .distributed .checkpoint .load (state_dict = state_dict , checkpoint_id = str (CKPT_DIR ))
554+
555+ # FIXME(@cspades): DelayedScaling checkpointing has tiny (+/- 1) uint8
556+ # parity issues that affects the dequantized model state. Only test loss
557+ # parity if using DelayedScaling with quantized parameters.
558+ if not args .fp8_init or args .recipe != "DelayedScaling" :
559+ # Validate checkpoint parity with pre-save state dictionaries.
560+ # Compare pre-save and post-load model state dictionaries.
561+ s2 = model .state_dict ()
562+ nonempty_model_state = False
563+ for key in s1 .keys () | s2 .keys ():
564+ if key .endswith ("_extra_state" ):
565+ # Don't parity test _extra_state. Shape can change after reset_parameters().
566+ continue
567+ v1 = s1 .get (key , None )
568+ if isinstance (v1 , DTensor ):
569+ v1 = v1 .to_local ()
570+ v2 = s2 .get (key , None )
571+ if isinstance (v2 , DTensor ):
572+ v2 = v2 .to_local ()
573+ assert (
574+ v1 is not None and v2 is not None
575+ ), f"[{ key } Not Found] Original Param: { v1 } | Checkpoint Param: { v2 } "
576+ assert (
577+ v1 .shape == v2 .shape
578+ ), f"[Checkpoint Param { key } Shape Mismatch] { v1 .shape } != { v2 .shape } "
579+ assert torch .allclose (v1 , v2 ), f"[Checkpoint Param { key } Value Mismatch] { v1 } != { v2 } "
580+ nonempty_model_state = True
581+ assert nonempty_model_state , "Model state should not be empty for evenly-sharded DTensors!"
582+
583+ # Compare pre-save and post-load optimizer state dictionaries.
584+ o2 = optimizer .state_dict ()
585+ nonempty_optim_state = False
586+ for param_id in o1 ["state" ].keys () | o2 ["state" ].keys ():
587+ param_state_1 = o1 ["state" ].get (param_id , None )
588+ param_state_2 = o2 ["state" ].get (param_id , None )
589+ assert param_state_1 is not None and param_state_2 is not None , (
590+ f"[{ param_id } Not Found] Original Optim State: { param_state_1 } | Checkpoint Optim"
591+ f" State: { param_state_2 } "
592+ )
593+ for key in param_state_1 .keys () | param_state_2 .keys ():
594+ v1 = param_state_1 .get (key , None )
595+ if isinstance (v1 , DTensor ):
596+ v1 = v1 .to_local ()
597+ v2 = param_state_2 .get (key , None )
598+ if isinstance (v2 , DTensor ):
599+ v2 = v2 .to_local ()
600+ assert v1 is not None and v2 is not None , (
601+ f"[{ param_id } { key } Not Found] Original Optim State: { v1 } | Checkpoint Optim"
602+ f" State: { v2 } "
603+ )
604+ assert (
605+ v1 .shape == v2 .shape
606+ ), f"[Optim State { param_id } { key } Shape Mismatch] { v1 .shape } != { v2 .shape } "
607+ assert torch .allclose (
608+ v1 , v2
609+ ), f"[Optim State { param_id } { key } Value Mismatch] { v1 } != { v2 } "
610+ nonempty_optim_state = True # Optimizer state depends on wgrad, verify this!
611+ assert (
612+ nonempty_optim_state
613+ ), "Optimizer state should not be empty for evenly-sharded DTensors!"
614+ assert len (o1 ["param_groups" ]) == len (o2 ["param_groups" ]), (
615+ f"[Optim State Param Groups Length Mismatch] { o1 ['param_groups' ]} !="
616+ f" { o2 ['param_groups' ]} "
617+ )
618+ for i in range (len (o2 ["param_groups" ])):
619+ for key in o1 ["param_groups" ][i ].keys ():
620+ v1 = o1 ["param_groups" ][i ][key ]
621+ v2 = o2 ["param_groups" ][i ][key ]
622+ assert v1 == v2 , f"[Optim State Param Group { i } { key } Value Mismatch] { v1 } != { v2 } "
623+
624+ # Validate post-load model loss.
625+ model .eval ()
626+ with (
627+ torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 )
628+ if args .recipe == "NVFP4BlockScaling"
629+ else nullcontext ()
630+ ):
631+ with te .autocast (enabled = True , recipe = fp8_recipe ):
632+ output = model (input_data )
633+ post_load_loss = F .mse_loss (output , target )
634+ # Allow for 1% disparity due to _extra_state disparity.
635+ assert torch .allclose (
636+ pre_save_loss , post_load_loss , rtol = 1e-2
637+ ), f"Pre-Save Loss: { pre_save_loss } != Post-Load Loss: { post_load_loss } "
638+
639+ # Clean up temporary checkpoint directory.
640+ if not torch .distributed .is_initialized () or torch .distributed .get_rank () == 0 :
641+ shutil .rmtree (CKPT_DIR )
642+ torch .distributed .barrier ()
643+
502644 dist .destroy_process_group ()
503645 return 0
504646
0 commit comments