Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

37 changes: 36 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,41 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
}
#endif

#if defined(DATA_A_Q1_0_G128)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
// iqs is the element index within the block (0..127)
const uint byte_idx = iqs / 8;
const uint bit_idx = iqs % 8;
const uint bits = uint(data_a[a_offset + ib].qs[byte_idx]);
const float sign0 = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f;
// Second element
const uint byte_idx2 = (iqs + 1) / 8;
const uint bit_idx2 = (iqs + 1) % 8;
const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]);
const float sign1 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f;
return vec2(sign0, sign1);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint byte_idx0 = iqs / 8;
const uint bit_idx0 = iqs % 8;
const uint bits0 = uint(data_a[a_offset + ib].qs[byte_idx0]);
const float s0 = ((bits0 >> bit_idx0) & 1) == 1 ? 1.0f : -1.0f;
const uint byte_idx1 = (iqs + 1) / 8;
const uint bit_idx1 = (iqs + 1) % 8;
const uint bits1 = uint(data_a[a_offset + ib].qs[byte_idx1]);
const float s1 = ((bits1 >> bit_idx1) & 1) == 1 ? 1.0f : -1.0f;
const uint byte_idx2 = (iqs + 2) / 8;
const uint bit_idx2 = (iqs + 2) % 8;
const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]);
const float s2 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f;
const uint byte_idx3 = (iqs + 3) / 8;
const uint bit_idx3 = (iqs + 3) % 8;
const uint bits3 = uint(data_a[a_offset + ib].qs[byte_idx3]);
const float s3 = ((bits3 >> bit_idx3) & 1) == 1 ? 1.0f : -1.0f;
return vec4(s0, s1, s2, s3);
}
#endif

#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
Expand Down Expand Up @@ -448,7 +483,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif

#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q1_0_G128) || defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}
Expand Down
29 changes: 29 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0_g128.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#version 450

#include "dequant_head.glsl"

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {block_q1_0_g128 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};

void main() {
// Each thread handles one 128-element block
const uint ib = gl_WorkGroupID.x * 256 + gl_LocalInvocationID.x;

if (ib >= p.nel / 128) {
return;
}

const uint b_idx = ib * 128;
const float d = float(data_a[ib].d);

// Each block has 16 bytes = 128 bits = 128 elements
[[unroll]] for (uint byte_idx = 0; byte_idx < 16; ++byte_idx) {
const uint bits = uint(data_a[ib].qs[byte_idx]);
[[unroll]] for (uint bit_idx = 0; bit_idx < 8; ++bit_idx) {
const float sign = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f;
data_b[b_idx + byte_idx * 8 + bit_idx] = D_TYPE(d * sign);
}
}
}
108 changes: 108 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q1_0_g128.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_8bit_storage : require

#include "mul_mat_vec_base.glsl"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

// Fused 1-bit matrix-vector multiply for Q1_0_g128.
// 4 threads per block, each handles 32 elements (one uint32 of packed bits).
// Uses simple ternary sign selection which compiles to v_cndmask on RDNA.

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

void calc_block(const uint a_offset, const uint b_offset, const uint itid, const uint i,
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {

const uint y_idx_base = i * 128 + itid * 32;

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint base_b = (j * p.batch_stride_b + b_offset + y_idx_base) / 4;
const vec4 bv0 = vec4(data_b_v4[base_b]);
const vec4 bv1 = vec4(data_b_v4[base_b + 1]);
const vec4 bv2 = vec4(data_b_v4[base_b + 2]);
const vec4 bv3 = vec4(data_b_v4[base_b + 3]);
const vec4 bv4 = vec4(data_b_v4[base_b + 4]);
const vec4 bv5 = vec4(data_b_v4[base_b + 5]);
const vec4 bv6 = vec4(data_b_v4[base_b + 6]);
const vec4 bv7 = vec4(data_b_v4[base_b + 7]);

uint ibi = a_offset + first_row * num_blocks_per_row + i;

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);

const uint byte_base = itid * 4;
const uint bits = uint(data_a[ibi].qs[byte_base])
| (uint(data_a[ibi].qs[byte_base + 1]) << 8)
| (uint(data_a[ibi].qs[byte_base + 2]) << 16)
| (uint(data_a[ibi].qs[byte_base + 3]) << 24);

FLOAT_TYPE partial = FLOAT_TYPE(0);

partial += FLOAT_TYPE(dot(vec4(
(bits & 0x1u) != 0 ? 1.0 : -1.0, (bits & 0x2u) != 0 ? 1.0 : -1.0,
(bits & 0x4u) != 0 ? 1.0 : -1.0, (bits & 0x8u) != 0 ? 1.0 : -1.0), bv0));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x10u) != 0 ? 1.0 : -1.0, (bits & 0x20u) != 0 ? 1.0 : -1.0,
(bits & 0x40u) != 0 ? 1.0 : -1.0, (bits & 0x80u) != 0 ? 1.0 : -1.0), bv1));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x100u) != 0 ? 1.0 : -1.0, (bits & 0x200u) != 0 ? 1.0 : -1.0,
(bits & 0x400u) != 0 ? 1.0 : -1.0, (bits & 0x800u) != 0 ? 1.0 : -1.0), bv2));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x1000u) != 0 ? 1.0 : -1.0, (bits & 0x2000u) != 0 ? 1.0 : -1.0,
(bits & 0x4000u) != 0 ? 1.0 : -1.0, (bits & 0x8000u) != 0 ? 1.0 : -1.0), bv3));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x10000u) != 0 ? 1.0 : -1.0, (bits & 0x20000u) != 0 ? 1.0 : -1.0,
(bits & 0x40000u) != 0 ? 1.0 : -1.0, (bits & 0x80000u) != 0 ? 1.0 : -1.0), bv4));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x100000u) != 0 ? 1.0 : -1.0, (bits & 0x200000u) != 0 ? 1.0 : -1.0,
(bits & 0x400000u) != 0 ? 1.0 : -1.0, (bits & 0x800000u) != 0 ? 1.0 : -1.0), bv5));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x1000000u) != 0 ? 1.0 : -1.0, (bits & 0x2000000u) != 0 ? 1.0 : -1.0,
(bits & 0x4000000u) != 0 ? 1.0 : -1.0, (bits & 0x8000000u) != 0 ? 1.0 : -1.0), bv6));
partial += FLOAT_TYPE(dot(vec4(
(bits & 0x10000000u) != 0 ? 1.0 : -1.0, (bits & 0x20000000u) != 0 ? 1.0 : -1.0,
(bits & 0x40000000u) != 0 ? 1.0 : -1.0, (bits & 0x80000000u) != 0 ? 1.0 : -1.0), bv7));

