Skip to content

[OMNIML-3232] Support full TE spec for NemotronH HF-to-Megatron import#884

Merged
yueshen2016 merged 1 commit intomainfrom
yueshen/TE-spec-MLM
Feb 14, 2026
Merged

[OMNIML-3232] Support full TE spec for NemotronH HF-to-Megatron import#884
yueshen2016 merged 1 commit intomainfrom
yueshen/TE-spec-MLM

Conversation

@yueshen2016
Copy link
Contributor

@yueshen2016 yueshen2016 commented Feb 12, 2026

What does this PR do?

Type of change: new feature

Overview: Enable full TE spec support for NemotronH (Mamba hybrid) models during HF-to-Megatron weight import via import_mcore_gpt_from_hf.

Previously, importing HF weights into a Megatron model built with the full TE spec (TELayerNormColumnParallelLinear, TEGroupedMLP, etc.) failed for NemotronH models due to two issues:

  1. Grouped expert prefix bug: The experts.linear_fc1/fc2 import rules had a hard-coded mtp.layers.{} prefix, which was only correct for MTP layers. When regular decoder MoE layers use TEGroupedMLP (via the full TE spec), the importer generated incorrect HF keys (e.g., mtp.layers.27.mixer.experts.0.up_proj.weight instead of backbone.layers.27.mixer.experts.0.up_proj.weight).

  2. Fused layer norm loading: In the full TE spec, layer norms are fused into TELayerNormColumnParallelLinear modules as layer_norm_weight. The importer's _name_remapping would crash trying to load layer_norm_weight from a non-existent HF path (e.g., backbone.layers.X.mixer.in_proj.layer_norm_weight), when the actual HF norm weight lives at backbone.layers.X.norm.weight.

Changes

mcore_nemotron.py:

  • Fixed grouped expert prefix from mtp.layers.{} to backbone.layers.{}. The _grouped_mlp_merging function already handles backbonemtp replacement when is_mtp=True, so both decoder and MTP layers work correctly.
  • Added mapping={"layer_norm_weight": None} to in_proj and linear_fc1 rules to skip layer_norm_weight during _name_remapping (loaded separately via fused_norm).
  • Added fused_norm rule (NameRemapping("backbone.layers.{}.norm.weight")) to load HF norm weights into fused TE modules.

megatron_importer.py:

  • Added source_key is None check in _name_remapping to skip keys mapped to None in the mapping dict (keeps existing value instead of crashing on missing HF key).
  • Added fused norm loading in _import_mamba_layer: after loading in_proj, loads layer_norm_weight from HF via fused_norm rule when layer.norm is IdentityOp.
  • Added fused norm loading in _import_transformer_layer: loads layer_norm_weight into linear_qkv (when input_layernorm is IdentityOp) and into linear_fc1 (when pre_mlp_layernorm is IdentityOp).

Usage

The full TE spec is enabled via the --full-te-spec flag on the Megatron-LM side (separate PR). On the ModelOpt side, no user-facing changes are needed -- the import rules automatically handle both local spec and full TE spec models.

# Convert HF checkpoint to Megatron with full TE spec (megatron-lm side)
unset MLM_MODEL_CKPT && export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm && export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
export PP=2
export MLM_EXTRA_ARGS="--full-te-spec"
bash convert.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

# Quantize the converted checkpoint (megatron-lm side)
export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm
export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm
export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
export PP=2 && export TP=4 && export EP=4 && export ETP=1
export MLM_EXTRA_ARGS="--full-te-spec"
bash quantize.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 FP8_DEFAULT_CFG

# Generate
export PP=2 && export TP=4 && export EP=4 && export ETP=1
export MLM_EXTRA_ARGS="--full-te-spec"
export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && ./generate.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

# MMLU
export PP=2 && export TP=4 && export EP=4 && export ETP=1
export MLM_EXTRA_ARGS="--full-te-spec"
export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && export MLM_EXTRA_ARGS="--fraction 0.05 --disable-tqdm" && ./mmlu.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

Testing

  • Tested end-to-end: HF → Megatron conversion → FP8 quantization → inference (generate) → MMLU evaluation with Nemotron-3-Nano-30B-A3B-BF16.
  • Verified the resulting model structure matches Megatron-Bridge's TE spec output (TELayerNormColumnParallelLinear, TEGroupedMLP, IdentityOp norms, etc.).
  • Verified quantized model produces coherent text generation outputs.
  • Verified backward compatibility: all changes are no-ops for existing local-spec pipelines (guarded by IdentityOp checks, hasattr checks, and "fused_norm" in self.rules checks).

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes -- all changes are guarded by conditions that only activate for full TE spec models. Local spec models follow the exact same code paths as before.
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Additional Information

Companion megatron-lm changes (separate PR):

  • megatron/core/post_training/modelopt/mamba/model_specs.py: Added use_full_te_spec parameter to return canonical mamba_stack_spec from mamba_layer_specs.py.
  • megatron/post_training/model_builder.py: Passes use_full_te_spec=args.full_te_spec to get_mamba_stack_modelopt_spec.
  • megatron/post_training/arguments.py: Added --full-te-spec CLI flag.
  • examples/post_training/modelopt/convert_model.py: Skip moe_grouped_gemm=False override when --full-te-spec is set.

Summary by CodeRabbit

  • New Features

    • Added support for loading fused normalization weights during model import.
  • Bug Fixes

    • Improved weight mapping logic to correctly skip redundant layer norm weights in specialized model architectures.
  • Refactor

    • Reorganized expert model parallel configuration paths for better compatibility with mixed parallel processing settings.

@yueshen2016 yueshen2016 requested a review from a team as a code owner February 12, 2026 22:52
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 12, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The changes update model import/export handling to support fused layer normalization in transformer-based models. Mapping configurations skip direct layer_norm_weight loading, while new logic routes these weights through dedicated fused_norm rules. Expert layer paths are reworked from MTP to backbone-based paths.

Changes

Cohort / File(s) Summary
Nemotron Plugin Mappings
modelopt/torch/export/plugins/mcore_nemotron.py
Updated in_proj and linear_fc1 mappings to skip layer_norm_weight via mapping hints. Introduced new fused_norm rule to load norm weights via dedicated path. Reworked expert layer mappings to use backbone paths with COL_ETP/ROW_ETP, removing is_mtp annotations.
Megatron Importer Fusion Logic
modelopt/torch/export/plugins/megatron_importer.py
Added support for None mapping values to preserve keys. Implemented TE-specific fused norm handling blocks in _import_mamba_layer and _import_transformer_layer that detect IdentityOp norms and route layer_norm_weight through fused_norm rule when available.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (2 files):

⚔️ modelopt/torch/export/plugins/mcore_nemotron.py (content)
⚔️ modelopt/torch/export/plugins/megatron_importer.py (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main objective: adding full Tensor-Expression spec support for NemotronH model imports, which is the primary change in this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yueshen/TE-spec-MLM

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Feb 12, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.74%. Comparing base (ae69d5d) to head (5ddcbe7).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #884   +/-   ##
=======================================
  Coverage   73.74%   73.74%           
=======================================
  Files         199      199           
  Lines       21163    21163           
=======================================
  Hits        15606    15606           
  Misses       5557     5557           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97
Copy link
Collaborator

Can we make sure previous modelopt spec continues to work? Pruning doesnt support full TE spec yet

@kevalmorabia97
Copy link
Collaborator

No concerns from pruning-support POV where we still need to use older non-full-TE modelopt spec

@yueshen2016
Copy link
Contributor Author

Can we make sure previous modelopt spec continues to work? Pruning doesnt support full TE spec yet

Yes, previous one is still working.

Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 enabled auto-merge (squash) February 13, 2026 23:08
@yueshen2016 yueshen2016 merged commit 5c4ef8e into main Feb 14, 2026
37 checks passed
@yueshen2016 yueshen2016 deleted the yueshen/TE-spec-MLM branch February 14, 2026 00:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants