Skip to content

Conversation

@Fridah-nv
Copy link
Contributor

@Fridah-nv Fridah-nv commented Feb 5, 2026

What does this PR do?

Type of change: ? new feature

Overview: ?
Supports export NVFP4StaticQantizer in unified huggingface checkpoint, as a deployment path for PTQ algorithms such as MSE

Usage

# checkpoint generation
python examples/llm_ptq/hf_ptq.py --pyt_ckpt_path Qwen/Qwen3-8B  --qformat nvfp4_mse --export_path test-Qwen3-8B-Instruct-MSE-FP8-sweep-FP4 --kv_cache_qformat none --trust_remote_code 

Testing

Tested generated Qwen3 8B checkpoint with trtllm serve and nv_eval example in Model-Optimizer-Internal/examples/nv_eval.

NV eval results:

| Groups |Version|Filter|n-shot|  Metric   |   |Value |   |Stderr|
|--------|-------|------|------|-----------|---|-----:|---|-----:|
|mmlu_str|       |none  |      |exact_match|↑  |0.7186|±  |0.0036|

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added support for static NVFP4 quantizers that utilize pre-computed calibration scales.
    • Introduced new NVFP4 W4A4 quantization configuration with optional FP8 scale sweep.
  • Performance Improvements

    • Static quantizers now skip unnecessary dynamic scaling factor recalculation.
    • Unified quantization handling for improved consistency and efficiency.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 5, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

Introduces support for NVFP4StaticQuantizer to handle pre-computed weight scales alongside existing dynamic quantizers. Changes distinguish between static and dynamic NVFP4 variants across export utilities, configuration definitions, and tensor scaling operations. Static quantizers bypass dynamic calibration; dynamic quantizers proceed normally.

Changes

Cohort / File(s) Summary
Export utilities
modelopt/torch/export/quant_utils.py, modelopt/torch/export/unified_export_hf.py
Added recognition of NVFP4StaticQuantizer in import and weight scaling logic. Export utilities now conditionally skip dynamic scaling computation for static quantizers and use pre-computed scales instead. Config generation labels static NVFP4 as nvfp4_static and fusion preprocessing unifies global_amax for static quantizers.
Quantization configuration
modelopt/torch/quantization/config.py
Introduced new NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG constant with static weight quantizer and dynamic input quantizer, combined with MSE calibration and FP8 scale sweep enablement.
Tensor scaling logic
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Added dual-path quantizer handling: _is_static_quantizer detects static quantizers via global_amax, and new get_weights_scaling_factor_from_quantizer method computes per-block scales for static quantizers from pre-computed values or delegates to dynamic path, with optional FP8 quantization.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title 'support static NVFP4 HF export' accurately summarizes the main change—adding export support for static NVFP4 quantizers in Hugging Face export, which is reflected across all modified files.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fridah/static-fp4-export

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Fridah-nv Fridah-nv force-pushed the fridah/static-fp4-export branch from 9f69993 to df4e6a9 Compare February 5, 2026 19:02
Base automatically changed from asma/refactor-scale-sweep to main February 6, 2026 19:47
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
…FP4QTensor

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv force-pushed the fridah/static-fp4-export branch from c1ea842 to e0606cb Compare February 6, 2026 22:56
@codecov
Copy link

codecov bot commented Feb 6, 2026

Codecov Report

❌ Patch coverage is 19.23077% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.37%. Comparing base (ac30686) to head (9725c34).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
...odelopt/torch/quantization/qtensor/nvfp4_tensor.py 16.00% 21 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #858      +/-   ##
==========================================
- Coverage   73.45%   73.37%   -0.08%     
==========================================
  Files         197      197              
  Lines       20651    20675      +24     
==========================================
+ Hits        15169    15171       +2     
- Misses       5482     5504      +22     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Fridah-nv Fridah-nv marked this pull request as ready for review February 6, 2026 23:31
@Fridah-nv Fridah-nv requested review from a team as code owners February 6, 2026 23:31
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 509-520: The static NVFP4 path uses a scale computed from the
untransposed weight which will mismatch when is_bmm_expert_weight is True;
update the branch that checks is_nvfp4_static to either (1) recompute/reshape
weight_scale using the transposed weight by calling
NVFP4QTensor.get_weights_scaling_factor (or
get_weights_scaling_factor_from_quantizer) on the transposed weight before
calling to_quantized_weight, or (2) add an explicit guard that raises a clear
error when is_bmm_expert_weight and isinstance(weight_quantizer,
NVFP4StaticQuantizer) to prevent misuse; modify the code around the
is_nvfp4_static check (referencing is_bmm_expert_weight, weight_quantizer,
NVFP4StaticQuantizer, NVFP4QTensor.get_weights_scaling_factor[_from_quantizer],
and to_quantized_weight) accordingly.

In `@modelopt/torch/quantization/config.py`:
- Around line 391-411: The new NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG config is
defined but not added to the exported choices set; update the choices collection
(the variable named choices) to include "NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG"
alongside the other NVFP4 entries so it becomes discoverable by algorithms.py
(which expects all supported quantization format names in choices). Locate the
choices definition and append the new config name in the same style/ordering as
the other NVFP4_* entries.
🧹 Nitpick comments (2)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (2)

55-58: Duck-typing check vs isinstance — inconsistency with other call sites.

_is_static_quantizer uses duck typing (hasattr + is not None), but quant_utils.py and unified_export_hf.py use isinstance(weight_quantizer, NVFP4StaticQuantizer). If any non-NVFP4StaticQuantizer object happens to carry a global_amax attribute, this check could produce false positives. Consider aligning on isinstance for consistency, or documenting why duck typing is preferred here.

Proposed fix
+    from modelopt.torch.quantization.nn import NVFP4StaticQuantizer
+
     `@classmethod`
     def _is_static_quantizer(cls, weight_quantizer) -> bool:
         """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
-        return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None
+        return isinstance(weight_quantizer, NVFP4StaticQuantizer) and weight_quantizer.global_amax is not None

110-130: Zero-scale handling differs between static and dynamic paths.

In the static path, per_block_scale[per_block_scale == 0] = 1.0 is applied before the 448.0 / per_block_scale_max normalization (Line 118 vs 127). In the dynamic path (get_weights_scaling_factor, Line 167), the same sentinel is applied after the normalization by weights_scaling_factor_2.

For all-zero blocks this is harmless (quantized values are 0, so the scale is irrelevant during dequant), but the resulting FP8 scale values will differ between the two paths for those blocks. This could complicate debugging or round-trip comparisons.

To align, move the zero guard after normalization:

Align zero-handling with dynamic path
             per_block_scale = per_block_amax / 6.0
-            per_block_scale[per_block_scale == 0] = 1.0
 
             # Reshape per_block_scale to match weight's block structure
             num_blocks_per_row = weight.shape[-1] // block_size
             expected_shape = (*weight.shape[:-1], num_blocks_per_row)
             per_block_scale = per_block_scale.view(expected_shape)
 
             # Quantize scales to FP8
             if not keep_high_precision:
                 per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
                     torch.float8_e4m3fn
                 )
+            per_block_scale_float = per_block_scale.float()
+            per_block_scale_float[per_block_scale_float == 0] = 1.0
+            per_block_scale = per_block_scale_float.to(per_block_scale.dtype)

Comment on lines +509 to +520

# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)

if not is_nvfp4_static:
# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential shape mismatch for BMM-style expert weights with static NVFP4.

When is_bmm_expert_weight is True, the weight is transposed at Line 506–508 (e.g., from (E, in_dim, out_dim)(E, out_dim, in_dim)). The dynamic path (Lines 516–520) correctly recomputes weight_scale from the transposed weight. However, the static path skips recomputation and uses the scale that was computed by get_weight_scaling_factor (Line 461) from the untransposed weight.

Since the static path in NVFP4QTensor.get_weights_scaling_factor_from_quantizer (nvfp4_tensor.py Line 121–123) reshapes the per-block scale using the weight's original shape, the scale would have shape (*untransposed_shape[:-1], num_blocks) which won't match the transposed weight layout expected by to_quantized_weight.

This would fail with a shape error if static NVFP4 quantizers are ever used with Llama4TextExperts or GptOssExperts. If that combination is currently not expected, a guard would prevent a confusing error later:

Proposed guard
         # Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
         # For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
         is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
 
+        if is_nvfp4_static and is_bmm_expert_weight:
+            raise NotImplementedError(
+                "Static NVFP4 quantization is not yet supported for BMM-style expert weights "
+                "(Llama4TextExperts, GptOssExperts). Use dynamic NVFP4 quantization instead."
+            )
+
         if not is_nvfp4_static:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if not is_nvfp4_static:
# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_nvfp4_static and is_bmm_expert_weight:
raise NotImplementedError(
"Static NVFP4 quantization is not yet supported for BMM-style expert weights "
"(Llama4TextExperts, GptOssExperts). Use dynamic NVFP4 quantization instead."
)
if not is_nvfp4_static:
# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
🤖 Prompt for AI Agents
In `@modelopt/torch/export/unified_export_hf.py` around lines 509 - 520, The
static NVFP4 path uses a scale computed from the untransposed weight which will
mismatch when is_bmm_expert_weight is True; update the branch that checks
is_nvfp4_static to either (1) recompute/reshape weight_scale using the
transposed weight by calling NVFP4QTensor.get_weights_scaling_factor (or
get_weights_scaling_factor_from_quantizer) on the transposed weight before
calling to_quantized_weight, or (2) add an explicit guard that raises a clear
error when is_bmm_expert_weight and isinstance(weight_quantizer,
NVFP4StaticQuantizer) to prevent misuse; modify the code around the
is_nvfp4_static check (referencing is_bmm_expert_weight, weight_quantizer,
NVFP4StaticQuantizer, NVFP4QTensor.get_weights_scaling_factor[_from_quantizer],
and to_quantized_weight) accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridah-nv could you fix this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Comment on lines +391 to +411
NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"fp8_scale_sweep": True,
},
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's check the file exists and get context around lines 391-411 and 649-678
wc -l modelopt/torch/quantization/config.py

# Check if the new config exists at lines 391-411
echo "=== Lines 391-420 (new config) ==="
sed -n '391,420p' modelopt/torch/quantization/config.py

# Check the choices set around lines 649-678
echo -e "\n=== Lines 645-680 (choices set) ==="
sed -n '645,680p' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1888


🏁 Script executed:

#!/bin/bash
# Check how choices set is used in the codebase
echo "=== Usage of 'choices' in config.py ==="
rg -n --type=py '\bchoices\b' modelopt/torch/quantization/config.py

echo -e "\n=== References to choices in broader codebase ==="
rg -n --type=py '\bchoices\b' modelopt/torch/quantization/ | grep -v "config.py" | head -20

echo -e "\n=== Imports of choices from config ==="
rg -n --type=py 'from.*config.*import.*choices' modelopt/ | head -10

echo -e "\n=== Direct config.choices references ==="
rg -n --type=py 'config\.choices' modelopt/ | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 2136


🏁 Script executed:

#!/bin/bash
# Check how the choices variable is validated/used
echo "=== Search for validation against choices ==="
rg -n --type=py -A2 -B2 'in.*choices' modelopt/torch/quantization/ | head -40

echo -e "\n=== Check if choices is exported or used at module level ==="
rg -n --type=py 'choices.*=' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3561


🏁 Script executed:

#!/bin/bash
# Check if there are comments explaining why the new config isn't in choices
echo "=== Context before the new config (lines 380-395) ==="
sed -n '380,395p' modelopt/torch/quantization/config.py

# Check if there's anything after the new config or before choices that explains the pattern
echo -e "\n=== Check around the choices set definition (lines 640-660) ==="
sed -n '640,660p' modelopt/torch/quantization/config.py

# See all NVFP4 configs to understand the pattern
echo -e "\n=== All NVFP4 config definitions (look for pattern) ==="
rg -n 'NVFP4.*_CFG = \{' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1633


🏁 Script executed:

#!/bin/bash
# Check if this is mentioned as experimental/temporary in any comments around these lines
echo "=== Comments around the new config (lines 385-415) ==="
sed -n '385,415p' modelopt/torch/quantization/config.py | grep -E '^\s*(#|""")'