temp[j][n] = fma(FLOAT_TYPE(d), partial, temp[j][n]);
ibi += num_blocks_per_row;
}
}
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

const uint num_blocks_per_row = p.ncols / 128;
const uint blocks_per_wg = gl_WorkGroupSize.x / 4;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 4;
const uint ix = tid / 4;

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}

[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_block(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);

reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
31 changes: 31 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;

buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q1_0_G128)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;

// LOAD_VEC_A = 4, so each load processes 4 elements.
// 128 elements per block / 4 = 32 loads per block.
const uint ib = idx / 32; // block index
const uint iel = (idx % 32) * 4; // element offset within block (0,4,8,...124)

const float d = float(data_a[ib].d);
const float d2 = d + d;
const float neg_d = -d;

// Mirror Metal's chunking more directly: q1_0_g128 is 8 chunks of 16 sign bits.
// Decode the containing 16-bit chunk, then select the 4-bit sub-group for this load.
const uint chunk16 = iel / 16;
const uint chunk_bit = iel % 16;
const uint byte_offset = chunk16 * 2;
const uint bits16 = uint(data_a[ib].qs[byte_offset])
| (uint(data_a[ib].qs[byte_offset + 1]) << 8);
const uint bits = (bits16 >> chunk_bit) & 0xFu;

// Branchless FMA: d*(2*bit-1) = fma(2d, bit_float, -d)
const vec4 bit_floats = vec4(
float(bits & 1u), float((bits >> 1) & 1u),
float((bits >> 2) & 1u), float((bits >> 3) & 1u)
);
const vec4 v = fma(vec4(d2), bit_floats, vec4(neg_d));

buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q2_K)
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require

#if defined(DATA_A_F32)
Expand Down Expand Up @@ -46,6 +47,23 @@
#endif
#endif

#define QUANT_K_Q1_0_G128 128
#define QUANT_R_Q1_0_G128 1

struct block_q1_0_g128
{
float16_t d;
uint8_t qs[16];
};

#if defined(DATA_A_Q1_0_G128)
#define QUANT_K QUANT_K_Q1_0_G128
#define QUANT_R QUANT_R_Q1_0_G128
#define QUANT_AUXF 1
#define A_TYPE block_q1_0_g128
#define DATA_A_QUANT_LEGACY
#endif

#define QUANT_K_Q4_0 32
#define QUANT_R_Q4_0 2

Expand Down
17 changes: 9 additions & 8 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const std::vector<std::string> type_names = {
"q5_0",
"q5_1",
"q8_0",
"q1_0_g128",
"q2_k",
"q3_k",
"q4_k",
Expand Down Expand Up @@ -220,7 +221,7 @@ bool is_quantized_type(const std::string& type_name) {
}

bool is_legacy_quant(const std::string& type_name) {
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0" || type_name == "q1_0_g128";
}

bool is_k_quant(const std::string& type_name) {
Expand Down Expand Up @@ -554,7 +555,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q1_0_g128") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";

if (tname == "bf16") {
Expand All @@ -580,14 +581,14 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}

if (tname != "f16" && tname != "f32") {
if (tname != "f16" && tname != "f32" && !(coopmat2 && tname == "q1_0_g128")) {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
if (!f16acc && !coopmat && !coopmat2 && tname != "q1_0_g128" && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
Expand Down Expand Up @@ -645,7 +646,7 @@ void process_shaders() {
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
} else {
} else if (tname != "q1_0_g128") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
Expand Down Expand Up @@ -680,7 +681,7 @@ void process_shaders() {
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_") || tname == "q1_0_g128") ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";

string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
Expand All @@ -697,7 +698,7 @@ void process_shaders() {

// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
if (tname != "q1_0_g128" && (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m")) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
Expand Down Expand Up @@ -1139,7 +1140,7 @@ void write_output_files() {

for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
if (btype == "q8_1" && (!is_legacy_quant(tname) || tname == "q1_0_g128") && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
continue;
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
Expand Down
3 changes: 3 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7094,6 +7094,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}

test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
test_cases.emplace_back(new test_get_rows(GGML_TYPE_Q1_0_g128, 256, 5, 4, 1, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
for (bool v : {false, true}) {
Expand Down Expand Up @@ -7796,6 +7797,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
#endif

test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q1_0_g128, GGML_TYPE_F32, 16, 16, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q1_0_g128, GGML_TYPE_F32, 8, 2, false, 16, 16, 256));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
Expand Down