Skip to content

Commit 267f1df

Browse files
committed
Extend FSDP2 unit tests to include DCP checkpointing and parity tests.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 4f7f9c0 commit 267f1df

12 files changed

Lines changed: 256 additions & 205 deletions

File tree

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 180 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
#
55
# See LICENSE for license information.
66

7+
import argparse
78
import os
89
import sys
9-
import argparse
10+
import shutil
11+
from contextlib import nullcontext
12+
from copy import deepcopy
1013
from dataclasses import dataclass
14+
from pathlib import Path
1115

1216
import transformer_engine.pytorch as te
1317
import transformer_engine.common.recipe
14-
18+
from transformer_engine.pytorch import QuantizedTensor
1519
import torch
1620
import torch.distributed as dist
1721
from torch.distributed.checkpoint import save, load
@@ -27,11 +31,13 @@
2731
from torch.distributed import DeviceMesh
2832
from torch.distributed._composable.fsdp import fully_shard
2933
from torch.distributed.device_mesh import init_device_mesh
30-
from transformer_engine.pytorch import QuantizedTensor
31-
from contextlib import nullcontext
3234

3335
LOCAL_RANK = None
3436

37+
# Needed for `torch.distributed.checkpoint.{save,load}` because
38+
# multiple processes need to write to the same directory.
39+
SHARED_TMP_DIR = "/tmp/pytest-shared-tmp"
40+
3541

3642
@dataclass
3743
class AppState(Stateful):
@@ -63,7 +69,7 @@ def state_dict(self):
6369
# yet get_state_dict / _init_optim_state produce empty Tensors.
6470
# TransformerEngine uses empty Tensors for dummy Parameters.
6571
optimizer_state_dict["state"][fqn] = {}
66-
if fqn.endswith("._extra_state"):
72+
if fqn.endswith("_extra_state"):
6773
# Evict `_extra_state` quantization data from model checkpoint.
6874
model_state_dict.pop(fqn)
6975
return {
@@ -203,14 +209,30 @@ def init_te_model(config):
203209
kwargs["device"] = config.device
204210
kwargs["tp_size"] = config.tp_size
205211

212+
# DeviceMesh / DTensor-related model parameter operations!
213+
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
214+
# If not using meta device initialization, reset_parameters is called during __init__.
215+
if config.tp_size > 1: # (H/F)SDP-TP
216+
assert "dp_shard" in config.mesh.mesh_dim_names
217+
assert "tp" in config.mesh.mesh_dim_names
218+
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
219+
# For TP shards as DTensors.
220+
kwargs["tp_mesh"] = config.mesh["tp"]
221+
# For per-tensor quantization recipes with TP.
222+
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
223+
elif len(config.mesh.mesh_dim_names) > 1: # HSDP
224+
assert "dp_shard" in config.mesh.mesh_dim_names
225+
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
226+
# Used for per-tensor quantization recipes like Float8CurrentScaling.
227+
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
228+
206229
layer_type = get_te_layer_from_string(config.layer_type)
207230
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
208231
# more details below.
209232
if layer_type in [
210233
te.TransformerLayer,
211234
te.MultiheadAttention,
212235
te.LayerNormMLP,
213-
# TODO(@cspades): GroupedLinear testing.
214236
]:
215237
# For this case, we are creating a model that resemebles production use-cases
216238
# wherein there are mltiple TransformerLayers in the model. And we would need
@@ -221,24 +243,9 @@ def init_te_model(config):
221243
kwargs["fuse_qkv_params"] = True
222244
if layer_type is te.MultiheadAttention:
223245
kwargs["input_layernorm"] = True
224-
# DeviceMesh / DTensor-related model parameter operations!
225-
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
226-
# If not using meta device initialization, reset_parameters is called during __init__.
227246
if config.tp_size > 1:
228-
assert "dp_shard" in config.mesh.mesh_dim_names
229-
assert "tp" in config.mesh.mesh_dim_names
230-
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
231247
# Activate TP in TE.
232248
kwargs["set_parallel_mode"] = True
233-
# For TP shards as DTensors.
234-
kwargs["tp_mesh"] = config.mesh["tp"]
235-
# For per-tensor quantization recipes with TP.
236-
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
237-
elif len(config.mesh.mesh_dim_names) > 1:
238-
assert "dp_shard" in config.mesh.mesh_dim_names
239-
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
240-
# Used for per-tensor quantization recipes like Float8CurrentScaling.
241-
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
242249
# Initialize model.
243250
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
244251
elif layer_type in [te.LayerNormLinear, te.Linear]:
@@ -247,26 +254,11 @@ def init_te_model(config):
247254
# reshard_after_forward=True for the parameters of these model.
248255
args[1] *= 3 # QKV projection
249256
out_shape[-1] *= 3
250-
# DeviceMesh / DTensor-related model parameter operations!
251-
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
252-
# If not using meta device initialization, reset_parameters is called during __init__.
253257
if config.tp_size > 1:
254-
assert "dp_shard" in config.mesh.mesh_dim_names
255-
assert "tp" in config.mesh.mesh_dim_names
256-
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
257258
# Activate TP in TE.
258259
kwargs["parallel_mode"] = "column"
259-
# For TP shards as DTensors.
260-
kwargs["tp_mesh"] = config.mesh["tp"]
261-
# For per-tensor quantization recipes with TP.
262-
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
263260
# Modify output shape for column-parallel Linear.
264261
out_shape[-1] //= config.tp_size
265-
elif len(config.mesh.mesh_dim_names) > 1:
266-
assert "dp_shard" in config.mesh.mesh_dim_names
267-
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
268-
# Used for per-tensor quantization recipes like Float8CurrentScaling.
269-
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
270262
# Initialize model.
271263
model = layer_type(*args, **kwargs)
272264
else:
@@ -352,7 +344,9 @@ def test_fp8_fsdp2_allgather(model):
352344
# FP32 manual weight allgather
353345
fp32_allgathered_params = {}
354346
for name, param in model.named_parameters():
355-
assert isinstance(param, DTensor)
347+
assert isinstance(
348+
param, DTensor
349+
), f"[test_fp8_fsdp2_allgather] {param} should be a DTensor."
356350
local_tensor = param._local_tensor
357351
device_mesh = param.device_mesh
358352
dist_group = (
@@ -471,7 +465,7 @@ def _train(args):
471465
optimizer = optim.Adam(model.parameters(), lr=1e-3)
472466

473467
"""
474-
Pre-Save Training
468+
FSDP2 Training
475469
"""
476470
for iteration in range(args.iter):
477471
# Zero the parameter gradients
@@ -499,6 +493,154 @@ def _train(args):
499493
if args.fp8_init:
500494
test_fp8_fsdp2_allgather(model)
501495

496+
"""
497+
DCP Checkpoint Testing
498+
"""
499+
# Compute the pre-save model loss to the last random input
500+
# with respect to the last random target.
501+
model.eval()
502+
with (
503+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
504+
if args.recipe == "NVFP4BlockScaling"
505+
else nullcontext()
506+
):
507+
with te.autocast(enabled=True, recipe=fp8_recipe):
508+
output = model(input_data)
509+
pre_save_loss = F.mse_loss(output, target)
510+
511+
# Save deep copy of the model and optimizer state before checkpointing.
512+
# NOTE(@cspades): deepcopy has issues with DTensors. Just clone().
513+
s1 = {}
514+
for key, val in model.state_dict().items():
515+
s1[key] = val.clone()
516+
optim_state_dict = optimizer.state_dict()
517+
o1 = {"state": {}}
518+
for idx, state in optim_state_dict["state"].items():
519+
o1_state = o1["state"].setdefault(idx, {})
520+
for key, val in state.items():
521+
o1_state[key] = val.clone()
522+
o1["param_groups"] = deepcopy(optim_state_dict["param_groups"])
523+
524+
# Write model to checkpoint.
525+
CKPT_DIR = (
526+
Path(SHARED_TMP_DIR)
527+
/ "run_fsdp2_model"
528+
/ f"dcp-{'_'.join(str(x) for x in args.sharding_dims)}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}"
529+
)
530+
CKPT_DIR.mkdir(parents=True, exist_ok=True, mode=0o777)
531+
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
532+
torch.distributed.checkpoint.save(state_dict, checkpoint_id=str(CKPT_DIR))
533+
534+
# Perform an extra training step to change the weights such that
535+
# state parity tests will fail unless the checkpoint is loaded
536+
# without any errors or incongruities vs. the saved model state.
537+
model.train()
538+
for iteration in range(args.iter):
539+
optimizer.zero_grad()
540+
with (
541+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
542+
if args.recipe == "NVFP4BlockScaling"
543+
else nullcontext()
544+
):
545+
with te.autocast(enabled=True, recipe=fp8_recipe):
546+
output = model(torch.randn(inp_shape).to(device))
547+
loss = F.mse_loss(output, torch.randn(out_shape).to(device))
548+
loss.backward()
549+
optimizer.step()
550+
551+
# Load the checkpoint.
552+
state_dict = {"app": AppState(model=model, optimizer=optimizer)}
553+
torch.distributed.checkpoint.load(state_dict=state_dict, checkpoint_id=str(CKPT_DIR))
554+
555+
# FIXME(@cspades): DelayedScaling checkpointing has tiny (+/- 1) uint8
556+
# parity issues that affects the dequantized model state. Only test loss
557+
# parity if using DelayedScaling with quantized parameters.
558+
if not args.fp8_init or args.recipe != "DelayedScaling":
559+
# Validate checkpoint parity with pre-save state dictionaries.
560+
# Compare pre-save and post-load model state dictionaries.
561+
s2 = model.state_dict()
562+
nonempty_model_state = False
563+
for key in s1.keys() | s2.keys():
564+
if key.endswith("_extra_state"):
565+
# Don't parity test _extra_state. Shape can change after reset_parameters().
566+
continue
567+
v1 = s1.get(key, None)
568+
if isinstance(v1, DTensor):
569+
v1 = v1.to_local()
570+
v2 = s2.get(key, None)
571+
if isinstance(v2, DTensor):
572+
v2 = v2.to_local()
573+
assert (
574+
v1 is not None and v2 is not None
575+
), f"[{key} Not Found] Original Param: {v1} | Checkpoint Param: {v2}"
576+
assert (
577+
v1.shape == v2.shape
578+
), f"[Checkpoint Param {key} Shape Mismatch] {v1.shape} != {v2.shape}"
579+
assert torch.allclose(v1, v2), f"[Checkpoint Param {key} Value Mismatch] {v1} != {v2}"
580+
nonempty_model_state = True
581+
assert nonempty_model_state, "Model state should not be empty for evenly-sharded DTensors!"
582+
583+
# Compare pre-save and post-load optimizer state dictionaries.
584+
o2 = optimizer.state_dict()
585+
nonempty_optim_state = False
586+
for param_id in o1["state"].keys() | o2["state"].keys():
587+
param_state_1 = o1["state"].get(param_id, None)
588+
param_state_2 = o2["state"].get(param_id, None)
589+
assert param_state_1 is not None and param_state_2 is not None, (
590+
f"[{param_id} Not Found] Original Optim State: {param_state_1} | Checkpoint Optim"
591+
f" State: {param_state_2}"
592+
)
593+
for key in param_state_1.keys() | param_state_2.keys():
594+
v1 = param_state_1.get(key, None)
595+
if isinstance(v1, DTensor):
596+
v1 = v1.to_local()
597+
v2 = param_state_2.get(key, None)
598+
if isinstance(v2, DTensor):
599+
v2 = v2.to_local()
600+
assert v1 is not None and v2 is not None, (
601+
f"[{param_id} {key} Not Found] Original Optim State: {v1} | Checkpoint Optim"
602+
f" State: {v2}"
603+
)
604+
assert (
605+
v1.shape == v2.shape
606+
), f"[Optim State {param_id} {key} Shape Mismatch] {v1.shape} != {v2.shape}"
607+
assert torch.allclose(
608+
v1, v2
609+
), f"[Optim State {param_id} {key} Value Mismatch] {v1} != {v2}"
610+
nonempty_optim_state = True # Optimizer state depends on wgrad, verify this!
611+
assert (
612+
nonempty_optim_state
613+
), "Optimizer state should not be empty for evenly-sharded DTensors!"
614+
assert len(o1["param_groups"]) == len(o2["param_groups"]), (
615+
f"[Optim State Param Groups Length Mismatch] {o1['param_groups']} !="
616+
f" {o2['param_groups']}"
617+
)
618+
for i in range(len(o2["param_groups"])):
619+
for key in o1["param_groups"][i].keys():
620+
v1 = o1["param_groups"][i][key]
621+
v2 = o2["param_groups"][i][key]
622+
assert v1 == v2, f"[Optim State Param Group {i} {key} Value Mismatch] {v1} != {v2}"
623+
624+
# Validate post-load model loss.
625+
model.eval()
626+
with (
627+
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
628+
if args.recipe == "NVFP4BlockScaling"
629+
else nullcontext()
630+
):
631+
with te.autocast(enabled=True, recipe=fp8_recipe):
632+
output = model(input_data)
633+
post_load_loss = F.mse_loss(output, target)
634+
# Allow for 1% disparity due to _extra_state disparity.
635+
assert torch.allclose(
636+
pre_save_loss, post_load_loss, rtol=1e-2
637+
), f"Pre-Save Loss: {pre_save_loss} != Post-Load Loss: {post_load_loss}"
638+
639+
# Clean up temporary checkpoint directory.
640+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
641+
shutil.rmtree(CKPT_DIR)
642+
torch.distributed.barrier()
643+
502644
dist.destroy_process_group()
503645
return 0
504646

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# See LICENSE for license information.
44

5+
import math
56
import os
67
import subprocess
78
from pathlib import Path
@@ -74,7 +75,7 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
7475
subprocess.run(test_cmd, env=os.environ, check=True)
7576

7677

77-
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
78+
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs.")
7879
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
7980
@pytest.mark.parametrize(
8081
"sharding_dims",
@@ -83,16 +84,20 @@ def _run_test(fp_init, sharding_dims, recipe, layer_type):
8384
[NUM_PROCS],
8485
# HSDP
8586
[2, NUM_PROCS // 2],
86-
# FSDP-TP
87-
[1, 2, NUM_PROCS // 2],
88-
# HSDP-TP
87+
# (H/F)SDP-TP
8988
[NUM_PROCS // 4, 2, 2],
9089
),
9190
)
9291
@pytest.mark.parametrize("fp8_init", (False, True))
9392
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))
9493
def test_distributed(fp8_init, sharding_dims, fp_recipe, layer_type):
9594

95+
parallel_size = math.prod(x for x in sharding_dims if x != 0)
96+
if NUM_PROCS < parallel_size:
97+
pytest.skip(
98+
f"Insufficient devices ({NUM_PROCS}) to test sharding configuration: {sharding_dims}"
99+
)
100+
96101
if fp_recipe in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init:
97102
pytest.xfail(f"{fp_recipe} + fp8_init: test_fp8_fsdp2_allgather is currently failing.")
98103

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
CudaRNGStatesTracker,
4848
graph_safe_rng_available,
4949
_convert_param_to_dtensor_param,
50+
_extract_trainable_tensor_from_dtensor,
5051
)
5152
from transformer_engine.pytorch.jit import no_torch_dynamo
5253
from transformer_engine.pytorch.graph import is_graph_capturing
@@ -575,7 +576,8 @@ def set_device_mesh(
575576
weight_mesh : Optional[DeviceMesh]
576577
Not used for DotProductAttention as there are no quantized weights.
577578
"""
578-
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
579+
if weight_mesh is not None:
580+
warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}")
579581
if tp_mesh is not None:
580582
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
581583
assert tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), (
@@ -864,13 +866,16 @@ def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> Non
864866

865867
def _get_softmax_offset(self) -> torch.Tensor:
866868
"""Get the softmax offset."""
869+
softmax_offset = self.softmax_offset
870+
if isinstance(softmax_offset, DTensor):
871+
# Extract the trainable compute Tensor.
872+
softmax_offset = _extract_trainable_tensor_from_dtensor(softmax_offset)
873+
# Reshape softmax offset.
867874
softmax_offset = (
868-
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
869-
if self.softmax_offset is not None
875+
softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
876+
if softmax_offset is not None
870877
else None
871878
)
872-
if isinstance(softmax_offset, DTensor):
873-
softmax_offset = softmax_offset.to_local()
874879
return softmax_offset
875880

876881
@no_torch_dynamo(recursive=False)

0 commit comments

Comments
 (0)