diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 416d2af3fc2e..e242b4b57cb0 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -341,7 +341,7 @@ jobs: additional_deps: ["peft", "kernels"] - backend: "torchao" test_location: "torchao" - additional_deps: [] + additional_deps: [mslk-cuda] - backend: "optimum_quanto" test_location: "quanto" additional_deps: [] diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 0f1fbde72485..ec74422741c3 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -177,6 +177,11 @@ def _test_quantization_inference(self, config_kwargs): model_quantized.to(torch_device) inputs = self.get_dummy_inputs() + model_dtype = next(model_quantized.parameters()).dtype + inputs = { + k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v + for k, v in inputs.items() + } output = model_quantized(**inputs, return_dict=False)[0] assert output is not None, "Model output is None" @@ -930,6 +935,7 @@ def test_torchao_device_map(self): """Test that device_map='auto' works correctly with quantization.""" self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + @pytest.mark.xfail(reason="dequantize is not implemented in torchao") def test_torchao_dequantize(self): """Test that dequantize() works correctly.""" self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])