-
Notifications
You must be signed in to change notification settings - Fork 29
feat: port TQ3_0 KV cache from llama-turboquant #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: prism
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -429,7 +429,8 @@ extern "C" { | |||||
| GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) | ||||||
| GGML_TYPE_Q1_0 = 40, | ||||||
| GGML_TYPE_Q1_0_g128 = 41, | ||||||
| GGML_TYPE_COUNT = 42, | ||||||
| GGML_TYPE_TQ3_0 = 42, // TurboQuant 3-bit polar + QJL (no per-block scale) | ||||||
|
||||||
| GGML_TYPE_TQ3_0 = 42, // TurboQuant 3-bit polar + QJL (no per-block scale) | |
| GGML_TYPE_TQ3_0 = 42, // TurboQuant 3-bit polar + QJL (with per-block scale) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -276,6 +276,21 @@ typedef struct { | |||||||||||||||||||||||||||||||||||||||||||||||
| } block_tq2_0; | ||||||||||||||||||||||||||||||||||||||||||||||||
| static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // TurboQuant 3-bit quantization (3.5 bpw) | ||||||||||||||||||||||||||||||||||||||||||||||||
| // Per TurboQuant paper (Algorithm 2: TurboQuant_prod), ICLR 2026 | ||||||||||||||||||||||||||||||||||||||||||||||||
| // Each block of 32 values is quantized as: | ||||||||||||||||||||||||||||||||||||||||||||||||
| // - 2-bit MSE codebook indices (after random rotation Π·x) | ||||||||||||||||||||||||||||||||||||||||||||||||
| // - 1-bit QJL residual signs (sign(S·r) where r = x - dequant_mse(quant_mse(x))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| // - FP16 residual norm ||r||₂ for QJL scaling | ||||||||||||||||||||||||||||||||||||||||||||||||
| // Requires per-model rotation matrices Π and S (stored externally) | ||||||||||||||||||||||||||||||||||||||||||||||||
| #define QK_TQ3_0 32 | ||||||||||||||||||||||||||||||||||||||||||||||||
| typedef struct { | ||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t qs[QK_TQ3_0 / 4]; // 2-bit codebook indices, 32 × 2 bits = 8 bytes | ||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t qr[QK_TQ3_0 / 8]; // QJL residual signs, 32 × 1 bit = 4 bytes | ||||||||||||||||||||||||||||||||||||||||||||||||
| ggml_half gamma; // ||residual||₂ for QJL correction scaling | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+280
to
+290
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // Per TurboQuant paper (Algorithm 2: TurboQuant_prod), ICLR 2026 | |
| // Each block of 32 values is quantized as: | |
| // - 2-bit MSE codebook indices (after random rotation Π·x) | |
| // - 1-bit QJL residual signs (sign(S·r) where r = x - dequant_mse(quant_mse(x))) | |
| // - FP16 residual norm ||r||₂ for QJL scaling | |
| // Requires per-model rotation matrices Π and S (stored externally) | |
| #define QK_TQ3_0 32 | |
| typedef struct { | |
| uint8_t qs[QK_TQ3_0 / 4]; // 2-bit codebook indices, 32 × 2 bits = 8 bytes | |
| uint8_t qr[QK_TQ3_0 / 8]; // QJL residual signs, 32 × 1 bit = 4 bytes | |
| ggml_half gamma; // ||residual||₂ for QJL correction scaling | |
| // Implementation note: the on-wire format used here follows the current | |
| // ggml implementation, which applies a fixed WHT+sign preconditioner rather | |
| // than storing or requiring external rotation matrices. | |
| // Each block of 32 values stores: | |
| // - 2-bit quantized indices | |
| // - 1-bit residual/sign bits | |
| // - FP16 per-block scale d | |
| #define QK_TQ3_0 32 | |
| typedef struct { | |
| uint8_t qs[QK_TQ3_0 / 4]; // 2-bit codebook indices, 32 × 2 bits = 8 bytes | |
| uint8_t qr[QK_TQ3_0 / 8]; // residual/sign bits, 32 × 1 bit = 4 bytes | |
| ggml_half gamma; // per-block FP16 scale d |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -396,6 +396,10 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { | |||
| .vec_dot_type = GGML_TYPE_Q8_K, | ||||
| .nrows = 1, | ||||
| }, | ||||
| [GGML_TYPE_TQ3_0] = { | ||||
| .from_float = quantize_row_tq3_0, | ||||
|
||||
| .from_float = quantize_row_tq3_0, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -486,6 +486,50 @@ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_ | |
| } | ||
| } | ||
|
|
||
| // TurboQuant TQ3_0: 2-bit codebook dequantization + inverse WHT | ||
| // Dequantize to rotated space, then apply inverse WHT32 cooperatively | ||
| template<typename dst_t> | ||
| static __global__ void dequantize_block_tq3_0(const void * __restrict__ vx, dst_t * __restrict__ yy) { | ||
| const float centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f }; | ||
| const int8_t signs[32] = { | ||
| +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, | ||
| +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 | ||
| }; | ||
|
|
||
| const int64_t i = blockIdx.x; | ||
| const block_tq3_0 * x = (const block_tq3_0 *)vx; | ||
| const int tid = threadIdx.x; | ||
| if (tid >= 32) return; | ||
|
|
||
| const float d = __half2float(x[i].gamma); | ||
|
|
||
| // Step 1: Each thread dequantizes its value (in rotated space) | ||
| const int byte_idx = tid / 4; | ||
| const int bit_shift = 2 * (tid % 4); | ||
| const int idx = (x[i].qs[byte_idx] >> bit_shift) & 3; | ||
|
|
||
| __shared__ float shmem[32]; | ||
| shmem[tid] = d * centroids[idx]; | ||
| __syncthreads(); | ||
|
|
||
| // Step 2: Cooperative inverse WHT (5 butterfly stages) | ||
| for (int step = 1; step < 32; step <<= 1) { | ||
| int partner = tid ^ step; // butterfly partner | ||
| float a = shmem[tid]; | ||
| float b = shmem[partner]; | ||
| __syncthreads(); | ||
| if (tid < partner) { | ||
| shmem[tid] = a + b; | ||
| shmem[partner] = a - b; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // Step 3: Normalize and undo sign flips | ||
| const float inv_sqrt32 = 0.17677669529663688f; | ||
| yy[i * QK_TQ3_0 + tid] = shmem[tid] * inv_sqrt32 * signs[tid]; | ||
| } | ||
|
Comment on lines
+489
to
+531
|
||
|
|
||
| template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> | ||
| static void dequantize_block_cuda(const void * vx, dst_t * y, | ||
| const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
|
|
@@ -617,6 +661,12 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t | |
| dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y); | ||
| } | ||
|
|
||
| template<typename dst_t> | ||
| static void dequantize_row_tq3_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { | ||
| const int nb = k / QK_TQ3_0; | ||
| dequantize_block_tq3_0<<<nb, 32, 0, stream>>>(vx, y); | ||
| } | ||
|
|
||
| template <typename src_t, typename dst_t> | ||
| static __global__ void convert_unary( | ||
| const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, | ||
|
|
@@ -719,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { | |
| return dequantize_row_iq3_s_cuda; | ||
| case GGML_TYPE_MXFP4: | ||
| return dequantize_row_mxfp4_cuda; | ||
| case GGML_TYPE_TQ3_0: | ||
| return dequantize_row_tq3_0_cuda; | ||
| case GGML_TYPE_F32: | ||
| return convert_unary_cont_cuda<float>; | ||
| case GGML_TYPE_BF16: | ||
|
|
@@ -774,6 +826,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | |
| return dequantize_row_iq3_s_cuda; | ||
| case GGML_TYPE_MXFP4: | ||
| return dequantize_row_mxfp4_cuda; | ||
| case GGML_TYPE_TQ3_0: | ||
| return dequantize_row_tq3_0_cuda; | ||
| case GGML_TYPE_F16: | ||
| return convert_unary_cont_cuda<half>; | ||
| case GGML_TYPE_BF16: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -211,6 +211,79 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { | |||||||||||||||||
| quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // TQ3_0: Device-side Walsh-Hadamard Transform (WHT32) for rotation | ||||||||||||||||||
| // Same sign pattern as CPU (must match for consistency) | ||||||||||||||||||
| static __device__ __forceinline__ void tq3_wht32_forward_device(float * x) { | ||||||||||||||||||
| const int8_t signs[32] = { | ||||||||||||||||||
| +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, | ||||||||||||||||||
| +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 | ||||||||||||||||||
| }; | ||||||||||||||||||
| for (int j = 0; j < 32; j++) x[j] *= signs[j]; | ||||||||||||||||||
| for (int step = 1; step < 32; step <<= 1) { | ||||||||||||||||||
| for (int i = 0; i < 32; i += step * 2) { | ||||||||||||||||||
| for (int j = i; j < i + step; j++) { | ||||||||||||||||||
| float a = x[j], b = x[j + step]; | ||||||||||||||||||
| x[j] = a + b; x[j + step] = a - b; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| const float s = 0.17677669529663688f; // 1/sqrt(32) | ||||||||||||||||||
| for (int j = 0; j < 32; j++) x[j] *= s; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| static __device__ __forceinline__ void tq3_wht32_inverse_device(float * x) { | ||||||||||||||||||
| for (int step = 1; step < 32; step <<= 1) { | ||||||||||||||||||
| for (int i = 0; i < 32; i += step * 2) { | ||||||||||||||||||
| for (int j = i; j < i + step; j++) { | ||||||||||||||||||
| float a = x[j], b = x[j + step]; | ||||||||||||||||||
| x[j] = a + b; x[j + step] = a - b; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
| const int8_t signs[32] = { | ||||||||||||||||||
| +1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1, | ||||||||||||||||||
| +1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1 | ||||||||||||||||||
| }; | ||||||||||||||||||
| const float s = 0.17677669529663688f; | ||||||||||||||||||
| for (int j = 0; j < 32; j++) x[j] *= s * signs[j]; | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // TQ3_0: GPU-side 2-bit scalar codebook quantization with WHT rotation | ||||||||||||||||||
| static __device__ void quantize_f32_tq3_0_block(const float * __restrict__ x, block_tq3_0 * __restrict__ y) { | ||||||||||||||||||
| const float centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f }; | ||||||||||||||||||
|
|
||||||||||||||||||
| // Copy and apply WHT rotation | ||||||||||||||||||
| float rotated[QK_TQ3_0]; | ||||||||||||||||||
| for (int j = 0; j < QK_TQ3_0; j++) rotated[j] = x[j]; | ||||||||||||||||||
| tq3_wht32_forward_device(rotated); | ||||||||||||||||||
|
|
||||||||||||||||||
| memset(y, 0, sizeof(block_tq3_0)); | ||||||||||||||||||
|
||||||||||||||||||
| memset(y, 0, sizeof(block_tq3_0)); | |
| y->gamma = __float2half(0.0f); | |
| for (int j = 0; j < (int)(sizeof(y->qs) / sizeof(y->qs[0])); ++j) { | |
| y->qs[j] = 0; | |
| } | |
| for (int j = 0; j < (int)(sizeof(y->qr) / sizeof(y->qr[0])); ++j) { | |
| y->qr[j] = 0; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TQ3_0 is advertised as an allowed KV cache type in the CLI, but the CPU backend has no vec_dot support for GGML_TYPE_TQ3_0 (see ggml/src/ggml-cpu/ggml-cpu.c), so CPU-only runs selecting this type will fail during graph planning/execution. Either add CPU support, or gate this option based on backend capabilities / emit a clear error earlier.