vulkan: add Q1_0_g128 (1-bit ternary) shader support#9
vulkan: add Q1_0_g128 (1-bit ternary) shader support#9claudlos wants to merge 1 commit intoPrismML-Eng:prismfrom
Conversation
Add Vulkan compute shader support for the GGML_TYPE_Q1_0_g128 quantization format (1-bit sign / binary quantization, group size 128). New files: - dequant_q1_0_g128.comp: standalone dequantization shader - mul_mat_vec_q1_0_g128.comp: fused matrix-vector multiply shader (4 threads/block, 32 elements/thread, 8x dot(vec4,vec4)) Modified files: - types.glsl: block_q1_0_g128 struct, QUANT_K=128, QUANT_R=1 - dequant_funcs.glsl: dequantize/dequantize4 + single-scale get_dm - mul_mm_funcs.glsl: branchless FMA load path for batched matmul - vulkan-shaders-gen.cpp: type registration, LOAD_VEC_A=4, excluded from coopmat2 flash attention and integer dot product Q8_1 paths - ggml-vulkan.cpp: pipeline registration for dequant, get_rows, mul_mat_vec (f32/f16/id), mul_mat_mat, mul_mat_mat_id, supports_op - test-backend-ops.cpp: Q1_0_g128 test cases for get_rows, mul_mat, mul_mat_id Performance on AMD Radeon 680M (RDNA2 iGPU): eval: 0.28 -> 23.4 t/s (84x), prompt: 0.31 -> 38.3 t/s (124x) graph splits: 291 -> 2
|
Nice, thanks this is very good, we had a Vulkan and opencl backends too but did not have time to test and benchmark them properly, so did not release it, I will try to put them in a brach for people to try. This looks great too. Curios which phone was this? Also the x86 cpu giving wrong output shoud be fixed now (stilll not optimized to be fast but its correct now), please check this PR #8 |
Thanks Khosravipasha I'm happy to help This was done with a mini computer not a phone. R9 6900HX with AMD Radeon 680M (RDNA2 iGPU) I edited the PR and removed the notes about x86 cpu dequant. |
|
It is working on My Vega56 GPU. Thank you very much to @claudlos |

Add Vulkan shader support for Q1_0_g128 quantization
Summary
This PR adds Vulkan shader support for the
GGML_TYPE_Q1_0_g128(1-bit sign / binary quantization, group size 128) format. The primary validated paths today are dequantization,get_rows, and fusedmul_mat_vec. Without these shaders, Q1_0_g128 models fall back to CPU dequantization on Vulkan devices, resulting in ~291 graph splits and extremely poor performance. With this PR, the tested inference path runs almost entirely on GPU with only 2 graph splits.Performance Results
Comparison: Qwen2.5-3B Q4_K_M on the same hardware achieves 27.8 t/s — our Bonsai 8B Q1_0_g128 reaches 84% of that speed despite being 2.7x larger, validating the efficiency of 1-bit inference.
Files Changed
New Files
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0_g128.comp— Standalone dequantization compute shader. Each thread processes one 128-element block (16 bytes of packed sign bits + fp16 scale), used forget_rowsand general dequantization pipelines.ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q1_0_g128.comp— Custom fused matrix-vector multiply shader. Uses 4 threads per 128-element block (32 elements/thread = one uint32 of packed sign bits). Maps each bit to±1.0, which compiles to efficientv_cndmask-style code on RDNA GPUs. Accumulates via 8dot(vec4, vec4)operations per thread withfmafor the final scale multiply.Modified Files
ggml/src/ggml-vulkan/vulkan-shaders/types.glsl— Addedblock_q1_0_g128struct definition (fp16 scaled+ 16-byteqsarray) andDATA_A_Q1_0_G128preprocessor configuration withQUANT_K=128,QUANT_R=1,QUANT_AUXF=1.ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl— Addeddequantize()anddequantize4()functions for Q1_0_g128 (bit extraction → sign mapping). Added Q1_0_g128 to the single-scaleget_dm()path (returnsvec2(d, 0)).ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl— Added Q1_0_g128 matrix-matrix multiply load path inload_a_to_shmem. Uses branchless FMA dequantization:d*(2*bit - 1) = fma(2d, bit_float, -d)for efficient SIMD utilization.ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp— Registeredq1_0_g128in the type list, marked as legacy quant withLOAD_VEC_A=4. Excluded fromcoopmat2flash attention paths (nodequantFuncdefined), excluded from integer dot productq8_1MMQ paths (no Q8_1 quantization mapping exists for 1-bit types).ggml/src/ggml-vulkan/ggml-vulkan.cpp— Registered Q1_0_g128 pipelines across all shader pipeline arrays:dequant,mul_mat_vec(f32 and f16 B-type variants),mul_mat_vec_id,get_rows,get_rows_f32. Added Q1_0_g128 tosupports_opswitch cases forGGML_OP_MUL_MATandGGML_OP_GET_ROWS. Usesrm_iqrow multiplier and subgroup16 configuration (matching IQ-type pipeline parameters).Technical Design Decisions
Custom fused mul_mat_vec shader rather than generic path: The 1-bit format has a unique structure (128 elements packed as 16 bytes of sign bits + single fp16 scale) that doesn't fit the standard
mul_mat_vec.comptemplate. The fused shader avoids intermediate dequantization and directly computes dot products from packed bits.4 threads per block: Each thread handles 32 elements (one uint32 worth of bits), loading 4 bytes and expanding to 8 vec4 dot products. This maps well to GPU wavefronts.
Excluded from coopmat2/flash attention: Q1_0_g128 is a weights-only quantization (KV cache uses f16). There's no
dequantFuncsymbol required for the cooperative matrix path, and flash attention operates on KV cache types, not weight types.Excluded from integer dot MMQ: The Q8_1 quantized matmul path requires a compatible requantization scheme that doesn't exist for 1-bit types.
Branchless FMA in mul_mm_funcs: The matrix-matrix path uses
fma(2d, bit_float, -d)instead of conditional selection, which is more efficient for the wider SIMD paths used in batched matmul.Testing
llama-cliwith Vulkan backend (-ngl 99)Known Limitations
No cooperative matrix (coopmat2) support: Q1_0_g128 does not participate in cooperative matrix matmul or flash attention paths. This is by design — these paths require a
dequantFuncsymbol and Q1_0_g128 is weights-only.No integer dot product (MMQ) support: The
q8_1integer dot product optimization path is excluded for Q1_0_g128 since no compatible requantization scheme exists.No F16 B-type
mul_matsupport: Q1_0_g128 with F16 input tensors is explicitly blocked insupports_op. Only F32 B-type is supported. This avoids the complexity of F16 pipeline wiring for a weights-only quantization format (KV cache uses f16, but input activations are f32).Tested on RDNA2 only: While the shaders use standard Vulkan compute (no vendor-specific extensions), they have only been validated on AMD RDNA2. Testing on NVIDIA and Intel GPUs is recommended.
Flash Attention: N/A for weight quantization types. Q1_0_g128 models use f16 KV cache, which already has full Vulkan support.
Hardware Tested