Feature Request: Quantized KV Cache Support in mx.fast.scaled_dot_product_attention
Summary
Add native support for reading quantized (TurboQuant codebook) KV cache data in mx.fast.scaled_dot_product_attention, eliminating the need for Python-level custom Metal kernels or full-cache dequantization during attention.
Motivation
TurboQuant (Google Research, ICLR 2026) compresses KV caches to 2-4 bits via rotation + codebook quantization. The MLX ecosystem has adopted it widely:
The problem: mx.fast.scaled_dot_product_attention only accepts fp16/bf16/fp32 tensors. When the KV cache is quantized, there are only two options today:
-
Dequantize the entire cache to fp16 before attention — works, fast (uses native SDPA), but creates a full fp16 copy of the KV cache in memory (31 GB activation spike at 128K context on gemma-4-31B). Defeats the purpose of quantization at long context.
-
Custom Metal kernel via mx.fast.metal_kernel — we built a working proof-of-concept that fuses score + online softmax + value accumulation in one dispatch. Correctness is perfect (cosine 1.0 vs dequantize path), memory is bounded. But it's 3-4x slower than native SDPA due to Python dispatch overhead and inability to leverage MLX's internal kernel fusion and memory planning.
Neither option is satisfactory. llama.cpp and vLLM solved this by implementing quantized-cache attention in their C++/CUDA/Metal core — the quantized format is a first-class citizen in the attention kernel. MLX could do the same.
Prior Art in the MLX Ecosystem
-
TheTom's MLX fork (TheTom/mlx:feature/turboquant-plus) — adds mx.fast.scaled_dot_product_attention_qv as a C++ Metal kernel reading 4-bit quantized values. Achieves near-native decode speed. Decode-only (L=1), uses mx.quantize affine format.
-
mlx-vlm _fused_integer_decode_single_tile_kernel — existing Python-generated Metal kernel for TQ decode attention. Fuses score + online softmax + value accumulation. Works well for decode but is invoked via mx.fast.metal_kernel (Python dispatch).
-
Our proof-of-concept (Landon-Molt/mlx-vlm:feat/fused-tq-prefill) — extends the decode kernel to L>1 (prefill) with grid-parallel queries. Validates the algorithm and correctness. Benchmarks below.
Benchmarks (Our Proof-of-Concept)
Single layer, B=1, H=8 KV heads, D=128, GQA 4:1, L=128 queries, MacBook Pro M5 Max 128GB:
| KV tokens |
Fused TQ (Python Metal) |
Dequantize + native SDPA |
Ratio |
| 2K |
0.058s |
0.003s |
17.9x |
| 8K |
0.025s |
0.009s |
2.9x |
| 32K |
0.097s |
0.029s |
3.4x |
| 64K |
0.191s |
0.056s |
3.4x |
The 3-4x gap at long context is entirely from Python→Metal dispatch overhead vs C++ native. The Metal kernel itself is efficient — at short context where dispatch dominates less, the gap narrows to 2.9x.
For a 60-layer model at 32K context, this translates to ~5.8s attention-only (vs ~1.7s native) — the difference between responsive and sluggish prefill.
Proposed Approach
Extend the existing SDPA Metal kernels to read TurboQuant codebook-quantized KV data:
Decode (L=1) — extend sdpa_vector.h:
- Score: unpack key codebook indices inline, dot product with query in rotated space
- Value: unpack value codebook indices inline during weighted accumulation
- Same online softmax, same threading model, same simdgroup reduction
Prefill (L>1) — extend steel flash attention kernels:
- Grid-parallel over query positions (each threadgroup handles one query)
- Tile over KV tokens with online softmax (same as current flash attention)
- Inline codebook dequantization per element during score and value phases
Quantized state format (from mlx-vlm's TurboQuantKVCache):
- Keys:
TurboQuantProdState — norms (fp16, per-token), mse_indices (packed uint32), residual_norms (fp16), qjl_signs (packed uint32), + codebook (small, constant)
- Values:
TurboQuantMSEState — norms (fp16, per-token), indices (packed uint32), + codebook (small, constant)
- Dequant per element:
codebook[unpack_bits(packed, d)] * norm[t]
API Options
Option A: New function
mx.fast.scaled_dot_product_attention_tq(
queries, keys_state, values_state,
key_codebook, value_codebook,
scale=..., mask=...
)
Option B: Extend existing SDPA to detect quantized inputs
# If keys/values are QuantizedStateProxy objects, dispatch to TQ kernel
mx.fast.scaled_dot_product_attention(queries, keys, values, scale=..., mask=...)
Related Issues
- mlx-vlm#1016 — root cause:
prefill_attention() dead after ProdCodec removal
- mlx-vlm#939 — tiled prefill (2 dispatches, 17x slower)
- mlx-lm#1060 — TurboQuant community discussion
- #3302 — GPU watchdog on long-context SDPA (related: long prefill)
- #3361 — SDPA fix for >32K KV (actively maintained area)
Feature Request: Quantized KV Cache Support in
mx.fast.scaled_dot_product_attentionSummary
Add native support for reading quantized (TurboQuant codebook) KV cache data in
mx.fast.scaled_dot_product_attention, eliminating the need for Python-level custom Metal kernels or full-cache dequantization during attention.Motivation
TurboQuant (Google Research, ICLR 2026) compresses KV caches to 2-4 bits via rotation + codebook quantization. The MLX ecosystem has adopted it widely:
The problem:
mx.fast.scaled_dot_product_attentiononly accepts fp16/bf16/fp32 tensors. When the KV cache is quantized, there are only two options today:Dequantize the entire cache to fp16 before attention — works, fast (uses native SDPA), but creates a full fp16 copy of the KV cache in memory (31 GB activation spike at 128K context on gemma-4-31B). Defeats the purpose of quantization at long context.
Custom Metal kernel via
mx.fast.metal_kernel— we built a working proof-of-concept that fuses score + online softmax + value accumulation in one dispatch. Correctness is perfect (cosine 1.0 vs dequantize path), memory is bounded. But it's 3-4x slower than native SDPA due to Python dispatch overhead and inability to leverage MLX's internal kernel fusion and memory planning.Neither option is satisfactory. llama.cpp and vLLM solved this by implementing quantized-cache attention in their C++/CUDA/Metal core — the quantized format is a first-class citizen in the attention kernel. MLX could do the same.
Prior Art in the MLX Ecosystem
TheTom's MLX fork (TheTom/mlx:feature/turboquant-plus) — adds
mx.fast.scaled_dot_product_attention_qvas a C++ Metal kernel reading 4-bit quantized values. Achieves near-native decode speed. Decode-only (L=1), usesmx.quantizeaffine format.mlx-vlm
_fused_integer_decode_single_tile_kernel— existing Python-generated Metal kernel for TQ decode attention. Fuses score + online softmax + value accumulation. Works well for decode but is invoked viamx.fast.metal_kernel(Python dispatch).Our proof-of-concept (Landon-Molt/mlx-vlm:feat/fused-tq-prefill) — extends the decode kernel to L>1 (prefill) with grid-parallel queries. Validates the algorithm and correctness. Benchmarks below.
Benchmarks (Our Proof-of-Concept)
Single layer, B=1, H=8 KV heads, D=128, GQA 4:1, L=128 queries, MacBook Pro M5 Max 128GB:
The 3-4x gap at long context is entirely from Python→Metal dispatch overhead vs C++ native. The Metal kernel itself is efficient — at short context where dispatch dominates less, the gap narrows to 2.9x.
For a 60-layer model at 32K context, this translates to ~5.8s attention-only (vs ~1.7s native) — the difference between responsive and sluggish prefill.
Proposed Approach
Extend the existing SDPA Metal kernels to read TurboQuant codebook-quantized KV data:
Decode (L=1) — extend
sdpa_vector.h:Prefill (L>1) — extend steel flash attention kernels:
Quantized state format (from mlx-vlm's TurboQuantKVCache):
TurboQuantProdState—norms(fp16, per-token),mse_indices(packed uint32),residual_norms(fp16),qjl_signs(packed uint32), + codebook (small, constant)TurboQuantMSEState—norms(fp16, per-token),indices(packed uint32), + codebook (small, constant)codebook[unpack_bits(packed, d)] * norm[t]API Options
Option A: New function
Option B: Extend existing SDPA to detect quantized inputs
Related Issues
prefill_attention()dead after ProdCodec removal