# Check if choices set has any explanatory comments
echo -e "\n=== Comments around choices set definition ==="
sed -n '645,655p' modelopt/torch/quantization/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 387


Add NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG to the choices set.

The new config is not registered in the choices set (line 649), which is documented as containing all "supported quantization format names" and is used in algorithms.py:126 for config discovery. All other similar NVFP4 configs are included in choices. Add the new config name to maintain consistency and ensure it's discoverable through the documented API.

🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/config.py` around lines 391 - 411, The new
NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG config is defined but not added to the
exported choices set; update the choices collection (the variable named choices)
to include "NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG" alongside the other NVFP4
entries so it becomes discoverable by algorithms.py (which expects all supported
quantization format names in choices). Locate the choices definition and append
the new config name in the same style/ordering as the other NVFP4_* entries.

@Fridah-nv Fridah-nv changed the title Fridah/static fp4 export support static NVFP4 HF export Feb 7, 2026
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridah-nv Could we add a sample command in the PR description and the MMLU results you have?

Also, is the checkpoint format the same comparing static NVFP4 vs our default NVFP4?

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv requested a review from a team as a code owner February 7, 2026 01:09
@Fridah-nv Fridah-nv requested a review from meenchen February 7, 2026 01:09
@Fridah-nv
Copy link
Contributor Author

@Fridah-nv Could we add a sample command in the PR description and the MMLU results you have?

Added, thanks for reminding.

Also, is the checkpoint format the same comparing static NVFP4 vs our default NVFP4?

Yes, I checked that hf_quant_config.json and config.json are the same for the two formats.


# Handle NVFP4 variants (static or dynamic)
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_nvfp4_static or quantization_format in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the quantization_format for NVFP4StaticQuantizer? do we need is_nvfp4_static here?

# Calibrate weight quantizer if amax is not set for all NVFP4 variants
if quantization_format in [
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)
if is_nvfp4_static or quantization_format in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)

if not is_nvfp4_static:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle the else condition?

weight_quantizer
)

if cls._is_static_quantizer(weight_quantizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in static NVFP4 quant case:

You use _amax for per-block amax and global_amax for the per-tensor amax?

return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None

@classmethod
def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a unittest?

Comment on lines +509 to +520

# Check if this is a static NVFP4 quantizer (has pre-computed scales from MSE calibration)
# For static NVFP4, weight_scale is already computed from static _amax values in get_weight_scaling_factor
is_nvfp4_static = isinstance(weight_quantizer, NVFP4StaticQuantizer)

if not is_nvfp4_static:
# For dynamic NVFP4, compute scales from weights
weight_scale = NVFP4QTensor.get_weights_scaling_factor(
weight,
block_size=block_size,
weights_scaling_factor_2=weight_scale_2,
)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridah-nv could you fix this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants