From 5e5b575fb3413fbaf04a949c3d0fa5796b79e4f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 25 Mar 2026 09:38:49 +0530 Subject: [PATCH 1/2] fix torchao tests --- tests/models/testing_utils/quantization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 0f1fbde72485..ec74422741c3 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -177,6 +177,11 @@ def _test_quantization_inference(self, config_kwargs): model_quantized.to(torch_device) inputs = self.get_dummy_inputs() + model_dtype = next(model_quantized.parameters()).dtype + inputs = { + k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v + for k, v in inputs.items() + } output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" @@ -930,6 +935,7 @@ def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + @pytest.mark.xfail(reason="dequantize is not implemented in torchao") def test_torchao_dequantize(self): """Test that dequantize() works correctly.""" self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) From 4e01e02395145cd79e258cd40ad4ec0d62f4c42c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 25 Mar 2026 09:41:04 +0530 Subject: [PATCH 2/2] add mslk for additional dependencies. --- .github/workflows/nightly_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 416d2af3fc2e..e242b4b57cb0 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -341,7 +341,7 @@ jobs: additional_deps: ["peft", "kernels"] - backend: "torchao" test_location: "torchao" - additional_deps: [] + additional_deps: [mslk-cuda] - backend: "optimum_quanto" test_location: "quanto" additional_deps: []