From 188b74dc4c7d9c482c02cbde668c6734b9d86a75 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 12 Feb 2026 22:35:29 +0000 Subject: [PATCH 1/6] added support for rotate in fp32 Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/config.py | 14 +++++++---- modelopt/torch/quantization/nn/functional.py | 8 +++++-- .../nn/modules/tensor_quantizer.py | 17 ++++++++++--- tests/gpu/torch/quantization/test_hadamard.py | 24 +++++++++++++++---- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index e1b48ee60..ada417375 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -934,14 +934,20 @@ def validate_calibrator(cls, v, info: ValidationInfo): assert v in ["max", "histogram"] return v - rotate: bool = ModeloptField( + rotate: bool | dict[str, bool] = ModeloptField( default=False, - title="""If rotate the input before quantization.""", - description=""""If true, the input of the quantizer will be rotated with a hadamard matrix + title="""Configuration for rotating the input before quantization.""", + description="""Can be a boolean or a dictionary with the following keys: + - "enable": Boolean to enable/disable rotation (default: False) + - "rotate_fp32": Boolean to compute rotation in float32 precision (default: False) + + If a boolean is provided, it is treated as the "enable" value with "rotate_fp32" defaulting to False. + + When enabled, the input of the quantizer will be rotated with a hadamard matrix given by scipy.linalg.hadamard, i.e. ``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``. - This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant. + This can be used for rotation based PTQ methods, e.g. QuaRot or SpinQuant. See https://arxiv.org/abs/2404.00456 for example.""", ) diff --git a/modelopt/torch/quantization/nn/functional.py b/modelopt/torch/quantization/nn/functional.py index df8bcbbcd..662aea66e 100644 --- a/modelopt/torch/quantization/nn/functional.py +++ b/modelopt/torch/quantization/nn/functional.py @@ -93,7 +93,7 @@ def backward(ctx, grad_outputs): return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined] -def normalized_hadamard_transform(inputs): +def normalized_hadamard_transform(inputs, rotate_fp32=False): """Normalized fast hadamard transform.""" global fast_hadamard_transform try: @@ -104,6 +104,10 @@ def normalized_hadamard_transform(inputs): "`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`" ) - return FastHadamardTransform.apply(inputs) / torch.sqrt( + dtype = inputs.dtype + if rotate_fp32: + inputs = inputs.float() + outputs = FastHadamardTransform.apply(inputs) / torch.sqrt( torch.tensor(inputs.shape[-1], dtype=torch.float32) ) + return outputs.to(dtype) if rotate_fp32 else outputs diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3852d1144..6726a8065 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -992,8 +992,14 @@ def forward(self, inputs): inputs = inputs * self.pre_quant_scale # Rotating the input - if self._rotate: - inputs = normalized_hadamard_transform(inputs) + rotate_fp32 = ( + self._rotate.get("rotate_fp32", False) if isinstance(self._rotate, dict) else False + ) + rotate_enable = ( + self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate + ) + if rotate_enable: + inputs = normalized_hadamard_transform(inputs, rotate_fp32=rotate_fp32) if self._disabled: # if quantizer is disabled, we still need to track the input dtype for saving the model @@ -1105,7 +1111,12 @@ def extra_repr(self): if self.pre_quant_scale is not None else "" ) - s += " rotated" if self._rotate else "" + s += ( + " rotated" + if (isinstance(self._rotate, dict) and self._rotate.get("enable", False)) + or self._rotate + else "" + ) s += ( f" calibrator={self._calibrator.__class__.__name__}" if (self._calibrator is not None) diff --git a/tests/gpu/torch/quantization/test_hadamard.py b/tests/gpu/torch/quantization/test_hadamard.py index c768bc87e..07c179026 100644 --- a/tests/gpu/torch/quantization/test_hadamard.py +++ b/tests/gpu/torch/quantization/test_hadamard.py @@ -41,9 +41,17 @@ def test_hadamard_transform(dim): xxt_h = x_h @ x_h.T # The numerical error can be large, especially for 16-bit floats. assert torch.allclose(xxt_h, xxt, atol=0.05) + x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True) + xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T + # test the numerical error is smaller when using float32 + assert torch.allclose(xxt_h_fp32, xxt, atol=1e-6) -def test_kv_rotate(): +@pytest.mark.parametrize( + "rotate_fp32", + [True, False], +) +def test_kv_rotate(rotate_fp32): mtq.plugins.register_attention_for_kv_quant(SDPAAttention) model = nn.Sequential(SDPAAttention()) mtq.replace_quant_module(model) @@ -51,27 +59,33 @@ def test_kv_rotate(): set_quantizer_by_cfg(model, {"*": {"enable": False}}) dummy_input = SDPAAttention.get_input(device="cuda") output_ref = model(dummy_input) + if rotate_fp32: + rotate = {"enable": True, "rotate_fp32": True} + atol = 1e-6 + else: + rotate = True + atol = 0.05 with set_quantizer_by_cfg_context( model, { "*[qk]_bmm_quantizer": { - "rotate": True, + "rotate": rotate, }, }, ): output_test = model(dummy_input) - assert torch.allclose(output_ref, output_test, atol=0.05) + assert torch.allclose(output_ref, output_test, atol=atol) # Test the rotation is actually applied by turning on only one of the query, key quantizers with set_quantizer_by_cfg_context( model, { "*k_bmm_quantizer": { - "rotate": True, + "rotate": rotate, }, }, ): output_test1 = model(dummy_input) - assert not torch.allclose(output_ref, output_test1, atol=0.05) + assert not torch.allclose(output_ref, output_test1, atol=atol) mtq.unregister(SDPAAttention) From fc7e476c284c67af1fcfea53ddd5cb04ffd97720 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 12 Feb 2026 22:41:10 +0000 Subject: [PATCH 2/6] updated changelog Signed-off-by: Kinjal Patel --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbbe6ab9e..2064dc66e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. - Add PTQ support for Nemotron Parse. +- Add support for rotating the input before quantization for RHT. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ From 4e23527c599dc6d8c6bbab39a2afbc224747d23a Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 12 Feb 2026 22:44:02 +0000 Subject: [PATCH 3/6] minor Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/nn/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/nn/functional.py b/modelopt/torch/quantization/nn/functional.py index 662aea66e..0beb7c956 100644 --- a/modelopt/torch/quantization/nn/functional.py +++ b/modelopt/torch/quantization/nn/functional.py @@ -106,7 +106,7 @@ def normalized_hadamard_transform(inputs, rotate_fp32=False): dtype = inputs.dtype if rotate_fp32: - inputs = inputs.float() + inputs = inputs.to(torch.float32) outputs = FastHadamardTransform.apply(inputs) / torch.sqrt( torch.tensor(inputs.shape[-1], dtype=torch.float32) ) From 39b8b3269ab021065d25c2bce754104a775fa897 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 12 Feb 2026 22:57:25 +0000 Subject: [PATCH 4/6] minor Signed-off-by: Kinjal Patel --- .../quantization/nn/modules/tensor_quantizer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 6726a8065..0a92efa97 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1111,12 +1111,13 @@ def extra_repr(self): if self.pre_quant_scale is not None else "" ) - s += ( - " rotated" - if (isinstance(self._rotate, dict) and self._rotate.get("enable", False)) - or self._rotate - else "" - ) + if isinstance(self._rotate, dict): + if self._rotate.get("enable", False): + s += " rotated" + if self._rotate.get("rotate_fp32", False): + s += " (fp32)" + elif self._rotate: + s += " rotated" s += ( f" calibrator={self._calibrator.__class__.__name__}" if (self._calibrator is not None) From 044a4abc871185997dd3bd404df0266d50b3183f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Fri, 13 Feb 2026 18:50:08 +0000 Subject: [PATCH 5/6] minor Signed-off-by: Kinjal Patel --- .../nn/modules/tensor_quantizer.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0a92efa97..937714f52 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -525,6 +525,20 @@ def is_static_block_quant(self): and self._fake_quant ) + @property + def rotate_is_enabled(self): + """Check if rotate is enabled in quant config.""" + return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate + + @property + def rotate_is_fp32(self): + """Check if rotation needs to be computed in float32.""" + return ( + self._rotate.get("rotate_fp32", False) + if isinstance(self._rotate, dict) and self.rotate_is_enabled + else False + ) + def disable_calib(self): """Disable calibration.""" self._if_calib = False @@ -992,14 +1006,8 @@ def forward(self, inputs): inputs = inputs * self.pre_quant_scale # Rotating the input - rotate_fp32 = ( - self._rotate.get("rotate_fp32", False) if isinstance(self._rotate, dict) else False - ) - rotate_enable = ( - self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate - ) - if rotate_enable: - inputs = normalized_hadamard_transform(inputs, rotate_fp32=rotate_fp32) + if self.rotate_is_enabled: + inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32) if self._disabled: # if quantizer is disabled, we still need to track the input dtype for saving the model @@ -1111,13 +1119,8 @@ def extra_repr(self): if self.pre_quant_scale is not None else "" ) - if isinstance(self._rotate, dict): - if self._rotate.get("enable", False): - s += " rotated" - if self._rotate.get("rotate_fp32", False): - s += " (fp32)" - elif self._rotate: - s += " rotated" + s += " rotated" if self.rotate_is_enabled else "" + s += " (fp32)" if self.rotate_is_fp32 else "" s += ( f" calibrator={self._calibrator.__class__.__name__}" if (self._calibrator is not None) From 9920a1eff9c1059e324ba952c375f47e6b39c215 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Fri, 13 Feb 2026 21:38:58 +0000 Subject: [PATCH 6/6] minor Signed-off-by: Kinjal Patel --- tests/gpu/torch/quantization/test_hadamard.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/gpu/torch/quantization/test_hadamard.py b/tests/gpu/torch/quantization/test_hadamard.py index 07c179026..64dd39e2c 100644 --- a/tests/gpu/torch/quantization/test_hadamard.py +++ b/tests/gpu/torch/quantization/test_hadamard.py @@ -43,8 +43,7 @@ def test_hadamard_transform(dim): assert torch.allclose(xxt_h, xxt, atol=0.05) x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True) xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T - # test the numerical error is smaller when using float32 - assert torch.allclose(xxt_h_fp32, xxt, atol=1e-6) + assert torch.allclose(xxt_h_fp32, xxt, atol=0.05) @pytest.mark.parametrize( @@ -61,10 +60,8 @@ def test_kv_rotate(rotate_fp32): output_ref = model(dummy_input) if rotate_fp32: rotate = {"enable": True, "rotate_fp32": True} - atol = 1e-6 else: rotate = True - atol = 0.05 with set_quantizer_by_cfg_context( model, { @@ -74,7 +71,7 @@ def test_kv_rotate(rotate_fp32): }, ): output_test = model(dummy_input) - assert torch.allclose(output_ref, output_test, atol=atol) + assert torch.allclose(output_ref, output_test, atol=0.05) # Test the rotation is actually applied by turning on only one of the query, key quantizers with set_quantizer_by_cfg_context( @@ -86,6 +83,6 @@ def test_kv_rotate(rotate_fp32): }, ): output_test1 = model(dummy_input) - assert not torch.allclose(output_ref, output_test1, atol=atol) + assert not torch.allclose(output_ref, output_test1, atol=0.05) mtq.unregister(SDPAAttention)