Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,25 @@
"D": NameRemapping("backbone.layers.{}.mixer.D", REPLICATE),
"dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias", REPLICATE),
"conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d.", REPLICATE),
"in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj.", COL_TP),
# mapping layer_norm_weight to None tells _name_remapping to skip it;
# the fused layer_norm_weight is loaded separately via the "fused_norm" rule.
"in_proj": NameRemapping(
"backbone.layers.{}.mixer.in_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}}
),
"out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj.", ROW_TP),
# Attention
"input_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
"linear_qkv": QKVMerging("backbone.layers.{}.mixer.", COL_TP),
"linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj.", ROW_TP),
# MLP
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP),
"linear_fc1": NameRemapping(
"backbone.layers.{}.mixer.up_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}}
),
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP),
# Fused layer norm: loads the HF norm weight into fused TELayerNormColumnParallelLinear
# modules (in_proj, linear_qkv, linear_fc1) when using TE spec.
"fused_norm": NameRemapping("backbone.layers.{}.norm.weight"),
# MoE
"router": NameRemapping(
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
Expand All @@ -92,12 +101,14 @@
"mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}),
"mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}),
"mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}),
# Grouped local experts in MTP
# Grouped local experts (used for TEGroupedMLP in both decoder and MTP layers).
# The prefix uses "backbone" for regular decoder layers; when called from MTP
# context (is_mtp=True), _grouped_mlp_merging replaces "backbone" with "mtp".
"experts.linear_fc1": GroupedMLPMerging(
"mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True}
"backbone.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP
),
"experts.linear_fc2": GroupedMLPMerging(
"mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True}
"backbone.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP
),
}

Expand Down
39 changes: 39 additions & 0 deletions modelopt/torch/export/plugins/megatron_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ def _name_remapping(
state_dict[key] = val
else:
source_key = mapping.get(key, key)
# A mapping value of None means "skip this key" (keep existing value).
# This is used for fused TE modules where layer_norm_weight is loaded
# separately from a different HF path.
if source_key is None:
state_dict[key] = val
continue
# For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding
# since bias should always be replicated, not sharded
if (
Expand Down Expand Up @@ -537,6 +543,15 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar):
self.rules["in_proj"](layer.mixer.in_proj, layer_id)
self.rules["out_proj"](layer.mixer.out_proj, layer_id)

# TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.norm, IdentityOp)
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)

def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False):
if not isinstance(layer.input_layernorm, IdentityOp):
self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp)
Expand Down Expand Up @@ -578,6 +593,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp
)

# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.input_layernorm, IdentityOp)
and hasattr(attention, "linear_qkv")
and hasattr(attention.linear_qkv, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
)

if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp)

Expand Down Expand Up @@ -671,6 +698,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp)
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)

# TE spec: pre_mlp_layernorm is fused into linear_fc1
# (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.pre_mlp_layernorm, IdentityOp)
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
)

def _import_state_dict(self):
model = self.model
layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm)
Expand Down