Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ jobs:
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
additional_deps: [mslk-cuda]
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []
Expand Down
6 changes: 6 additions & 0 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def _test_quantization_inference(self, config_kwargs):
model_quantized.to(torch_device)

inputs = self.get_dummy_inputs()
model_dtype = next(model_quantized.parameters()).dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would affect all quantization backends? e.g. With a GGUF backend the dtype could end up as int8 and potentially cast inputs into int8?

Also prefer to avoid casting inputs post fetching from self.get_dummy_inputs() within a test if we can avoid it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should it go then? Should we implement a custom get_dummy_inputs() for torchao tests? I think it's reasonably safe to keep the dtypes of the inputs to bfloat16 there because that will replicate what we do in actual pipelines. LMK.

inputs = {
k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
for k, v in inputs.items()
}
output = model_quantized(**inputs, return_dict=False)[0]

assert output is not None, "Model output is None"
Expand Down Expand Up @@ -930,6 +935,7 @@ def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])

@pytest.mark.xfail(reason="dequantize is not implemented in torchao")
def test_torchao_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
Expand Down
Loading