diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 92611f54b..6883c51c9 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -58,7 +58,11 @@ "D": NameRemapping("backbone.layers.{}.mixer.D", REPLICATE), "dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias", REPLICATE), "conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d.", REPLICATE), - "in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj.", COL_TP), + # mapping layer_norm_weight to None tells _name_remapping to skip it; + # the fused layer_norm_weight is loaded separately via the "fused_norm" rule. + "in_proj": NameRemapping( + "backbone.layers.{}.mixer.in_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}} + ), "out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj.", ROW_TP), # Attention "input_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), @@ -66,8 +70,13 @@ "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj.", ROW_TP), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), - "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP), + "linear_fc1": NameRemapping( + "backbone.layers.{}.mixer.up_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}} + ), "linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP), + # Fused layer norm: loads the HF norm weight into fused TELayerNormColumnParallelLinear + # modules (in_proj, linear_qkv, linear_fc1) when using TE spec. + "fused_norm": NameRemapping("backbone.layers.{}.norm.weight"), # MoE "router": NameRemapping( "backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}} @@ -92,12 +101,14 @@ "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), - # Grouped local experts in MTP + # Grouped local experts (used for TEGroupedMLP in both decoder and MTP layers). + # The prefix uses "backbone" for regular decoder layers; when called from MTP + # context (is_mtp=True), _grouped_mlp_merging replaces "backbone" with "mtp". "experts.linear_fc1": GroupedMLPMerging( - "mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True} + "backbone.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP ), "experts.linear_fc2": GroupedMLPMerging( - "mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True} + "backbone.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP ), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index b4c1ec694..a156f2cd8 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -200,6 +200,12 @@ def _name_remapping( state_dict[key] = val else: source_key = mapping.get(key, key) + # A mapping value of None means "skip this key" (keep existing value). + # This is used for fused TE modules where layer_norm_weight is loaded + # separately from a different HF path. + if source_key is None: + state_dict[key] = val + continue # For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding # since bias should always be replicated, not sharded if ( @@ -537,6 +543,15 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar): self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) + # TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.norm, IdentityOp) + and hasattr(layer.mixer.in_proj, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id) + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) @@ -578,6 +593,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp ) + # TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.input_layernorm, IdentityOp) + and hasattr(attention, "linear_qkv") + and hasattr(attention.linear_qkv, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"]( + attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp + ) + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) @@ -671,6 +698,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + # TE spec: pre_mlp_layernorm is fused into linear_fc1 + # (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.pre_mlp_layernorm, IdentityOp) + and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"]( + layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp + ) + def _import_state_dict(self): model = self.model layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm)