Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 14 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 14 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

  • Support (H/F)SDP2 x TP strided sharding, and DTensor FP8 parameters for Torch DCP checkpointing, across all TransformerEngineBaseModule(s).
    • Except GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules under transformer_engine.pytorch.modules are supported.
    • FusibleOperation support is also a WIP, except for LayerNorm or RMSNorm which are TE modules.
  • Associated with BioNeMo-Recipes Llama3 TP: Enable TransformerEngine-backed Tensor Parallelism with Llama3. bionemo-framework#1483
    • Notably, TransformerEngine TP can be easily mixed with DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we use DTensor-based TP on the torch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to the torch.nn.Embedding, which is why we do not need to call set_device_mesh for the LM head!
  • Credit to @pstjohn for coming up with this idea!

Usage / Documentation

(tp_mesh and weight_mesh can also be passed in TEModule.__init__.)

    def set_device_mesh(
        self,
        tp_mesh: Optional[DeviceMesh] = None,
        weight_mesh: Optional[DeviceMesh] = None,
    ) -> None:
        """
        Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
        depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

        TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
        integration with Torch DCP checkpointing. This method should only be invoked when
        using DTensor parameters, e.g. when using FSDP2 or DCP.

        When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
        convert them into FSDP-TP strided or non-strided shards depending on the current
        sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
        matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
        This experimental FSDP-TP logic presides in this FSDP2 initialization function:
        ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

        Parameters
        ----------
        tp_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
            Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
        weight_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
            when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
            parameters and if the DTensor DeviceMesh includes dimensions that do not
            shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
            For example:
                - device_mesh["dp"] for FSDP.
                - device_mesh["dp_cp"] if using CP ranks in FSDP.
                - device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
                - device_mesh["tp"] if using TP.
                - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
        """

Details

DTensor Lifecycle in TransformerEngine

  • Initialization
    • __init__
      • TransformerEngine model parameters are initialized either on device or meta device with the appropriate tp_size and TP sharding strategy, e.g. parallel_mode and sequence_parallel.
    • TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)
      • Converts parameters to DTensor with appropriate TP placement(s) based on the TP sharding strategy specified in __init__, using transformer_engine.pytorch.distributed._convert_param_to_dtensor_param.
        • tp_mesh is a 1-D DeviceMesh containing the TP ProcessGroup that will be registered with the TransformerEngine module.
        • weight_mesh is the 1-D DeviceMesh containing the ProcessGroup that shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes like Float8CurrentScaling.
      • Needs to be invoked prior to fully_shard (which responds to the TP placements) and prior to reset_parameters(defer_init=False), which quantizes parameters.
      • Can also be directly invoked during __init__(tp_mesh, weight_mesh) for supported TransformerEngine modules.
    • fully_shard shards the TransformerEngine model with FSDP2.
      • When fully_shard encounters TP sharding on dim=0, it will use a _StridedShard for DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in the DeviceMesh and DTensor.placements. (See Appendix for visualization of this sharding strategy.)
    • reset_parameters is called if using meta device initialization.
  • Training
    • Pre-forward, FSDP2 all-gathers the sharded DTensor "main" weight that it registered during fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such as FusedAdam must be used to properly handle high-precision main weights.)
      • When using FSDP2 x TP, the all-gathered Tensor is actually a TP-sharded DTensor, which deviates from the original FSDP2 paradigm where the all-gathered Tensor is fully-unsharded and the DTensor wrapping is discarded. To support these DTensor compute weights in TransformerEngine modules, we utilize transformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensor to localize the DTensor and also inherit requires_grad attribute from the DTensor parameter as the local Tensor has this un-set during DTensor.from_local(Tensor) for FP8 parameters specifically!
    • Post-backward, the Tensor gradient is converted to DTensor and attached to the DTensor.grad attribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case of FusibleOperation, casted during the backward implementation.

QuantizedTensor Storage

  • When both row and column data are None, we send untyped_storage() to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is a data_ptr() = nullptr and breaks DCP.
    • While untyped_storage is not used anywhere in TransformerEngine, it may break code that uses Storage to figure out if a Tensor is empty or not, as now QuantizedTensor storage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.

Bugs

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh(weight_mesh) API.
  • TransformerEngineBaseModule: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • NOTE(@cspades): DelayedScaling has DCP save/load disparity issues, i.e. on the scale of +/-1 to the uint8 parameter checkpoint!

Appendix

_StridedShard - Using FSDP2 x TP Strided-Sharding

# (DP=4, TP=2)
(_StridedShard(dim=0, sf=2), Shard(dim=0))

┌───┬───┐
│ 0 │ 4 │ ← DP=0
├───┼───┤
│ 1 │ 5 │ ← DP=1
├───┼───┤          FSDP all-gather happens across the DP ranks,
│ 2 │ 6 │ ← DP=2   so we need to form the 0-3 and 4-7 TP shards!
├───┼───┤
│ 3 │ 7 │ ← DP=3
└───┴───┘
  ↑   ↑
TP=0 TP=1

When redistribute'ing a global DTensor to (_StridedShard(dim=0, sf=2), Shard(dim=0)), DTensor will perform the following steps:

  • Pre-shard the Tensor data with respect to the stride / shard factor, which is defined as the product of the parallelism sizes of all Shard placements to the right of _StridedShard. (In the above example, since TP=2, the factor is 2.)
    • [0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].
    • In the context of this PR and fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling _convert_param_to_dtensor_param!
  • Shard the pre-shards for _StridedShard.
    • [0] [1] [2] [3] and [4] [5] [6] [7]
  • Concatenate the strided shards.
    • [0 4] [1 5] [2 6] [3 7], which are assigned to the _StridedShard ranks.
    • Note that this is very different if we did left-to-right-sharding, which would have given us [0 1] [2 3] [4 5] [6 7]!
  • Subsequently / finally, each strided shard is sharded on the Shard placement.
    • [0] [4] / [1] [5] / [2] [6] / [3] [7], which are assigned to the Shard ranks.
    • Note that this is very different if we did left-to-right sharding, which would have given us [0] [1] / [2] [3] / [4] [5] / [6] [7]!

PyTorch also supports the inverse / un-sharding of this redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds Torch DCP (Distributed Checkpoint) compatibility for FSDP2 × TP strided sharding across all TransformerEngineBaseModule subclasses, enabling seamless interoperability between TransformerEngine's TP mechanics and PyTorch's DTensor-based checkpointing infrastructure.

Key changes and observations:

  • New set_device_mesh(tp_mesh, weight_mesh) API is added to every supported TE module (Linear, LayerNormLinear, LayerNormMLP, LayerNorm, RMSNorm, DotProductAttention, MultiheadAttention, TransformerLayer). It converts main weights to DTensor with the appropriate TP placements so that FSDP2's fully_shard can apply _StridedShard semantics for HSDP-TP.
  • _extract_trainable_tensor_from_dtensor uses a custom _ToLocalIdentity autograd function to preserve object identity between the returned local tensor and DTensor._local_tensor, ensuring FSDP2's post-all-gather in-place updates remain visible; it also explicitly inherits requires_grad from the DTensor (since DTensor.from_local resets it).
  • QuantizedTensor subclasses (Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor, NVFP4Tensor) now expose untyped_storage() returning a 1-byte _default_storage when both data buffers are None, unblocking DCP checkpoint-load assertions. Note: The _default_storage is allocated eagerly in __new__ on the current CUDA device, which may cause device mismatches or unexpected CUDA memory allocation in certain initialization patterns (e.g., meta-device init or before torch.cuda.set_device).
  • quantizers type fixself.quantizers = {"scaling_fwd": [], "scaling_bwd": []} corrects a longstanding inconsistency where the inner values were dicts but were populated as lists via .extend().
  • Bug fix in _LayerNormMLP backwardisinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) was always False (quantizers are not storage objects); corrected to check ctx.fc1_weight instead, re-enabling update_usage(columnwise_usage=True).
  • NVFP4Tensor.untyped_storage docstring incorrectly references "MXFP8Tensor" (copy-paste error).
  • Test coverage is significantly expanded with a DCP save/load round-trip, pre/post-save loss parity, and optimizer state validation. Note: Unused save/load imports in run_fsdp2_model.py will trigger lint warnings; code uses fully-qualified torch.distributed.checkpoint.save() / load() calls instead.

Confidence Score: 4/5

  • Safe to merge with attention to the _default_storage device-assumption caveat in quantized tensor storage classes.
  • The core DTensor conversion, _ToLocalIdentity gradient pathway, and set_device_mesh implementations are well-reasoned and consistent across all modules. The changes are validated by Megatron-LM CI/parity tests. Score is 4 rather than 5 due to: (1) the _default_storage eagerly allocating CUDA memory in __new__ before device assignment in certain initialization orderings (meta-device init, pre-torch.cuda.set_device); (2) unused imports (save/load) that will trigger lint warnings; and (3) a minor docstring copy-paste error in NVFP4Tensor.
  • transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py, transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py, transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py (device allocation assumption), and tests/pytorch/distributed/run_fsdp2_model.py (unused imports).

Comments Outside Diff (3)

  1. transformer_engine/pytorch/tensor/nvfp4_tensor.py, line 555-562 (link)

    Docstring copy-paste error

    The docstring for NVFP4Tensor.untyped_storage() incorrectly references "MXFP8Tensor" in the first line of the description — this is an NVFP4Tensor method.

  2. tests/pytorch/distributed/run_fsdp2_model.py, line 21 (link)

    Unused imports

    save and load are imported from torch.distributed.checkpoint but are never used directly — all call sites use the fully-qualified torch.distributed.checkpoint.save(...) and torch.distributed.checkpoint.load(...) forms instead. This will trigger unused-import lint warnings.

    Either use the short names at the call sites or remove the starred imports and keep the fully-qualified calls.

  3. transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py, line 66 (link)

    Eager CUDA memory allocation in __new__ may cause device issues

    torch.UntypedStorage(1, device=torch.cuda.current_device()) is called unconditionally in __new__, which runs every time an instance is created. This can cause problems in certain initialization patterns:

    1. Meta-device init: When a model is constructed with device="meta", this __new__ call will still allocate 1 byte of CUDA memory on the current device (e.g., GPU 0), even though the model should not occupy any device memory yet.

    2. Device mismatch: If torch.cuda.set_device(LOCAL_RANK) is called after the tensor is constructed (e.g., user code that initializes the model before setting up distributed contexts), _default_storage will reside on device 0 for all ranks, causing device mismatches on other ranks.

    Consider deferring _default_storage creation to first access (lazy property), basing the device on self.device rather than current_device(), or using a meta-device-safe allocation pattern.

Last reviewed commit: "Merge branch 'main' ..."

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from fcdd5bd to c912f5b Compare March 5, 2026 16:06
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 5 times, most recently from bc82f02 to 267f1df Compare March 10, 2026 01:30
@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 4 times, most recently from f0b3cae to af7362a Compare March 12, 2026 15:26
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 4 times, most recently from 9435382 to 15df86f Compare March 16, 2026 19:16
@cspades
Copy link
Member Author

cspades commented Mar 16, 2026

/te-ci L1 pytorch

@cspades
Copy link
Member Author

cspades commented Mar 17, 2026

@cspades
Copy link
Member Author

cspades commented Mar 18, 2026

For some reason after 2.3k training steps, I start to get NaNs: https://wandb.ai/nvidia/bionemo-recipes/runs/nmzugu0a?nw=nwusercye_nv

Restarting from this checkpoint and around 500 steps later same thing.

cspades and others added 13 commits March 19, 2026 08:50
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…ess.

Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 19, 2026

/te-ci L1 pytorch

module.reset_parameters()

# Run a training step to initialize FSDP2 lazy state and update quantization
# scales before testing the allgather. Block-scaling formats (Float8BlockScaling,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe Float8Blockscaling allgather should work now right?

input_data = torch.randn(inp_shape, device=device)
target = torch.randn(inp_shape, device=device)
nvfp4_ctx = (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why seperate nvfp4 context? In general, adding multiple context manager adds CPU overheads in the training loop.



@dataclass
class AppState(Stateful):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a useful class. With some things like extra state specific to TE. Might make sense to move it TE distributed module. Thoughts @cspades ?

# TransformerEngine uses empty Tensors for dummy Parameters.
optimizer_state_dict["state"][fqn] = {}
if fqn.endswith("_extra_state"):
# Evict `_extra_state` quantization data from model checkpoint.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is evicted, how do we make sure it is updated correctly after load from checkpoint?

Comment on lines +125 to +128
(
# FSDP
[NUM_PROCS],
# HSDP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work ok if NUM_PROCS < 4? i.e lets say NUM_PROCS = 2. TP dimension will be 0. Curious what happens in that case?


if tp_mesh is not None or weight_mesh is not None:
# Apply DeviceMesh and DTensor-related modifications.
self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in set_device_mesh function, it says weight_mesh is not necessary, but we call it only if both tp_mesh and weight_mesh is not None. So it should not include the condition weight_mesh is not none right?

else device_mesh.get_group()
)
quantizer.amax_reduction_group = amax_reduction_group
quantizer.amax_reduction_group = device_mesh.get_group()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which group will it return in case of multiple dimensions? For instance if weight is both FSDP-TP sharded then will this give the FSDP dim or TP dim?

Comment on lines +223 to +224
if isinstance(weight, DTensor):
weight = weight.to_local()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt we use _extract_trainable_tensor_from_dtensor here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also applicable to couple in other places in ops folder

instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales
instance._default_storage = torch.UntypedStorage(1, device=torch.cuda.current_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am leaning towards creating this default_storage on CPU instead due to coupe of reasons

  1. Since Idea of default_storage here is to show unique identity, keeping it on CPU/GPU shouldnt matter
  2. Calling torch cuda current device on every single Tensor creation has python overheads)

integration with Torch DCP checkpointing. This method should only be invoked when
using DTensor parameters, e.g. when using FSDP2 or DCP.

When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it that we havent added tp_mesh and weight_mesh to the constrictors of rmsnorm and layer_norm? But for every other layer we have?

param = getattr(self, bias)
placements = (Replicate(),)
if self.parallel_mode == "column":
placements = (Shard(dim=0),)
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we can make all the set_device_mesh function share a helper set_tp_mesh
defined in base.py that takes in a dictionary of {parameter name: parallel_mode} and tp_mesh that converts the parameters to Dtensors and use that in set_device_mesh of each module?

Something like this
def set_tp_mesh(self, param_mode_dict: dict, tp_mesh: Optional[DeviceMesh]):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And put the big docstring that we have over there in base.py
. The docstring seems to be repeated in all places.

@vthumbe1503
Copy link
Collaborator

Generally LGTM @cspades. Lets get it merged after comments are addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants