Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ M \times \left\lceil K / gs \right\rceil
$$

## 🚀 What is new in QuTLASS v0.2:
- **V2 API with Arbitrary K Support** using CUTLASS 4.x:
- Supports K = 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, ... (removes K≤256 limitation)
- Automatic dispatch for K > 256 via `fusedQuantizeMx(..., use_v2=None)`
- Batched input support `[B, M, K]` with automatic reshaping
- CollectiveBuilder pattern for optimal tile sizes across all K dimensions
- **FlashInfer backend** support for **B200 GPUs**
- **Quantization-Aware Training (QAT)** via MXFP types:
- Quartet clipping mask computation integrated in quantization routines
Expand Down Expand Up @@ -110,14 +115,98 @@ in the root folder of this repository.

Correctness tests can be executed via ```python tests/mxfp4_test.py``` and benchmarks via ```python benchmarks/bench_mxfp4.py```.

### Basic Usage

The fused quantization kernel can be invoked directly through ```qutlass.fusedQuantizeMx(a, h, method)```. Here, ```a``` is the input tensor to quantize, ```h``` is the Hadamard matrix, and ```method``` is the quantization scheme specified as ```Literal["quest", "abs_max"]```.
The kernel interface is defined in ```qutlass/csrc/fused_quantize_mx.cu```.
The outputs include ```aq```, the quantized data in FP4 (```e2m1```), and ```a_sf``` the corresponding scaling factors in FP8 (```e8m0```).

```python
import torch
import qutlass
from scipy.linalg import hadamard

# Setup
M, K = 512, 256
device = torch.device('cuda')
H = torch.tensor(hadamard(K) * K**-0.5, dtype=torch.bfloat16, device=device)
A = torch.randn(M, K, dtype=torch.bfloat16, device=device)

# Quantize
A_e2m1, A_e8m0 = qutlass.fusedQuantizeMx(A, H, method='quest')
```

### V2 API: Arbitrary K Support (K > 256)

QuTLASS v0.2+ includes a **v2 API** using CUTLASS 4.x that supports **arbitrary K dimensions** (32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, ...), removing the K≤256 limitation of the legacy API.

The v2 API is **automatically selected** when K > 256. You can also manually force v2 with `use_v2=True`:

```python
import torch
import qutlass
from scipy.linalg import hadamard

# Large K dimension (beyond legacy 256 limit)
M, K = 4096, 1024
device = torch.device('cuda')

# Generate Hadamard matrix
H = torch.tensor(hadamard(K) * K**-0.5, dtype=torch.bfloat16, device=device)
A = torch.randn(M, K, dtype=torch.bfloat16, device=device)

# Automatically uses v2 for K > 256
A_e2m1, A_e8m0 = qutlass.fusedQuantizeMx(A, H, method='quest')

# Or manually force v2 API
A_e2m1, A_e8m0 = qutlass.fusedQuantizeMx(A, H, method='quest', use_v2=True)
```

**Batched Inputs**: The v2 API supports batched inputs `[B, M, K]` by automatically reshaping:

```python
# 3D batched input
B, M, K = 4, 256, 1024
A_batched = torch.randn(B, M, K, dtype=torch.bfloat16, device=device)

# Outputs preserve batch dimension: [B, M, K/2] and [B, M, K/32]
A_e2m1, A_e8m0 = qutlass.fusedQuantizeMx(A_batched, H, method='quest')
```

**Direct V2 Call**: You can also call the v2 API directly (requires 2D input):

```python
A_e2m1, A_e8m0 = qutlass.fusedQuantizeMx_v2(A, H, method='quest')
```

**Requirements**:
- K must be divisible by 32
- Requires Blackwell GPU (SM100 or SM120)
- BF16 input dtype

### Matmul Integration

The matmul kernel can be called via ```qutlass.matmul_mxf4_bf16_tn(aq, bq, a_sf, b_sf, alpha)```. Its implementation can be found in ```qutlass/csrc/gemm.cu```.
To use this matmul kernel, the scaling factors must be first rearranged into a block-scaled swizzle [format](https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout).
The ```qutlass.to_blocked```, located in ```qutlass/utils.py```, handles this reordering.

```python
# End-to-end example with v2 API and matmul
M, N, K = 512, 512, 1024

# Quantize both matrices
A_e2m1, A_e8m0 = qutlass.fusedQuantizeMx(A, H, method='quest')
B_e2m1, B_e8m0 = qutlass.fusedQuantizeMx(B, H, method='quest')

# Convert scales to blocked format
A_scale = qutlass.utils.to_blocked(A_e8m0, use_triton_kernel=True)
B_scale = qutlass.utils.to_blocked(B_e8m0, use_triton_kernel=True)

# MXFP4 matmul
alpha = torch.tensor([1.0], device=device)
C = qutlass.matmul_mxf4_bf16_tn(A_e2m1, B_e2m1, A_scale, B_scale, alpha, backend='cutlass')
```

In addition to the previous CUTLASS-powered MXFP4 matmul kernel, we provide a custom prototype kernel that can be called via ```qutlass.matmul_ada_mxf4_bf16_tn(...)```.
This implementation is located in ```qutlass/csrc/gemm_ada.cu``` and does **not** require the previous invocation to ```to_blocked```.
Optimization efforts for this kernel have primarily targeted small batch sizes(i.e., $bs=1\sim 32$). For larger sizes, ```qutlass.matmul_mxf4_bf16_tn``` is recommended.
Expand Down
71 changes: 71 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: qutlass
channels:
- pytorch
- nvidia
- conda-forge
- defaults

dependencies:
# Python version - use 3.10 for compatibility with qutlass tests
- python=3.10

# Build tools (required for CUDA compilation)
- cmake>=3.26
- ninja

# CUDA toolkit and cuDNN (for Blackwell sm_120 support)
# Note: Use system CUDA 12.8+ if available, otherwise conda will install
- cudnn=9.15.1.9 # Required for FlashInfer MXFP4 hardware acceleration

# Core dependencies
- numpy>=1.21.0
- scipy # Required for Hadamard matrices in tests
- matplotlib
- pandas

# Python package manager
- pip

# Pip-installable dependencies
- pip:
# PyTorch 2.8+ with CUDA 12.8 support (for Blackwell sm_120)
- torch>=2.8.0
- triton>=2.1.0

# Testing framework
- pytest>=7.0.0
- pytest-xdist # Parallel test execution

# FlashInfer for 4x speedup on Blackwell (optional)
# Install from local fork if you've fixed bugs, otherwise:
# - --extra-index-url https://flashinfer.ai/whl/cu124/torch2.9/
# - flashinfer-python>=0.5.3

# Development tools
- black
- isort
- flake8

# Utilities
- tqdm
- tabulate

# Installation instructions:
# 1. Create environment:
# conda env create -f environment.yml
#
# 2. Activate:
# conda activate qutlass
#
# 3. Install qutlass:
# pip install --no-build-isolation -e .
#
# 4. Install FlashInfer (if you have local fork with fixes):
# cd ../flashinfer && pip install -e .
#
# 5. Run tests:
# pytest tests/mxfp4_test.py -v
#
# Notes:
# - Python 3.10 is required (tests fail on 3.14 due to dataclasses changes)
# - CUDA 12.8+ required for Blackwell sm_120 support
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ build-backend = "setuptools.build_meta"

[project]
name = "qutlass"
version = "0.2.0"
version = "0.4.0"
description = "qutlass"
authors = [{name = "Roberto L. Castro", email = "Roberto.LopezCastro@ist.ac.at"}]
authors = [{name = "Roberto L. Castro", email = "Roberto.LopezCastro@ist.ac.at"},
{name = "Moiz A. Yousufi", email = "moiz.yousufi@gatech.edu"}]
license = {text = "Apache-2.0"}
dependencies = []

Expand Down
Loading