Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,15 @@ The quantized CogVideoX 5B model below requires ~16GB of VRAM.

```py
import torch
from diffusers import CogVideoXPipeline, AutoModel
from diffusers import CogVideoXPipeline, AutoModel, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from torchao.quantization import Int8WeightOnlyConfig

# quantize weights to int8 with torchao
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="torchao",
quant_kwargs={"quant_type": "int8wo"},
components_to_quantize="transformer"
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
)

# fp8 layerwise weight-casting
Expand Down
44 changes: 13 additions & 31 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
from torchao.quantization import Int8WeightOnlyConfig

pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```

For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.

```py
import torch
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig

pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig("int8wo")}
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128, version=2))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
Expand Down Expand Up @@ -91,18 +74,15 @@ Weight-only quantization stores the model weights in a specific low-bit data typ

Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.

The quantization methods supported are as follows:
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods. An exhaustive list of configuration options are available [here](https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows).

| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some example popular quantization configurations are as follows:

Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.

Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
| **Category** | **Configuration Classes** |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's practically nothing preventing the users from using the configs supported through TorchAO and they might not be limited to the ones we're including the in following table. For example, we can use the more recent NVFP4 and MXFP8 schemes (their respective config classes) here as well.

So, how about we provide examples to the popular config classes like Int8DynamicActivationInt4WeightConfig, Int8WeightOnlyConfig, and Float8DynamicActivationFloat8WeightConfig (with hyperlinks) and then provide a link to available config options (this will be a TorchAO doc link) for the users to explore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the link to the above

|---|---|
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is cool!

| **Integer quantization** | [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int4WeightOnlyConfig.html), [`Int8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8WeightOnlyConfig.html), [`Int8DynamicActivationInt8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8DynamicActivationInt8WeightConfig.html) |
| **Floating point 8-bit quantization** | [`Float8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8WeightOnlyConfig.html), [`Float8DynamicActivationFloat8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) |
| **Unsigned integer quantization** | [`IntxWeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.IntxWeightOnlyConfig.html) |

## Serializing and Deserializing quantized models

Expand All @@ -111,8 +91,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
```python
import torch
from diffusers import AutoModel, TorchAoConfig
from torchao.quantization import Int8WeightOnlyConfig

quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
Expand All @@ -137,18 +118,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```

If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.

```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
from torchao.quantization import IntxWeightOnlyConfig

# Serialize the model
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
Expand Down
Loading
Loading