From ff49890ddb12c2d0727e70fbae70746f44051c30 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Mon, 15 Dec 2025 19:27:20 +0500 Subject: [PATCH 01/13] ggml-cpu: add rvv ggml_quantize_mat_4x8 for q8_0 Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 88 +++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 2a35ff9ad87..9f37197ab4e 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -24,6 +24,94 @@ #define UNUSED GGML_UNUSED +void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + +#if defined(__riscv_v_intrinsic) + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0); + const size_t vl_save = __riscv_vsetvl_e64m2(4); + vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1)); + + for (int i = 0; i < nb; i++) { + const float *x_block_base = x + i * QK8_0; + vint8m2_t q_r0, q_r1, q_r2, q_r3; + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[0] = GGML_CPU_FP32_TO_FP16(d); + + float id = d ? 1.0f / d : 0.0f; + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[1] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[2] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[3] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0); + vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1); + vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2); + vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3); + vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3); + __riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save); + } +#else + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); +#endif +} + void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; From 9bd6c9e3ff14a82fd4dbc738da1c4b4b18d9d355 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Mon, 15 Dec 2025 19:27:59 +0500 Subject: [PATCH 02/13] ggml-cpu: add rvv repacking for iq4_nl --- ggml/src/ggml-cpu/arch-fallback.h | 15 +- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 987 ++++++++++++++++++++++++ ggml/src/ggml-cpu/repack.cpp | 425 +++++++++- ggml/src/ggml-cpu/repack.h | 20 +- 4 files changed, 1441 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index ebbd4b47e05..c4d1cf314c2 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -47,6 +47,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -63,6 +64,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -72,9 +74,11 @@ // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -127,6 +131,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -143,6 +148,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -171,6 +177,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -187,6 +194,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -203,9 +211,10 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 -#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_0_4x16_generic ggml_quantize_mat_q8_0_4x16 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -265,6 +274,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -281,6 +291,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -317,6 +328,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -333,6 +345,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 9f37197ab4e..667d889a7d8 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -203,6 +203,527 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + vfloat32mf2_t sumf = __riscv_vfmv_v_f_f32mf2(0.0, 4); + for (int l = 0; l < nb; l++) { + // Load first 8 bytes of `a`. + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); + + // Load `b_ptr`. + const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); + const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); + const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); + + // Create 4 segments from `b`. + const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); + const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); + const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); + const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d, 4); + sumf = __riscv_vfmacc_vv_f32mf2(sumf, facc, d_0, QK4_NL / 8); + } + __riscv_vse32_v_f32mf2(s + x * ncols_interleaved, sumf, QK4_NL / 8); + } + return; + +#endif + ggml_gemv_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0f, 4); + for (int l = 0; l + 1 < nb; l += 2) { + vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); + vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); + vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); + vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); + vuint8m1_t b_4_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 0, 16); + vuint8m1_t b_5_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 16, 16); + vuint8m1_t b_6_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 32, 16); + vuint8m1_t b_7_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 48, 16); + + vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); + vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); + vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); + vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); + vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); + vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); + vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); + vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); + vuint8m1_t b_4_lo = __riscv_vand_vx_u8m1(b_4_packed, 0xf, 16); + vuint8m1_t b_4_hi = __riscv_vsrl_vx_u8m1(b_4_packed, 4, 16); + vuint8m1_t b_5_lo = __riscv_vand_vx_u8m1(b_5_packed, 0xf, 16); + vuint8m1_t b_5_hi = __riscv_vsrl_vx_u8m1(b_5_packed, 4, 16); + vuint8m1_t b_6_lo = __riscv_vand_vx_u8m1(b_6_packed, 0xf, 16); + vuint8m1_t b_6_hi = __riscv_vsrl_vx_u8m1(b_6_packed, 4, 16); + vuint8m1_t b_7_lo = __riscv_vand_vx_u8m1(b_7_packed, 0xf, 16); + vuint8m1_t b_7_hi = __riscv_vsrl_vx_u8m1(b_7_packed, 4, 16); + + vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); + vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); + vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); + vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); + vint8m1_t b_4 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_4_lo, b_4_hi, 16, 32), 32); + vint8m1_t b_5 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_5_lo, b_5_hi, 16, 32), 32); + vint8m1_t b_6 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_6_lo, b_6_hi, 16, 32), 32); + vint8m1_t b_7 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_7_lo, b_7_hi, 16, 32), 32); + + vint8m1_t a_0 = __riscv_vle8_v_i8m1(a_ptr[l].qs, 32); + vint8m1_t a_1 = __riscv_vle8_v_i8m1(a_ptr[l + 1].qs, 32); + + vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_4 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_4, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_5 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_5, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_6 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_6, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_7 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_7, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + + int sumi_temp[8]; + __riscv_vse32_v_i32m1(&sumi_temp[0], sumi_0, 1); + __riscv_vse32_v_i32m1(&sumi_temp[1], sumi_1, 1); + __riscv_vse32_v_i32m1(&sumi_temp[2], sumi_2, 1); + __riscv_vse32_v_i32m1(&sumi_temp[3], sumi_3, 1); + __riscv_vse32_v_i32m1(&sumi_temp[4], sumi_4, 1); + __riscv_vse32_v_i32m1(&sumi_temp[5], sumi_5, 1); + __riscv_vse32_v_i32m1(&sumi_temp[6], sumi_6, 1); + __riscv_vse32_v_i32m1(&sumi_temp[7], sumi_7, 1); + vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); + vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); + + vfloat16mf2_t b_d_0 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); + vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d_0, *(const _Float16 *)&a_ptr[l].d, 4); + vfloat16mf2_t b_d_1 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l + 1].d, 4); + vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d_1, *(const _Float16 *)&a_ptr[l + 1].d, 4); + + sumf = __riscv_vfmacc_vv_f32m1(sumf, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, 4); + } + return; +#endif + ggml_gemv_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + // Accumulation loop. + for (int i = 0; i < 16; i++) { + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); + sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); + } + + vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + // 4x4 Accumulators + vfloat32m1_t sumf_0 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + vfloat32m1_t sumf_1 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + vfloat32m1_t sumf_2 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + vfloat32m1_t sumf_3 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + + for (int l = 0; l < nb; l++) { + int sumi_temp[16]; + uint8_t index[4] = {0, 8, 64, 72}; + vuint8mf8_t i_vec = __riscv_vle8_v_u8mf8(&index[0], 4); + vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); + vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); + vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); + vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); + + vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); + vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); + vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); + vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); + vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); + vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); + vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); + vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); + + vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); + vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); + vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); + vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); + + #pragma unroll 4 + for (int i = 0; i < 4; i++) { + vint8m1_t a_i = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vloxei8_v_i64m1((int64_t*)(a_ptr[l].qs + i * 16), i_vec, 4)); + vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 0], sumi_0, 1); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 1], sumi_1, 1); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 2], sumi_2, 1); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 3], sumi_3, 1); + } + + vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); + vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); + vint32m1_t sum_2 = __riscv_vle32_v_i32m1(&sumi_temp[8], 4); + vint32m1_t sum_3 = __riscv_vle32_v_i32m1(&sumi_temp[12], 4); + + vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); + vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[0], 4); + vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[1], 4); + vfloat32m1_t d_2 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[2], 4); + vfloat32m1_t d_3 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[3], 4); + + sumf_0 = __riscv_vfmacc_vv_f32m1(sumf_0, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); + sumf_1 = __riscv_vfmacc_vv_f32m1(sumf_1, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); + sumf_2 = __riscv_vfmacc_vv_f32m1(sumf_2, d_2, __riscv_vfcvt_f_x_v_f32m1(sum_2, 4), 4); + sumf_3 = __riscv_vfmacc_vv_f32m1(sumf_3, d_3, __riscv_vfcvt_f_x_v_f32m1(sum_3, 4), 4); + } + + __riscv_vse32_v_f32m1(s + (y * 4 + 0) * bs + x * 4, sumf_0, 4); + __riscv_vse32_v_f32m1(s + (y * 4 + 1) * bs + x * 4, sumf_1, 4); + __riscv_vse32_v_f32m1(s + (y * 4 + 2) * bs + x * 4, sumf_2, 4); + __riscv_vse32_v_f32m1(s + (y * 4 + 3) * bs + x * 4, sumf_3, 4); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + // Accumulation loop. + for (int i = 0; i < 16; i++) { + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16); + + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16); + + sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16); + sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16); + sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16); + sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); + } + + vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Load `b_ptr`. + const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); + const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); + const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); + + // Create 4 segments from `b`. + const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); + const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); + const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); + const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); + sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); + } + return; + +#endif + ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -428,3 +949,469 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + // 4x4 accumulators. + vfloat32mf2_t sumf0 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32mf2_t sumf1 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32mf2_t sumf2 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32mf2_t sumf3 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + + for (int l = 0; l < nb; l++) { + // Load `b_ptr`. + const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); + const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); + const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); + + // Create 4 segments from `b`. + const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); + const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); + const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); + const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); + + // Load scales for `b`. + const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); + + // Load first 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[0], 4); + sumf0 = __riscv_vfmacc_vv_f32mf2(sumf0, facc, d_0, QK4_NL / 8); + } + + // Load second 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[1], 4); + sumf1 = __riscv_vfmacc_vv_f32mf2(sumf1, facc, d_0, QK4_NL / 8); + } + + // Load third 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[2], 4); + sumf2 = __riscv_vfmacc_vv_f32mf2(sumf2, facc, d_0, QK4_NL / 8); + } + + // Load fourth 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[3], 4); + sumf3 = __riscv_vfmacc_vv_f32mf2(sumf3, facc, d_0, QK4_NL / 8); + } + } + + __riscv_vse32_v_f32mf2(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); + __riscv_vse32_v_f32mf2(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); + __riscv_vse32_v_f32mf2(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); + __riscv_vse32_v_f32mf2(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); + } + } + return; + +#endif + ggml_gemm_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + // 4x8 accumulators. + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, 8); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, 8); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, 8); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, 8); + + for (int l = 0; l < nb; l++) { + // Load `b_ptr`. + const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); + const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); + const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); + + // Create 4 segments from `b`. + const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); + const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); + const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); + const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); + + // Load scales for `b`. + const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + + { + // Load first 8 bytes of `a`. + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[0], 8); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, facc, d_0, QK4_NL / 4); + } + + // Load second 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[1], 8); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, facc, d_0, QK4_NL / 4); + } + + // Load third 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[2], 8); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, facc, d_0, QK4_NL / 4); + } + + { + // Load fourth 8 bytes of `a`. + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[3], 8); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, facc, d_0, QK4_NL / 4); + } + } + + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); + } + } + return; + +#endif + ggml_gemm_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5edba4212f6..cff46066715 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,3 +1,4 @@ +#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -48,6 +49,44 @@ static inline int nearest_int(float fval) { extern "C" { +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 1; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} + void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -124,6 +163,43 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } +void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 16; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); @@ -238,12 +314,24 @@ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<16, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x16(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); @@ -1060,6 +1148,82 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1758,6 +1922,118 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2577,7 +2853,31 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); } - } else { + } else if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + for (int b = 0; b < 8; ++b) { + out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; + } + + // Generates bus error on RVV as this is auto-vectorized and the + // source might possible not be 8-byte aligned + // + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else if (blck_size_interleave == 16) { + for (int i = 0; i < end; ++i) { + int src_id = i; + int src_offset = 0; + int dst_offset = i * 16; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], 4 * sizeof(uint32_t)); + } + } + else { GGML_ASSERT(false); } @@ -2586,7 +2886,7 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 4); + // GGML_ASSERT(interleave_block == 4); const block_iq4_nl * src = (const block_iq4_nl *)data; block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; @@ -2632,7 +2932,14 @@ static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_s int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + for (int b = 0; b < 8; ++b) { + out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; + } + + // Generates bus error on RVV as this is auto-vectorized and the + // source might possible not be 8-byte aligned + // + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } } else { GGML_ASSERT(false); @@ -2674,6 +2981,67 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } +static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = (i / 16) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + + // Generates bus error on RVV as this is auto-vectorized and the + // source might possible not be 8-byte aligned + // + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + GGML_ASSERT(interleave_block == 8); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; + + block_iq4_nl dst_tmp[8]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { block_mxfp4x4 out; @@ -2844,10 +3212,22 @@ template <> int repack(struct ggml_tensor * t, const void * // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); //} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 16, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size); } @@ -2919,10 +3299,22 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2994,14 +3386,25 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3412,7 +3815,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for IQ4 static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_4x16_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_4x8_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; // instance for MXFP4 static const ggml::cpu::repack::tensor_traits mxfp4_4x4_q8_0; @@ -3494,6 +3900,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 4 == 0) { return &iq4_nl_16x1_q8_0; } break; } + case 512: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x8_q8_0; } break; } + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_MXFP4) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index b9f821630c4..f9da6b5c9f1 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -97,6 +97,13 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +struct block_iq4_nlx16 { + ggml_half d[16]; // deltas for 16 iq4_nl blocks + uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); + struct block_mxfp4x4 { uint8_t e[4]; uint8_t qs[QK_MXFP4 * 2]; @@ -109,13 +116,14 @@ struct block_mxfp4x8 { }; static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); - #if defined(__cplusplus) extern "C" { #endif +void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x16(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -129,6 +137,8 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -143,6 +153,8 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -152,8 +164,10 @@ void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // Native implementations +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -167,6 +181,8 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -181,6 +197,8 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From dbd16674816c6ea95bad5e546e2dee2c918e3c20 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Tue, 30 Dec 2025 17:40:51 +0500 Subject: [PATCH 03/13] ggml-cpu: add generic impl for iq4_nl gemm/gemv --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 49 +++++++-------- ggml/src/ggml-cpu/repack.cpp | 82 +++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 667d889a7d8..de3ac886afc 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -424,22 +424,23 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const // 1x16 integer accumulator vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - // Accumulation loop. - for (int i = 0; i < 16; i++) { + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); } - vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); - vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); } @@ -545,7 +546,7 @@ void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const } return; #endif - ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemm_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -589,15 +590,15 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - // Accumulation loop. - for (int i = 0; i < 16; i++) { + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); @@ -614,11 +615,11 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); } - vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); - vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index cff46066715..b3cb78abcf7 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1262,6 +1262,44 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2078,6 +2116,50 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; From d9fbb69f546847a1a354413b0685d7e0b0dabe2a Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Tue, 30 Dec 2025 21:33:46 +0500 Subject: [PATCH 04/13] ggml-cpu: add rvv repacking for q8_0 --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 127 ++++++++++++++++++++++++ ggml/src/ggml-cpu/repack.cpp | 73 ++++++++++++++ ggml/src/ggml-cpu/repack.h | 13 +++ 3 files changed, 213 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index de3ac886afc..923902786e6 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -203,6 +203,59 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x32 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -638,6 +691,80 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16); + sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16); + sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16); + sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index b3cb78abcf7..11702427f5c 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -2907,6 +2907,55 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, return 0; } +static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK8_0 * 16 / blck_size_interleave; + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = (i / 16) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); + } + + return out; +} + +static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + constexpr int nrows_interleaved = 16; + + block_q8_0x16 * dst = (block_q8_0x16 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q8_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { block_iq4_nlx4 out; @@ -3326,6 +3375,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); +} + // gemv template void gemv(int, float *, size_t, const void *, const void *, int, int); @@ -3413,6 +3466,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + // gemm template void gemm(int, float *, size_t, const void *, const void *, int, int); @@ -3499,6 +3556,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + class tensor_traits_base : public ggml::cpu::tensor_traits { public: virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; @@ -3909,6 +3970,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for Q8_0 static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; if (cur->type == GGML_TYPE_Q4_0) { if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) @@ -4015,6 +4077,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q8_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } return nullptr; diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index f9da6b5c9f1..b5186b39132 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -35,6 +35,7 @@ using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; +using block_q8_0x16 = block<8, 16>; struct block_q4_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -142,6 +143,10 @@ void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -160,8 +165,10 @@ void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // Native implementations void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -186,6 +193,10 @@ void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -204,8 +215,10 @@ void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined(__cplusplus) } // extern "C" From 6502d03bb7e35ff0a163feb295fcf72c42d1609c Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Fri, 2 Jan 2026 18:05:07 +0500 Subject: [PATCH 05/13] ggml-cpu: refactor; add rvv repacking for q4_0, q4_K --- ggml/src/ggml-cpu/arch-fallback.h | 22 +- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 1110 ++++++++++---------- ggml/src/ggml-cpu/repack.cpp | 1283 ++++++++++++++++------- ggml/src/ggml-cpu/repack.h | 62 +- 4 files changed, 1526 insertions(+), 951 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c4d1cf314c2..901510a9045 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -47,7 +47,6 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -64,7 +63,6 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -74,11 +72,9 @@ // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -94,7 +90,10 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +<<<<<<< HEAD #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -105,7 +104,10 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +<<<<<<< HEAD #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) @@ -131,7 +133,6 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -148,7 +149,6 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -177,7 +177,6 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -194,7 +193,6 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -210,11 +208,11 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp +#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 -#define ggml_quantize_mat_q8_0_4x16_generic ggml_quantize_mat_q8_0_4x16 +#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -274,7 +272,6 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -291,7 +288,6 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 @@ -328,7 +324,6 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 @@ -345,7 +340,6 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 923902786e6..21dda5c0a9c 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -1,3 +1,4 @@ +#include #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -203,7 +204,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 16; @@ -223,26 +224,32 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); #if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - // 1x16 Accumulator1 + // 1x16 Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); for (int l = 0; l < nb; l++) { - // 1x32 integer accumulator - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + // 1x16 Integer Accumulator + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); // Accumulation loop. - for (int i = 0; i < QK8_0; i++) { + for (int i = 0; i < QK4_0 / 2; i++) { // Load `b_ptr`. - const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); - sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16); } + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); @@ -253,14 +260,22 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } return; #endif - ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; +void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO +} + +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO +} + +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); assert (nc % ncols_interleaved == 0); @@ -276,79 +291,124 @@ void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q8_K * a_ptr = (const block_q8_K *) vy; - vfloat32mf2_t sumf = __riscv_vfmv_v_f_f32mf2(0.0, 4); - for (int l = 0; l < nb; l++) { - // Load first 8 bytes of `a`. - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - // Load `b_ptr`. - const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); - const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); - const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); - // Create 4 segments from `b`. - const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); - const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); - const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); - const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); + + // Accumulation for 2 sub-blocks. + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16); + } - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, 16); + } + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16); + } - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, 16); + } + } - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16); - // Multiply with scales. - const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d, 4); - sumf = __riscv_vfmacc_vv_f32mf2(sumf, facc, d_0, QK4_NL / 8); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); } - __riscv_vse32_v_f32mf2(s + x * ncols_interleaved, sumf, QK4_NL / 8); + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } return; - #endif - ggml_gemv_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 8; + const int blocklen = 8; assert (n % qk == 0); assert (nc % ncols_interleaved == 0); @@ -364,85 +424,71 @@ void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0f, 4); - for (int l = 0; l + 1 < nb; l += 2) { - vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); - vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); - vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); - vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); - vuint8m1_t b_4_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 0, 16); - vuint8m1_t b_5_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 16, 16); - vuint8m1_t b_6_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 32, 16); - vuint8m1_t b_7_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 48, 16); - - vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); - vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); - vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); - vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); - vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); - vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); - vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); - vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); - vuint8m1_t b_4_lo = __riscv_vand_vx_u8m1(b_4_packed, 0xf, 16); - vuint8m1_t b_4_hi = __riscv_vsrl_vx_u8m1(b_4_packed, 4, 16); - vuint8m1_t b_5_lo = __riscv_vand_vx_u8m1(b_5_packed, 0xf, 16); - vuint8m1_t b_5_hi = __riscv_vsrl_vx_u8m1(b_5_packed, 4, 16); - vuint8m1_t b_6_lo = __riscv_vand_vx_u8m1(b_6_packed, 0xf, 16); - vuint8m1_t b_6_hi = __riscv_vsrl_vx_u8m1(b_6_packed, 4, 16); - vuint8m1_t b_7_lo = __riscv_vand_vx_u8m1(b_7_packed, 0xf, 16); - vuint8m1_t b_7_hi = __riscv_vsrl_vx_u8m1(b_7_packed, 4, 16); - - vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); - vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); - vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); - vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); - vint8m1_t b_4 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_4_lo, b_4_hi, 16, 32), 32); - vint8m1_t b_5 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_5_lo, b_5_hi, 16, 32), 32); - vint8m1_t b_6 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_6_lo, b_6_hi, 16, 32), 32); - vint8m1_t b_7 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_7_lo, b_7_hi, 16, 32), 32); - - vint8m1_t a_0 = __riscv_vle8_v_i8m1(a_ptr[l].qs, 32); - vint8m1_t a_1 = __riscv_vle8_v_i8m1(a_ptr[l + 1].qs, 32); - - vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_4 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_4, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_5 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_5, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_6 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_6, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_7 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_7, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - - int sumi_temp[8]; - __riscv_vse32_v_i32m1(&sumi_temp[0], sumi_0, 1); - __riscv_vse32_v_i32m1(&sumi_temp[1], sumi_1, 1); - __riscv_vse32_v_i32m1(&sumi_temp[2], sumi_2, 1); - __riscv_vse32_v_i32m1(&sumi_temp[3], sumi_3, 1); - __riscv_vse32_v_i32m1(&sumi_temp[4], sumi_4, 1); - __riscv_vse32_v_i32m1(&sumi_temp[5], sumi_5, 1); - __riscv_vse32_v_i32m1(&sumi_temp[6], sumi_6, 1); - __riscv_vse32_v_i32m1(&sumi_temp[7], sumi_7, 1); - vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); - vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); - - vfloat16mf2_t b_d_0 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); - vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d_0, *(const _Float16 *)&a_ptr[l].d, 4); - vfloat16mf2_t b_d_1 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l + 1].d, 4); - vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d_1, *(const _Float16 *)&a_ptr[l + 1].d, 4); - - sumf = __riscv_vfmacc_vv_f32m1(sumf, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); - sumf = __riscv_vfmacc_vv_f32m1(sumf, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Load `b_ptr`. + const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); + const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); + const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); + + // Create 4 segments from `b`. + const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); + const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); + const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); + const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); + sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, 4); + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); } return; + #endif - ggml_gemv_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -504,11 +550,64 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 Integer Accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); assert (nr % 4 == 0); @@ -525,81 +624,354 @@ void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + } + + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - // 4x4 Accumulators - vfloat32m1_t sumf_0 = __riscv_vfmv_v_f_f32m1(0.0f, 4); - vfloat32m1_t sumf_1 = __riscv_vfmv_v_f_f32m1(0.0f, 4); - vfloat32m1_t sumf_2 = __riscv_vfmv_v_f_f32m1(0.0f, 4); - vfloat32m1_t sumf_3 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); for (int l = 0; l < nb; l++) { - int sumi_temp[16]; - uint8_t index[4] = {0, 8, 64, 72}; - vuint8mf8_t i_vec = __riscv_vle8_v_u8mf8(&index[0], 4); - vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); - vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); - vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); - vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); - - vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); - vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); - vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); - vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); - vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); - vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); - vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); - vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); - - vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); - vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); - vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); - vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); - - #pragma unroll 4 - for (int i = 0; i < 4; i++) { - vint8m1_t a_i = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vloxei8_v_i64m1((int64_t*)(a_ptr[l].qs + i * 16), i_vec, 4)); - vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 0], sumi_0, 1); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 1], sumi_1, 1); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 2], sumi_2, 1); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 3], sumi_3, 1); + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[0], 16); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[1], 16); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[2], 16); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16); + + + // Accumulation for 2 sub-blocks. + { + // 4x8 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, 16); + } + { + // 4x8 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, 16); + } } - vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); - vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); - vint32m1_t sum_2 = __riscv_vle32_v_i32m1(&sumi_temp[8], 4); - vint32m1_t sum_3 = __riscv_vle32_v_i32m1(&sumi_temp[12], 4); - - vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); - vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[0], 4); - vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[1], 4); - vfloat32m1_t d_2 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[2], 4); - vfloat32m1_t d_3 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[3], 4); - - sumf_0 = __riscv_vfmacc_vv_f32m1(sumf_0, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); - sumf_1 = __riscv_vfmacc_vv_f32m1(sumf_1, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); - sumf_2 = __riscv_vfmacc_vv_f32m1(sumf_2, d_2, __riscv_vfcvt_f_x_v_f32m1(sum_2, 4), 4); - sumf_3 = __riscv_vfmacc_vv_f32m1(sumf_3, d_3, __riscv_vfcvt_f_x_v_f32m1(sum_3, 4), 4); + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); } - __riscv_vse32_v_f32m1(s + (y * 4 + 0) * bs + x * 4, sumf_0, 4); - __riscv_vse32_v_f32m1(s + (y * 4 + 1) * bs + x * 4, sumf_1, 4); - __riscv_vse32_v_f32m1(s + (y * 4 + 2) * bs + x * 4, sumf_2, 4); - __riscv_vse32_v_f32m1(s + (y * 4 + 3) * bs + x * 4, sumf_3, 4); + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } return; #endif - ggml_gemm_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO +} + +void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO } void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -724,7 +1096,7 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators + // 4x16 Integer Accumulators vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); @@ -765,93 +1137,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); - for (int l = 0; l < nb; l++) { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Load `b_ptr`. - const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); - const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); - const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); - - // Create 4 segments from `b`. - const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); - const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); - const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); - const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); - sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); - } - return; - -#endif - ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1078,239 +1363,6 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - // 4x4 accumulators. - vfloat32mf2_t sumf0 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - vfloat32mf2_t sumf1 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - vfloat32mf2_t sumf2 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - vfloat32mf2_t sumf3 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - - for (int l = 0; l < nb; l++) { - // Load `b_ptr`. - const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); - const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); - const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); - - // Create 4 segments from `b`. - const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); - const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); - const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); - const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); - - // Load scales for `b`. - const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); - - // Load first 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[0], 4); - sumf0 = __riscv_vfmacc_vv_f32mf2(sumf0, facc, d_0, QK4_NL / 8); - } - - // Load second 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[1], 4); - sumf1 = __riscv_vfmacc_vv_f32mf2(sumf1, facc, d_0, QK4_NL / 8); - } - - // Load third 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[2], 4); - sumf2 = __riscv_vfmacc_vv_f32mf2(sumf2, facc, d_0, QK4_NL / 8); - } - - // Load fourth 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[3], 4); - sumf3 = __riscv_vfmacc_vv_f32mf2(sumf3, facc, d_0, QK4_NL / 8); - } - } - - __riscv_vse32_v_f32mf2(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); - __riscv_vse32_v_f32mf2(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); - __riscv_vse32_v_f32mf2(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); - __riscv_vse32_v_f32mf2(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); - } - } - return; - -#endif - ggml_gemm_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 11702427f5c..b9ac85cabfa 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -49,6 +49,7 @@ static inline int nearest_int(float fval) { extern "C" { +#if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -87,45 +88,52 @@ void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; - block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; - // scalar - const int blck_size_interleave = 4; - float srcv[4][QK8_0]; - float id[4]; + const int blck_size_interleave = 1; + float srcv[4][QK_K]; + float iscale[4]; for (int i = 0; i < nb; i++) { for (int row_iter = 0; row_iter < 4; row_iter++) { float amax = 0.0f; // absolute max + float max = 0; - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } } - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + iscale[row_iter] = amax ? -127.f/max : 0; + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; } - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + for (int j = 0; j < QK_K * 4; j++) { + int src_id = j % 4; + int src_offset = j / 4; + int index = ((j >> 6) << 2) + (j & 3); - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; } } } +#endif -void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -133,7 +141,7 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; // scalar - const int blck_size_interleave = 8; + const int blck_size_interleave = 4; float srcv[4][QK8_0]; float id[4]; @@ -163,7 +171,7 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -171,7 +179,7 @@ void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * G block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; // scalar - const int blck_size_interleave = 16; + const int blck_size_interleave = 8; float srcv[4][QK8_0]; float id[4]; @@ -314,35 +322,37 @@ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); + ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); + ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<16, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_0_4x16(x, vy, n_per_row); + ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +#if defined __riscv_zvfh +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); + ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); + ggml_quantize_mat_q8_K_4x1(x, vy, n_per_row); } +#endif template static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, @@ -1148,82 +1158,6 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - -void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1262,6 +1196,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +<<<<<<< HEAD void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1376,6 +1311,8 @@ void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -1470,14 +1407,14 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } } -void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); - assert (nr % 4 == 0); assert (nc % ncols_interleaved == 0); UNUSED(s); @@ -1490,52 +1427,38 @@ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, UNUSED(ncols_interleaved); UNUSED(blocklen); - { - float sumf[4][4]; - int sumi; + float sumf[16]; + int sumi; - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } -void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); - assert (nr % 4 == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); UNUSED(vx); @@ -1545,46 +1468,258 @@ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); - - float sumf[4][4]; + float sumf[16]; + float sum_minf[16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; } } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } } } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } } } -void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} +#endif + +void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; assert (n % qk == 0); assert (nr % 4 == 0); @@ -1960,118 +2095,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - -void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2116,6 +2139,7 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +<<<<<<< HEAD void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2248,6 +2272,8 @@ void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2300,6 +2326,8 @@ void ggml_gemm_q8_0_4x4_q8_0_generic(int n, } } + + void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2322,7 +2350,246 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +#if defined __riscv_zvfh +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][16]; + float sum_minf[4][16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; @@ -2351,6 +2618,7 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, } } } +#endif } // extern "C" @@ -2440,6 +2708,31 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + const uint8_t xor_mask = 0x88; + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -2454,10 +2747,11 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; // Interleave Q4_K quants by taking 8 bytes at a time - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; // buffer large enough for the max interleave block size (8 bytes) uint64_t elems; @@ -2465,53 +2759,139 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in memcpy(&out.qs[dst_offset], &elems, blck_size_interleave); } - // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K - // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) - // The output Q4_Kx8 structure has 96 bytes - // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure - // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures - uint8_t s[8], m[8]; + // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K + // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q4_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = in[j].scales[i] & 63; - m[j] = in[j].scales[i + 4] & 63; } - out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + + } + + } else if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = i / 8; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[64], m[64]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[i * 8 + j] = in[j].scales[i] & 63; + m[i * 8 + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[32 + i * 8 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[32 + i * 8 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + for (int i = 0; i < 64; i++) { + out.scales[i] = (s[i] & 15) + (m[i] & 15 << 4); + } + for (int i = 0; i < 32; i++) { + out.scales[64 + i] = (s[i] & 48 >> 4) + (m[i] & 48 >> 2) + (s[32 + i] & 48) + (m[32 + i] & 48 << 2); + } } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + return out; +} + +static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) { + block_q4_Kx16 out; + //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 16; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; } - out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[128], m[128]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[i * 16 + j] = in[j].scales[i] & 63; + m[i * 16 + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + for (int i = 0; i < 128; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 64; i++) { + out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2); + } + } else { + GGML_ASSERT(false); } return out; @@ -2723,7 +3103,7 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8 || interleave_block == 4); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4 || interleave_block == 1); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -2752,6 +3132,36 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + constexpr int nrows_interleaved = 16; + + block_q4_Kx16 * dst = (block_q4_Kx16*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q2_K); GGML_ASSERT(interleave_block == 8); @@ -2783,6 +3193,36 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -2915,11 +3355,16 @@ static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_ } const int end = QK8_0 * 16 / blck_size_interleave; - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = (i / 16) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + } else { + GGML_ASSERT(false); } return out; @@ -2990,25 +3435,9 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s int src_offset = (i / 4) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - for (int b = 0; b < 8; ++b) { - out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; - } - - // Generates bus error on RVV as this is auto-vectorized and the - // source might possible not be 8-byte aligned - // - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); - } - } else if (blck_size_interleave == 16) { - for (int i = 0; i < end; ++i) { - int src_id = i; - int src_offset = 0; - int dst_offset = i * 16; - - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], 4 * sizeof(uint32_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } - } - else { + } else { GGML_ASSERT(false); } @@ -3017,7 +3446,7 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - // GGML_ASSERT(interleave_block == 4); + GGML_ASSERT(interleave_block == 4); const block_iq4_nl * src = (const block_iq4_nl *)data; block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; @@ -3124,15 +3553,10 @@ static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck if (blck_size_interleave == 1) { for (int i = 0; i < end; ++i) { int src_id = i % 16; - int src_offset = (i / 16) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + int src_offset = i / 16; + int dst_offset = i; out.qs[dst_offset] = in[src_id].qs[src_offset]; - - // Generates bus error on RVV as this is auto-vectorized and the - // source might possible not be 8-byte aligned - // - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } } else { GGML_ASSERT(false); @@ -3343,17 +3767,10 @@ template <> int repack(struct ggml_tensor * t, const void * // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); //} -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_4_bl(t, 16, data, data_size); -} - template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +<<<<<<< HEAD template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); @@ -3367,6 +3784,8 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size); } +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -3375,9 +3794,27 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +#if defined __riscv_zvfh +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_8_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); } +#endif // gemv template @@ -3434,17 +3871,10 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +<<<<<<< HEAD template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -3458,6 +3888,8 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3466,9 +3898,27 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +#endif // gemm template @@ -3525,14 +3975,11 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +<<<<<<< HEAD template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3548,6 +3995,8 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3556,9 +4005,27 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +#endif class tensor_traits_base : public ggml::cpu::tensor_traits { public: @@ -3958,10 +4425,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for IQ4 static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_4x16_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_4x8_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; // instance for MXFP4 static const ggml::cpu::repack::tensor_traits mxfp4_4x4_q8_0; @@ -3970,11 +4434,20 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for Q8_0 static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + + // instances for RISC-V + // + // These implement outer-product style multiplication with interleave of 1. +#if defined __riscv_zvfh + static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; +#endif if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) - || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) { return &q4_0_8x8_q8_0; } @@ -3989,6 +4462,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -4005,6 +4489,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x4_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { @@ -4048,8 +4543,8 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { case 128: { break; } // TODO - case 256: { if (cur->ne[1] % 4 == 0) { return &iq4_nl_16x1_q8_0; } break; } - case 512: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x8_q8_0; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } + case 512: { break; } // TODO case 1024: { break; } // TODO default: { return nullptr; } } diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index b5186b39132..ae4ed6a5e41 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -28,11 +28,14 @@ template struct block { // control size static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding"); static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; +using block_q4_0x16 = block<4, 16>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; using block_q8_0x16 = block<8, 16>; @@ -45,7 +48,14 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); +struct block_q4_Kx16 { + ggml_half d[16]; // super-block scale for quantized scales + ggml_half dmin[16]; // super-block scale for quantized mins + uint8_t scales[192]; // scales and mins, quantized with 6 bits + uint8_t qs[2048]; // 4--bit quants +}; +static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -121,10 +131,8 @@ static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block si extern "C" { #endif -void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_0_4x16(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -138,15 +146,15 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +<<<<<<< HEAD void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -158,23 +166,35 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +<<<<<<< HEAD void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif // Native implementations -void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -188,15 +208,15 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +<<<<<<< HEAD void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -208,17 +228,31 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +<<<<<<< HEAD void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +======= +>>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif #if defined(__cplusplus) } // extern "C" From 981bbc58b0cf51da5d55877f6637a8e2c577b76e Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Mon, 12 Jan 2026 17:38:13 +0500 Subject: [PATCH 06/13] ggml-cpu: refactor; add rvv repacking for q2_K Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch-fallback.h | 6 - ggml/src/ggml-cpu/arch/riscv/repack.cpp | 1499 ++++++++++++----------- ggml/src/ggml-cpu/repack.cpp | 464 +++++-- ggml/src/ggml-cpu/repack.h | 15 +- 4 files changed, 1182 insertions(+), 802 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 901510a9045..48315610f2f 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -90,10 +90,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -<<<<<<< HEAD #define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -104,10 +101,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -<<<<<<< HEAD #define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 21dda5c0a9c..3265843bc0a 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -1,4 +1,3 @@ -#include #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -224,7 +223,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); @@ -250,7 +248,7 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); @@ -263,14 +261,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - -void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -339,15 +329,15 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); // Accumulation for 2 sub-blocks. - { - // 4x16 integer accumulators + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); @@ -364,15 +354,16 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), sumi_s_1_16, 16); } - { - // 4x16 integer accumulators + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); @@ -404,93 +395,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); - for (int l = 0; l < nb; l++) { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Load `b_ptr`. - const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); - const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); - const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); - - // Create 4 segments from `b`. - const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); - const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); - const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); - const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); - sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); - } - return; - -#endif - ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -537,7 +441,7 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); @@ -568,6 +472,7 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); + UNUSED(bs); #if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -590,7 +495,7 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); @@ -603,103 +508,184 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); -#if defined __riscv_v_intrinsic - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const int N_COLS_TILE = 16; + const int num_k_blocks = n / QK_K; - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { - // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; - for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators - vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); - // Accumulation loop. - for (int i = 0; i < QK4_0 / 2; i++) { - // Load `b_ptr`. - const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); - const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; - sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); - sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); - sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); - sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); + // 1. Prepare Global Min Scales + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); - sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); - sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); - sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); - sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl); + + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + { + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; } - // Do the final accumulation in i32 to prevent overflow. - const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); - const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); - const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); - const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + // B. Inner Dot Product Loop + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); - } + // Sub-block 0 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 1 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 2 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 3 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + } + + // correction + int sb_base_abs = base_k_phase / 16; + + // Sub-block 0 + { + int sb_idx = sb_base_abs + (k_offsets[0] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 1 + { + int sb_idx = sb_base_abs + (k_offsets[1] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 2 + { + int sb_idx = sb_base_abs + (k_offsets[2] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 3 + { + int sb_idx = sb_base_abs + (k_offsets[3] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + + } // End Phase Loop + + // Apply global Scales + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + + vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl); + v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl); + + } // End K-Block + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); - } } - return; -#endif - ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; + const int ncols_interleaved = 8; + const int blocklen = 8; assert (n % qk == 0); assert (nr % 4 == 0); @@ -715,53 +701,369 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); +#if defined __riscv_v + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; - // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - for (int l = 0; l < nb; l++) { - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + // vector version needs Zvfhmin extension + const float a_scales[4] = { + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[0]), + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[1]), + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[2]), + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); - // We process 4 sub-blocks at once. - for (int j = 0; j < QK_K / 128; j++) { - // Extract the scales and the mins. - // - // Low bits. - vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); - vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); - vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l0; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); - vuint8m2_t scales_hi; - vuint8m2_t mins_hi; - if (!j) { - scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); - mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); - } else { - scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); - mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + sumi_l0 = sumi_hi_m; } - vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); - vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); - // Reduce the mins and multiply with `dmin`. - // - // Correct in `sumf`. - vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + } + } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + + return; + } + +#endif + ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + } + + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, @@ -829,8 +1131,12 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v // Accumulation for 2 sub-blocks. - { - // 4x8 integer accumulators + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4x16 integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); @@ -840,10 +1146,7 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); @@ -885,8 +1188,13 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), sumi_3_s_1_16, 16); } - { - // 4x8 integer accumulators + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4x16 integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); @@ -896,10 +1204,7 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); @@ -943,7 +1248,7 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } - const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16), 16); + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16); const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); @@ -966,14 +1271,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - -void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -996,7 +1293,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const #if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1040,7 +1336,7 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); @@ -1114,7 +1410,7 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); @@ -1137,461 +1433,300 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v - if (__riscv_vlenb() >= QK4_0) { - const size_t vl = QK4_0; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - for (int l = 0; l < nb; l++) { - const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); - const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); - const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); - const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); - const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); - const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); - const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - - // vector version needs Zvfhmin extension - const float a_scales[4] = { - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[0]), - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[1]), - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[2]), - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[3]) - }; - const float b_scales[8] = { - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7]) - }; - const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); - - const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l0; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l0 = sumi_hi_m; - } - +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + const int num_k_blocks = n / QK_K; + const int N_ROWS_TILE = 4; + const int N_COLS_TILE = 16; + assert(nr % N_ROWS_TILE == 0); + assert(nc % N_COLS_TILE == 0); + + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + // --- Tiling Loops --- +#pragma GCC unroll 1 + for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { +#pragma GCC unroll 1 + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + // Base Pointers + const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Persistent Float Accumulators + vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + // --- Super-Block Loop (K=0..255) --- +#pragma GCC unroll 1 + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + + // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers) + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + + // 2. Initialize Integer Accumulators + vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- +#pragma GCC unroll 1 + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins for the 4 interleaved sub-blocks + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + // Unrolled Load Logic { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); - sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; } - const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l1; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l1 = sumi_hi_m; + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; + + // B. Inner Dot Product Loop +#pragma GCC unroll 1 + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase) + + // --- Sub-block 0 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 1 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 2 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 3 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } } - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); - sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); - } + // C CORRECTION + int sb_base_abs = base_k_phase / 16; - const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l2; + // --- Correction Sub-block 0 --- { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l2 = sumi_hi_m; + int sb_abs = sb_base_abs + (k_offsets[0] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + + // Row 0 + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + // Row 1 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + // Row 2 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + // Row 3 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } + // --- Correction Sub-block 1 --- { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); - sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + int sb_abs = sb_base_abs + (k_offsets[1] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } - const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l3; + // --- Correction Sub-block 2 --- { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l3 = sumi_hi_m; + int sb_abs = sb_base_abs + (k_offsets[2] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } + // --- Correction Sub-block 3 --- { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); - sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + int sb_abs = sb_base_abs + (k_offsets[3] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } - } - __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); - } - } - - return; - } - -#endif - ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - -void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - - // 4x8 accumulators. - vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, 8); - vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, 8); - vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, 8); - vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, 8); - - for (int l = 0; l < nb; l++) { - // Load `b_ptr`. - const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); - const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); - const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); - // Create 4 segments from `b`. - const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); - const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); - const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); - const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); + } // End Phase Loop - // Load scales for `b`. - const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + // --- Apply Main Scales --- + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); { - // Load first 8 bytes of `a`. - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[0], 8); - sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl); } - - // Load second 8 bytes of `a`. + // Row 1 { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[1], 8); - sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl); } - - // Load third 8 bytes of `a`. + // Row 2 { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[2], 8); - sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl); } - + // Row 3 { - // Load fourth 8 bytes of `a`. - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[3], 8); - sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl); } - } - __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); - __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); - __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); - __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); + } // End K-Block + + __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl); } } - return; - -#endif - ggml_gemm_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index b9ac85cabfa..ec2444a373e 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,4 +1,3 @@ -#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -1603,6 +1602,96 @@ void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } } + +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + UNUSED(bs); + + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 + 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 + }; + + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_K * y_ptr = y; + + float sumf[16] = {0}; + + // Loop over K-blocks + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[16] = {0}; + int32_t summs[16] = {0}; + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + // Iterate over sub-blocks 0..15 + for (int sb = 0; sb < 16; ++sb) { + // Correction Term + int16_t bsum = bs_lhs[sb]; + int scale_offset = sb_perm[sb] * 16; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits + } + + // Main Dot Product + // Calculate base offsets for Q2 unpacking based on SB + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits + + // Process 16 elements (l=0..15) + for (int l = 0; l < 16; ++l) { + // Q2: Interleaved by column. Byte `l` contains 4 k-values. + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Q8: Linear access + int k = sb * 16 + l; + int8_t q8_val = qs_lhs[k]; + + isum[col] += q8_val * q2_val * d_sb; + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_lhs = y_ptr[k_block].d; + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + + sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); + } + } + + for (int col = 0; col < 16; ++col) { + s[col_tile + col] = sumf[col]; + } + } +} #endif void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2618,6 +2707,102 @@ void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } } + + +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; + + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15 + }; + + // Iterate Rows in tiles of 4 + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + // Iterate Columns in tiles of 16 + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][16]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[4][16]; + int32_t summs[4][16]; + memset(isum, 0, sizeof(isum)); + memset(summs, 0, sizeof(summs)); + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + for (int sb = 0; sb < 16; ++sb) { + int scale_offset = sb_perm[sb] * 16; + + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; + int32_t m_sb = sc_val >> 4; + + // Correction Term + for (int r = 0; r < 4; ++r) { + int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); + summs[r][col] += bs_lhs[bsum_idx] * m_sb; + } + + // Main Dot Product + for (int l = 0; l < 16; ++l) { + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Calculate Q8 index for this specific k and row + int k = sb * 16 + l; + int q8_idx = (k / 4) * 16 + (k % 4); + + for (int r = 0; r < 4; ++r) { + // Add r*4 to jump to the correct row within the 4x4 chunk + int8_t q8_val = qs_lhs[q8_idx + r * 4]; + isum[r][col] += q8_val * q2_val * d_sb; + } + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + for (int r = 0; r < 4; ++r) { + float d_lhs = y_ptr[k_block].d[r]; + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); + } + } + } + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < 16; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; + } + } + } + } +} #endif } // extern "C" @@ -2747,11 +2932,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; // Interleave Q4_K quants by taking 8 bytes at a time - if (blck_size_interleave == 8) { - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; // buffer large enough for the max interleave block size (8 bytes) uint64_t elems; @@ -2759,87 +2943,53 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in memcpy(&out.qs[dst_offset], &elems, blck_size_interleave); } - // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K - // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) - // The output Q4_Kx8 structure has 96 bytes - // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure - // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures - uint8_t s[8], m[8]; - - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = in[j].scales[i] & 63; - m[j] = in[j].scales[i + 4] & 63; - } - - out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K + // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q4_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + uint8_t s[8], m[8]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } - - out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); - - } + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); - } else if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = i / 8; - int dst_offset = i; + } - out.qs[dst_offset] = in[src_id].qs[src_offset]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); } - // RVV repacking. - // - // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. - uint8_t s[64], m[64]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[i * 8 + j] = in[j].scales[i] & 63; - m[i * 8 + j] = in[j].scales[i + 4] & 63; - } - } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[32 + i * 8 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[32 + i * 8 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } - } + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); - for (int i = 0; i < 64; i++) { - out.scales[i] = (s[i] & 15) + (m[i] & 15 << 4); - } - for (int i = 0; i < 32; i++) { - out.scales[64 + i] = (s[i] & 48 >> 4) + (m[i] & 48 >> 2) + (s[32 + i] & 48) + (m[32 + i] & 48 << 2); - } } return out; @@ -3070,6 +3220,68 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in return out; } +static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) { + block_q2_Kx16 out; + constexpr int N_COLS = 16; + + // 1. Copy Super-Scales (d) and Super-Mins (dmin) + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + // 2. Interleave Q2_K Data + const int bytes_per_col = 64; + const int total_bytes = N_COLS * bytes_per_col; + const int end = total_bytes / blck_size_interleave; + + for (int i = 0; i < end; ++i) { + int src_col_id = i % N_COLS; + int src_offset = (i / N_COLS) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave); + } + + // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + int out_idx = 0; + + // Arrays define the sub-block order for each group + const int even_low_sbs[] = {0, 2, 4, 6}; + const int odd_low_sbs[] = {1, 3, 5, 7}; + const int even_high_sbs[] = {8, 10, 12, 14}; + const int odd_high_sbs[] = {9, 11, 13, 15}; + + // Pack Group 1: Even-Low + for (int sb : even_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 2: Odd-Low + for (int sb : odd_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 3: Even-High + for (int sb : even_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 4: Odd-High + for (int sb : odd_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + return out; +} + static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); @@ -3103,7 +3315,7 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8 || interleave_block == 4 || interleave_block == 1); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -3193,6 +3405,41 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + constexpr int nrows_interleaved = 16; + + block_q2_Kx16 * dst = (block_q2_Kx16*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + + block_q2_K dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // This loop gathers 16 separate blocks (one from each column) + // that correspond to the same K-dimension chunk. + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); constexpr int nrows_interleaved = 16; @@ -3429,14 +3676,6 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); } - } else if (blck_size_interleave == 8) { - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); - } } else { GGML_ASSERT(false); } @@ -3492,14 +3731,7 @@ static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_s int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - for (int b = 0; b < 8; ++b) { - out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; - } - - // Generates bus error on RVV as this is auto-vectorized and the - // source might possible not be 8-byte aligned - // - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } } else { GGML_ASSERT(false); @@ -3799,10 +4031,6 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); } -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_K_to_q4_K_8_bl(t, 1, data, data_size); -} - template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); } @@ -3814,6 +4042,10 @@ template <> int repack(struct ggml_tensor * t, const void * template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); } + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size); +} #endif // gemv @@ -3903,10 +4135,6 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3918,6 +4146,10 @@ template <> void gemv(int n, float * s, siz template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} #endif // gemm @@ -4010,10 +4242,6 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -4025,6 +4253,10 @@ template <> void gemm(int n, float * s, siz template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} #endif class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -4437,13 +4669,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instances for RISC-V // - // These implement outer-product style multiplication with interleave of 1. + // These implement outer-product style matrix multiplication kernels with + // an interleave of 1. #if defined __riscv_zvfh static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; - static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q2_K_16x1_q8_K; #endif if (cur->type == GGML_TYPE_Q4_0) { @@ -4506,6 +4739,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q2_K_8x8_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q5_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index ae4ed6a5e41..249aaa4a186 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -64,6 +64,13 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); +struct block_q2_Kx16 { + ggml_half d[16]; // Super-block scale for quantized scales + ggml_half dmin[16]; // Super-block scale for quantized mins + uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks) + uint8_t qs[1024]; // Data (16 cols * 64 bytes per block) +}; +static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding"); struct block_q5_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -181,15 +188,15 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif // Native implementations @@ -243,15 +250,15 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus) From a037b851e1c06aa63f5252854b8024a4babc5c78 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Wed, 4 Mar 2026 17:27:14 +0500 Subject: [PATCH 07/13] ggml-cpu: refactor rvv repack --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 21 ++-- ggml/src/ggml-cpu/repack.cpp | 122 +----------------------- ggml/src/ggml-cpu/repack.h | 25 +---- 3 files changed, 18 insertions(+), 150 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 3265843bc0a..cd5807879ea 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -292,6 +292,10 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); + // Load `dmin`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); + // We process 4 sub-blocks at once. for (int j = 0; j < QK_K / 128; j++) { // Extract the scales and the mins. @@ -324,8 +328,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); // Accumulation for 2 sub-blocks. @@ -1035,6 +1037,9 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + // Load `dmin`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16); + // We process 4 sub-blocks at once. for (int j = 0; j < QK_K / 128; j++) { // Extract the scales and the mins. @@ -1115,14 +1120,10 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[0], 16); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[1], 16); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[2], 16); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[3], 16); + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16); sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index ec2444a373e..d87f866bb1a 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1195,45 +1195,6 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -<<<<<<< HEAD -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1310,8 +1271,6 @@ void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2228,51 +2187,6 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -<<<<<<< HEAD -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][16]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} - void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2361,8 +2275,6 @@ void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -3799,16 +3711,16 @@ static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 1); const block_iq4_nl * src = (const block_iq4_nl *)data; block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; - block_iq4_nl dst_tmp[8]; + block_iq4_nl dst_tmp[16]; int nrow = ggml_nrows(t); - int nrows_interleaved = 8; - int nblocks = t->ne[0] / QK_MXFP4; + int nrows_interleaved = 16; + int nblocks = t->ne[0] / QK4_NL; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); @@ -4002,11 +3914,6 @@ template <> int repack(struct ggml_tensor * t, const void * template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } -<<<<<<< HEAD - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); -} template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size); @@ -4016,8 +3923,6 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size); } -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -4106,11 +4011,6 @@ template <> void gemv(int n, float * s, size template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -<<<<<<< HEAD - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); -} template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4120,8 +4020,6 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -4211,24 +4109,14 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -<<<<<<< HEAD -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 249aaa4a186..cb21edf6239 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -121,7 +121,6 @@ struct block_iq4_nlx16 { }; static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); - struct block_mxfp4x4 { uint8_t e[4]; uint8_t qs[QK_MXFP4 * 2]; @@ -154,12 +153,8 @@ void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -<<<<<<< HEAD void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -174,14 +169,8 @@ void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -<<<<<<< HEAD void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh @@ -216,12 +205,8 @@ void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -<<<<<<< HEAD void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -236,14 +221,8 @@ void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -<<<<<<< HEAD void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -======= ->>>>>>> 7d3d5df3a (ggml-cpu: refactor; add rvv repacking for q4_0, q4_K) void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh @@ -251,14 +230,14 @@ void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GG void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus) From 85c6e8fed7ef980e6ec6def12845332729d1ec21 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Tue, 3 Feb 2026 01:42:41 +0500 Subject: [PATCH 08/13] ggml-cpu: extend rvv gemm, gemv to other vlens --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 1130 ++++++++-- ggml/src/ggml-cpu/repack.cpp | 2515 +++++++++++------------ ggml/src/ggml-cpu/repack.h | 176 +- 3 files changed, 2290 insertions(+), 1531 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index cd5807879ea..9f336bdd50f 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -203,10 +203,10 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -222,43 +222,58 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); +<<<<<<< HEAD #if defined __riscv_v_intrinsic +======= +>>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); // 1x16 Accumulator - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 1x16 Integer Accumulator - vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK4_0 / 2; i++) { // Load `b_ptr`. - const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, ncols_interleaved), 4, ncols_interleaved); const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); - sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16); - sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16); + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, ncols_interleaved); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, ncols_interleaved); } - const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, ncols_interleaved); - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); } - return; +} + +void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_0_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); #endif +} +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +<<<<<<< HEAD } void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -393,73 +408,29 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } return; +======= +>>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) #endif - ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } - -void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - // 1x16 Accumulator1 - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - // 1x16 integer accumulator - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); - - // Accumulation loop. - for (int i = 0; i < QK4_NL / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - - const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); - const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); - sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); - } - - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); - - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); - } - - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); - } - return; +void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_0_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_0_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); #endif - ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -476,61 +447,86 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); UNUSED(bs); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); // 1x16 Accumulator - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 1x16 Integer Accumulator - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK8_0; i++) { // Load `b_ptr`. - const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); - sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], ncols_interleaved), ncols_interleaved); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, ncols_interleaved); - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); } - return; +} + +void ggml_gemv_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q8_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q8_0_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); #endif +} +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q8_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q8_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q8_0_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q8_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q8_0_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif } -void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); assert(nr == 1); assert(nc % 16 == 0); UNUSED(bs); - const int N_COLS_TILE = 16; const int num_k_blocks = n / QK_K; - const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); - for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + const size_t vl = __riscv_vsetvl_e32m2(ncols_interleaved); + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; - const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + const block_q2_Kx* rhs_base_ptr = (const block_q2_Kx*)vx + (col_tile / ncols_interleaved) * num_k_blocks; vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); for (int k_block = 0; k_block < num_k_blocks; ++k_block) { const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; - const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + const block_q2_Kx* rhs_current = &rhs_base_ptr[k_block]; // 1. Prepare Global Min Scales vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); @@ -683,6 +679,276 @@ void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } +void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q2_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q2_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q2_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q2_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q2_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q2_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q2_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); + + // Accumulation for 2 sub-blocks. + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, ncols_interleaved); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, 16); + } + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, ncols_interleaved); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q4_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q4_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 1x16 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], ncols_interleaved); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], ncols_interleaved); + sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, ncols_interleaved), ncols_interleaved); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_iq4_nl_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_iq4_nl_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_iq4_nl_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_iq4_nl_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_iq4_nl_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_iq4_nl_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -909,10 +1175,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -929,74 +1195,91 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); +<<<<<<< HEAD #if defined __riscv_v_intrinsic +======= + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + +>>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 4x16 integer accumulators - vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK4_0 / 2; i++) { // Load `b_ptr`. - const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); - const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); - - sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); - sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); - sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); - sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); - - sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); - sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); - sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); - sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, ncols_interleaved), 4, ncols_interleaved); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, ncols_interleaved); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, ncols_interleaved); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, ncols_interleaved); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, ncols_interleaved); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, ncols_interleaved); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, ncols_interleaved); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, ncols_interleaved); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, ncols_interleaved); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, ncols_interleaved); } // Do the final accumulation in i32 to prevent overflow. - const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); - const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); - const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); - const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, ncols_interleaved); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, ncols_interleaved); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, ncols_interleaved); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); - } - - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); } } - return; +} + +void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_0_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); #endif +} +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +<<<<<<< HEAD } void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1268,13 +1551,30 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } return; +======= +>>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) +#endif +} +void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_0_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_0_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); #endif - ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; +<<<<<<< HEAD const int ncols_interleaved = 16; const int blocklen = 1; @@ -1364,6 +1664,8 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 16; +======= +>>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) const int blocklen = 1; assert (n % qk == 0); @@ -1384,49 +1686,48 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 4x16 Integer Accumulators - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK8_0; i++) { // Load `b_ptr`. - const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); - sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16); - sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16); - sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16); - sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); + sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], ncols_interleaved), ncols_interleaved); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], ncols_interleaved); - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); } } return; @@ -1434,23 +1735,52 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q8_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q8_0_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q8_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q8_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q8_0_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q8_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q8_0_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); const int num_k_blocks = n / QK_K; const int N_ROWS_TILE = 4; - const int N_COLS_TILE = 16; assert(nr % N_ROWS_TILE == 0); - assert(nc % N_COLS_TILE == 0); + assert(nc % ncols_interleaved == 0); - const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + const size_t vl = __riscv_vsetvl_e32m2(ncols_interleaved); // --- Tiling Loops --- #pragma GCC unroll 1 for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { #pragma GCC unroll 1 - for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { // Base Pointers const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; - const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / ncols_interleaved) * num_k_blocks; // Persistent Float Accumulators vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); @@ -1731,3 +2061,447 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } } + +void ggml_gemm_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q2_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q2_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q2_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q2_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q2_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q2_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q2_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, ncols_interleaved), ncols_interleaved), ncols_interleaved); + + + // Accumulation for 2 sub-blocks. + { + // 4x8 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, ncols_interleaved); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, ncols_interleaved); + } + { + // 4x16 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, ncols_interleaved); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q4_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q4_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], ncols_interleaved); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], ncols_interleaved); + + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], ncols_interleaved); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], ncols_interleaved); + + sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, ncols_interleaved), ncols_interleaved); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_iq4_nl_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_iq4_nl_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_iq4_nl_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_iq4_nl_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_iq4_nl_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_iq4_nl_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index d87f866bb1a..7f101d946b9 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,3 +1,4 @@ +#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -353,191 +354,387 @@ template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTR } #endif -template -static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - const int blocks_per_half = 64 / blocklen; +#if defined __riscv_zvfh +template +static inline void ggml_gemv_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +template +static inline void ggml_gemv_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + assert(nr == 1); assert(n % qk == 0); assert(nc % ncols_interleaved == 0); UNUSED(bs); UNUSED(nr); - float sumf[8]; + float sumf[16]; + int sumi; - const block_q8_K * a_ptr = (const block_q8_K *) vy; + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; + sumf[j] = 0.0; } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; - const int base_h = base_l + 64; + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; +template +static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % ncols_interleaved == 0); - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; + UNUSED(bs); - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; + const int nb = n / QK_K; + const block_q2_Kx * x = (const block_q2_Kx *)vx; + const block_q8_K * y = (const block_q8_K *)vy; - for (int j = 0; j < ncols_interleaved; j++) { - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 + 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 + }; + + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { + const block_q2_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; + const block_q8_K * y_ptr = y; + + float sumf[16] = {0}; + + // Loop over K-blocks + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[16] = {0}; + int32_t summs[16] = {0}; + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + // Iterate over sub-blocks 0..15 + for (int sb = 0; sb < 16; ++sb) { + // Correction Term + int16_t bsum = bs_lhs[sb]; + int scale_offset = sb_perm[sb] * 16; - int sumi_l = 0; - int sumi_h = 0; + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits + } - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + // Main Dot Product + // Calculate base offsets for Q2 unpacking based on SB + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / blocklen; - const int qh_pos_l = qh_idx_l % blocklen; - const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + int shift = ((sb / 2) % 4) * 2; - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / blocklen; - const int qh_pos_h = qh_idx_h % blocklen; - const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; + // Process 16 elements (l=0..15) + for (int l = 0; l < 16; ++l) { + // Q2: Interleaved by column. Byte `l` contains 4 k-values. + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - const int8_t a_l = a_ptr[l].qs[base_l + i]; - const int8_t a_h = a_ptr[l].qs[base_h + i]; + // Q8: Linear access + int k = sb * 16 + l; + int8_t q8_val = qs_lhs[k]; - sumi_l += q_l * a_l; - sumi_h += q_h * a_h; + isum[col] += q8_val * q2_val * d_sb; } - - sumf[j] += - (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; } } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_lhs = y_ptr[k_block].d; + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + + sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); + } + } + + for (int col = 0; col < 16; ++col) { + s[col_tile + col] = sumf[col]; } + } +} +template +static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint8_t scales[ncols_interleaved * 8]; + uint8_t mins[ncols_interleaved * 8]; + int sumi1; + int sumi2; + int sumi; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + const int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; } } } -template -static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - const int blocks_per_half = 64 / blocklen; - const int q8_half_stride = 512; - const int q8_low_high_step = 256; +template +static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + assert(nr == 1); assert(n % qk == 0); - assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); UNUSED(bs); + UNUSED(nr); - float sumf[4][8]; + float sumf[16]; + int sumi; - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); - for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0f; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} +#endif + +#if defined __riscv_zvfh +template +static inline void ggml_gemm_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][ncols_interleaved]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } for (int l = 0; l < nb; l++) { for (int k = 0; k < (qk / (2 * blocklen)); k++) { - const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; +template +static inline void ggml_gemm_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); - const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + float sumf[4][ncols_interleaved]; + int sumi; + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / blocklen; - const int qh_pos_l = qh_idx_l % blocklen; - const int qh_offset_l = - qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / blocklen; - const int qh_pos_h = qh_idx_h % blocklen; - const int qh_offset_h = - qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; - const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; - - sumi_l += q_l * q8_l; - sumi_h += q_h * q8_h; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; } - - sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * - a_ptr[l].d[m]; + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } } - for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; @@ -547,207 +744,240 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, } } -template -static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; +template +static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15 + }; - UNUSED(bs); - UNUSED(nr); + // Iterate Rows in tiles of 4 + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + // Iterate Columns in tiles of 16 + for (int col_tile = 0; col_tile < nc; col_tile += 16) { - float sumf[ncols_interleaved]; - float sum_minf[ncols_interleaved]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + float sumf[4][16]; + memset(sumf, 0, sizeof(sumf)); - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - constexpr int scale_stride = 32; - uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; - uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[4][16]; + int32_t summs[4][16]; + memset(isum, 0, sizeof(isum)); + memset(summs, 0, sizeof(summs)); - const int qh_shift = (k / (32 / blocklen)) * 2; - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + for (int sb = 0; sb < 16; ++sb) { + int scale_offset = sb_perm[sb] * 16; + + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + int shift = ((sb / 2) % 4) * 2; - const int qh_idx = (k * blocklen + i) % 32; - const int qh_chunk = qh_idx / blocklen; - const int qh_pos = qh_idx % blocklen; - const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; + int32_t m_sb = sc_val >> 4; - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + // Correction Term + for (int r = 0; r < 4; ++r) { + int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); + summs[r][col] += bs_lhs[bsum_idx] * m_sb; + } - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + // Main Dot Product + for (int l = 0; l < 16; ++l) { + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; + // Calculate Q8 index for this specific k and row + int k = sb * 16 + l; + int q8_idx = (k / 4) * 16 + (k % 4); - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; + for (int r = 0; r < 4; ++r) { + // Add r*4 to jump to the correct row within the 4x4 chunk + int8_t q8_val = qs_lhs[q8_idx + r * 4]; + isum[r][col] += q8_val * q2_val * d_sb; + } + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + for (int r = 0; r < 4; ++r) { + float d_lhs = y_ptr[k_block].d[r]; + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; } } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < 16; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; } } } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } } } -template -static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; +template +static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); - float sumf[4][ncols_interleaved]; - float sum_minf[4][ncols_interleaved]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint8_t scales[8 * ncols_interleaved]; + uint8_t mins[8 * ncols_interleaved]; + int sumi1; + int sumi2; + int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; + sumf[m][j] = 0.0; sum_minf[m][j] = 0.0; } } for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - constexpr int scale_stride = 32; - uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; - uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; - - const int qh_shift = (k / (32 / blocklen)) * 2; - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * blocklen + i) % 32; - const int qh_chunk = qh_idx / blocklen; - const int qh_pos = qh_idx % blocklen; - const int b_qh_offset = - qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; - const int q8_offset = (k / (32 / blocklen)) * 256 + - (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + const int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); sumi1 = sumi1 * scales_0[j]; sumi2 = sumi2 * scales_1[j]; sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; } } } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +template +static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { for (int m = 0; m < 4; m++) { - const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } } for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; - } + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } } } } +#endif extern "C" { @@ -1366,290 +1596,74 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } #if defined __riscv_zvfh +// Q4_0 +void ggml_gemv_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } + ggml_gemv_q4_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); } - -void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - float sumf[16]; - float sum_minf[16]; - uint8_t scales[128]; - uint8_t mins[128]; - int sumi1; - int sumi2; - int sumi; - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; - sum_minf[j] = 0.0f; - } - for (int l = 0; l < nb; l++) { - for (int i = 0; i < 128; i++) { - scales[i] = b_ptr[l].scales[i] & 0x0F; - mins[i] = b_ptr[l].scales[i] >> 4; - } - for (int i = 0; i < 64; i++) { - scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; - mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; - scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); - mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; - } - for (int sb = 0; sb < 8; sb++) { - uint8_t *min = &mins[sb * 16]; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb += 2) { - uint8_t *scales_0 = &scales[sb * 16]; - uint8_t *scales_1 = &scales[(sb + 1) * 16]; - for (int i = 0; i < QK4_0; i++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); - sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); - sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } - } +void ggml_gemv_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); } - -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } +void ggml_gemv_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q8_0 +void ggml_gemv_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * blocklen + i]; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } + ggml_gemv_q8_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q2_K +void ggml_gemv_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - assert(n % QK_K == 0); - assert(nr == 1); - assert(nc % 16 == 0); - - UNUSED(bs); - - const int nb = n / QK_K; - const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; - const block_q8_K * y = (const block_q8_K *)vy; - - // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) - const int sb_perm[16] = { - 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 - 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 - }; - - for (int col_tile = 0; col_tile < nc; col_tile += 16) { - const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; - const block_q8_K * y_ptr = y; - - float sumf[16] = {0}; - - // Loop over K-blocks - for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[16] = {0}; - int32_t summs[16] = {0}; - - const uint8_t * qs_rhs = x_ptr[k_block].qs; - const uint8_t * sc_rhs = x_ptr[k_block].scales; - const int8_t * qs_lhs = y_ptr[k_block].qs; - const int16_t * bs_lhs = y_ptr[k_block].bsums; - - // Iterate over sub-blocks 0..15 - for (int sb = 0; sb < 16; ++sb) { - // Correction Term - int16_t bsum = bs_lhs[sb]; - int scale_offset = sb_perm[sb] * 16; - - for (int col = 0; col < 16; ++col) { - uint8_t sc_val = sc_rhs[scale_offset + col]; - summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits - } - - // Main Dot Product - // Calculate base offsets for Q2 unpacking based on SB - int byte_base; - if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; - else byte_base = (sb % 2 == 0) ? 32 : 48; - - int shift = ((sb / 2) % 4) * 2; - - for (int col = 0; col < 16; ++col) { - uint8_t sc_val = sc_rhs[scale_offset + col]; - int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits - - // Process 16 elements (l=0..15) - for (int l = 0; l < 16; ++l) { - // Q2: Interleaved by column. Byte `l` contains 4 k-values. - int qs_idx = (byte_base + l) * 16 + col; - uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - - // Q8: Linear access - int k = sb * 16 + l; - int8_t q8_val = qs_lhs[k]; - - isum[col] += q8_val * q2_val * d_sb; - } - } - } - - // Finalize K-Block - for (int col = 0; col < 16; ++col) { - float d_lhs = y_ptr[k_block].d; - float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); - float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); - - float d_all = d_lhs * d_rhs; - float d_min = d_lhs * dm_rhs; + ggml_gemv_q2_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); - } - } +// Q4_K +void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - for (int col = 0; col < 16; ++col) { - s[col_tile + col] = sumf[col]; - } - } +// IQ4_NL +void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } #endif @@ -2272,325 +2286,86 @@ void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } } - } -} - -void ggml_gemm_q8_0_4x4_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; - } - sumf[m][j] += - sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - - - -void ggml_gemm_q8_0_4x8_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; - } - sumf[m][j] += - sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - -#if defined __riscv_zvfh -void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - float sumf[4][16]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} - -void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - float sumf[4][16]; - float sum_minf[4][16]; - uint8_t scales[128]; - uint8_t mins[128]; - int sumi1; - int sumi2; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int i = 0; i < 128; i++) { - scales[i] = b_ptr[l].scales[i] & 0x0F; - mins[i] = b_ptr[l].scales[i] >> 4; - } - for (int i = 0; i < 64; i++) { - scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; - mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; - scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); - mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; - } - - for (int sb = 0; sb < 8; sb++) { - uint8_t *min = &mins[sb * 16]; - for(int m = 0; m < 4; m++) { - const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; - for(int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; - } - } - } - - for (int sb = 0; sb < 8; sb += 2) { - uint8_t *scales_0 = &scales[sb * 16]; - uint8_t *scales_1 = &scales[(sb + 1) * 16]; - - for (int i = 0; i < QK4_0; i++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - - const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); - sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); - sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; - } - } - } - } -} - -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; + } +} + +void ggml_gemm_q8_0_4x4_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; assert(n % qk == 0); assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][16]; - int sumi; + float sumf[4][4]; + int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } } for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int k = 0; k < (qk / blocklen); k++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } } for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) + for (int j = 0; j < ncols_interleaved; j++) { s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } } } } } -void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + + +void ggml_gemm_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; + const int ncols_interleaved = 4; + const int blocklen = 8; assert(n % qk == 0); assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][16]; + float sumf[4][4]; int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; @@ -2620,100 +2395,75 @@ void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +#if defined __riscv_zvfh +// Q4_0 +void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); +} -void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - assert(n % QK_K == 0); - assert(nr % 4 == 0); - assert(nc % 16 == 0); - const int nb = n / QK_K; - const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; - const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; - - const int sb_perm[16] = { - 0, 4, 1, 5, 2, 6, 3, 7, - 8, 12, 9, 13, 10, 14, 11, 15 - }; - - // Iterate Rows in tiles of 4 - for (int row_tile = 0; row_tile < nr; row_tile += 4) { - // Iterate Columns in tiles of 16 - for (int col_tile = 0; col_tile < nc; col_tile += 16) { - - const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; - const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; - - float sumf[4][16]; - memset(sumf, 0, sizeof(sumf)); - - for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[4][16]; - int32_t summs[4][16]; - memset(isum, 0, sizeof(isum)); - memset(summs, 0, sizeof(summs)); - - const uint8_t * qs_rhs = x_ptr[k_block].qs; - const uint8_t * sc_rhs = x_ptr[k_block].scales; - const int8_t * qs_lhs = y_ptr[k_block].qs; - const int16_t * bs_lhs = y_ptr[k_block].bsums; - - for (int sb = 0; sb < 16; ++sb) { - int scale_offset = sb_perm[sb] * 16; - - int byte_base; - if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; - else byte_base = (sb % 2 == 0) ? 32 : 48; - int shift = ((sb / 2) % 4) * 2; - - for (int col = 0; col < 16; ++col) { - uint8_t sc_val = sc_rhs[scale_offset + col]; - int32_t d_sb = sc_val & 0xF; - int32_t m_sb = sc_val >> 4; - - // Correction Term - for (int r = 0; r < 4; ++r) { - int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); - summs[r][col] += bs_lhs[bsum_idx] * m_sb; - } - - // Main Dot Product - for (int l = 0; l < 16; ++l) { - int qs_idx = (byte_base + l) * 16 + col; - uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - - // Calculate Q8 index for this specific k and row - int k = sb * 16 + l; - int q8_idx = (k / 4) * 16 + (k % 4); - - for (int r = 0; r < 4; ++r) { - // Add r*4 to jump to the correct row within the 4x4 chunk - int8_t q8_val = qs_lhs[q8_idx + r * 4]; - isum[r][col] += q8_val * q2_val * d_sb; - } - } - } - } +// Q8_0 +void ggml_gemm_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); +} - // Finalize K-Block - for (int col = 0; col < 16; ++col) { - float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); - float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); +// Q2_K +void ggml_gemm_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - for (int r = 0; r < 4; ++r) { - float d_lhs = y_ptr[k_block].d[r]; - float d_all = d_lhs * d_rhs; - float d_min = d_lhs * dm_rhs; - sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); - } - } - } +// Q4_K +void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - for (int r = 0; r < 4; ++r) { - for (int col = 0; col < 16; ++col) { - s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; - } - } - } - } +// IQ4_NL +void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } #endif @@ -2805,31 +2555,6 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_0 * 8 / blck_size_interleave; - - if (blck_size_interleave == 1) { - const uint8_t xor_mask = 0x88; - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; - - out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; - } - } else { - GGML_ASSERT(false); - } - - return out; -} - static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -2907,58 +2632,6 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in return out; } -static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) { - block_q4_Kx16 out; - //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; - } - - for (int i = 0; i < 16; i++) { - out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; - } - - const int end = QK_K * 8 / blck_size_interleave; - - if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; - - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - - // RVV repacking. - // - // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. - uint8_t s[128], m[128]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 16; j++) { - s[i * 16 + j] = in[j].scales[i] & 63; - m[i * 16 + j] = in[j].scales[i + 4] & 63; - } - } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 16; j++) { - s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } - } - - for (int i = 0; i < 128; i++) { - out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); - } - for (int i = 0; i < 64; i++) { - out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2); - } - } else { - GGML_ASSERT(false); - } - - return out; -} - static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) { block_q2_Kx8 out; @@ -3100,94 +2773,32 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in int src_offset = (i / n_blocks) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - uint64_t elem_ls; - memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave); - memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave); - } - - // Interleave high bits using same chunk size as low bits - const int end_hs = end_ls / 2; - for (int i = 0; i < end_hs; ++i) { - int src_id = i % n_blocks; - int src_offset = (i / n_blocks) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint64_t elem_hs; - memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave); - memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave); - } - - // The below logic is designed so as to unpack and rearrange scales in Q6_K - // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants - // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales - // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block) - constexpr int n_scales = QK_K / 16; - - for (int i = 0; i < n_blocks; i++) { - for (int j = 0; j < n_scales; j++) { - out.scales[j * n_blocks + i] = in[i].scales[j]; - } - } - - return out; -} - -static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) { - block_q2_Kx16 out; - constexpr int N_COLS = 16; - - // 1. Copy Super-Scales (d) and Super-Mins (dmin) - for (int i = 0; i < N_COLS; i++) { - out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; - out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; - } - - // 2. Interleave Q2_K Data - const int bytes_per_col = 64; - const int total_bytes = N_COLS * bytes_per_col; - const int end = total_bytes / blck_size_interleave; - - for (int i = 0; i < end; ++i) { - int src_col_id = i % N_COLS; - int src_offset = (i / N_COLS) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave); - } - - // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout - int out_idx = 0; - - // Arrays define the sub-block order for each group - const int even_low_sbs[] = {0, 2, 4, 6}; - const int odd_low_sbs[] = {1, 3, 5, 7}; - const int even_high_sbs[] = {8, 10, 12, 14}; - const int odd_high_sbs[] = {9, 11, 13, 15}; - - // Pack Group 1: Even-Low - for (int sb : even_low_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } + uint64_t elem_ls; + memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave); + memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave); } - // Pack Group 2: Odd-Low - for (int sb : odd_low_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } - } + // Interleave high bits using same chunk size as low bits + const int end_hs = end_ls / 2; + for (int i = 0; i < end_hs; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - // Pack Group 3: Even-High - for (int sb : even_high_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } + uint64_t elem_hs; + memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave); + memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave); } - // Pack Group 4: Odd-High - for (int sb : odd_high_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; + // The below logic is designed so as to unpack and rearrange scales in Q6_K + // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants + // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales + // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block) + constexpr int n_scales = QK_K / 16; + + for (int i = 0; i < n_blocks; i++) { + for (int j = 0; j < n_scales; j++) { + out.scales[j * n_blocks + i] = in[i].scales[j]; } } @@ -3256,17 +2867,18 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } -static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - constexpr int nrows_interleaved = 16; +static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; - block_q4_Kx16 * dst = (block_q4_Kx16*)t->data; - const block_q4_K * src = (const block_q4_K*) data; - block_q4_K dst_tmp[16]; + block_q2_Kx8 * dst = (block_q2_Kx8*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + block_q2_K dst_tmp[8]; int nrow = ggml_nrows(t); int nblocks = t->ne[0] / QK_K; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3274,10 +2886,10 @@ static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_bloc for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block); + *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } @@ -3286,18 +2898,21 @@ static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_bloc GGML_UNUSED(data_size); } -static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q2_K); - GGML_ASSERT(interleave_block == 8); +static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 8; - block_q2_Kx8 * dst = (block_q2_Kx8*)t->data; - const block_q2_K * src = (const block_q2_K*) data; - block_q2_K dst_tmp[8]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + block_q5_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3308,28 +2923,25 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); + *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } return 0; - - GGML_UNUSED(data_size); } -static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q2_K); - constexpr int nrows_interleaved = 16; - - block_q2_Kx16 * dst = (block_q2_Kx16*)t->data; - const block_q2_K * src = (const block_q2_K*) data; - - block_q2_K dst_tmp[nrows_interleaved]; +static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 8; + block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q6_K dst_tmp[8]; int nrow = ggml_nrows(t); int nblocks = t->ne[0] / QK_K; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3337,28 +2949,24 @@ static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_bloc for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - // This loop gathers 16 separate blocks (one from each column) - // that correspond to the same K-dimension chunk. - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - - *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block); + *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } return 0; - - GGML_UNUSED(data_size); } -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { +static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - constexpr int nrows_interleaved = 16; + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; - block_q4_0x16 * dst = (block_q4_0x16*)t->data; + block_q4_0x8 * dst = (block_q4_0x8*)t->data; const block_q4_0 * src = (const block_q4_0*) data; - block_q4_0 dst_tmp[16]; + block_q4_0 dst_tmp[8]; int nrow = ggml_nrows(t); int nblocks = t->ne[0] / QK4_0; @@ -3373,7 +2981,7 @@ static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_bloc for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } @@ -3382,21 +2990,21 @@ static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_bloc GGML_UNUSED(data_size); } -static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, +static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); - constexpr int nrows_interleaved = 8; + constexpr int nrows_interleaved = 4; - block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; - const block_q5_K * src = (const block_q5_K *) data; - block_q5_K dst_tmp[8]; + block_q8_0x4 * dst = (block_q8_0x4 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[4]; int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + int nblocks = t->ne[0] / QK8_0; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3407,25 +3015,62 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); + *dst++ = make_block_q8_0x4(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } return 0; } -static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q6_K); - GGML_ASSERT(interleave_block == 4 || interleave_block == 8); - constexpr int nrows_interleaved = 8; +static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 2 / blck_size_interleave; + + // TODO: this branch seems wrong + //if (blck_size_interleave == 8) { + // for (int i = 0; i < end; ++i) { + // int src_id = i % 4; + // int src_offset = (i / 4) * blck_size_interleave; + // int dst_offset = i * blck_size_interleave; + + // // Using memcpy to avoid unaligned memory accesses + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + // } + //} else + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + GGML_ASSERT(interleave_block == 4); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; + + block_iq4_nl dst_tmp[4]; - block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; - const block_q6_K * src = (const block_q6_K *) data; - block_q6_K dst_tmp[8]; int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK4_NL; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3436,36 +3081,64 @@ static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block); + *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } return 0; + + GGML_UNUSED(data_size); } -static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); +static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 4 / blck_size_interleave; + + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); GGML_ASSERT(interleave_block == 8); - constexpr int nrows_interleaved = 8; - block_q4_0x8 * dst = (block_q4_0x8*)t->data; - const block_q4_0 * src = (const block_q4_0*) data; - block_q4_0 dst_tmp[8]; + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data; + + block_iq4_nl dst_tmp[8]; + int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK4_NL; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + if (t->ne[1] % nrows_interleaved != 0) { return -1; } for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); + *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } @@ -3474,21 +3147,40 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } -static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q8_0); - GGML_ASSERT(interleave_block == 4 || interleave_block == 8); - constexpr int nrows_interleaved = 4; +#if defined __riscv_zvfh +template +static block<4, nrows_interleaved> make_block_q4_0xMx1(block_q4_0 * in) { + block<4, nrows_interleaved> out; + + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * nrows_interleaved / 2; + + const uint8_t xor_mask = 0x88; + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +template +static int repack_q4_0_to_q4_0_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - block_q8_0x4 * dst = (block_q8_0x4 *) t->data; - const block_q8_0 * src = (const block_q8_0 *) data; - block_q8_0 dst_tmp[4]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK8_0; + block<4, nrows_interleaved> * dst = (block<4, nrows_interleaved>*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3496,49 +3188,45 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { + for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q8_0x4(dst_tmp, interleave_block); + *dst++ = make_block_q4_0xMx1(dst_tmp); } src += nrows_interleaved * nblocks; } return 0; + + GGML_UNUSED(data_size); } -static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) { - block_q8_0x16 out; +template +static block<8, nrows_interleaved> make_block_q8_0xMx1(block_q8_0 * in) { + block<8, nrows_interleaved> out; - for (int i = 0; i < 16; i++) { + for (int i = 0; i < nrows_interleaved; i++) { out.d[i] = in[i].d; } - const int end = QK8_0 * 16 / blck_size_interleave; + const int end = QK8_0 * nrows_interleaved; - if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - } else { - GGML_ASSERT(false); + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + out.qs[dst_offset] = in[src_id].qs[src_offset]; } return out; } -static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { +template +static int repack_q8_0_to_q8_0_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q8_0); - constexpr int nrows_interleaved = 16; - block_q8_0x16 * dst = (block_q8_0x16 *) t->data; - const block_q8_0 * src = (const block_q8_0 *) data; - block_q8_0 dst_tmp[16]; + block<8, nrows_interleaved> * dst = (block<8, nrows_interleaved> *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[nrows_interleaved]; int nrow = ggml_nrows(t); int nblocks = t->ne[0] / QK8_0; @@ -3553,62 +3241,89 @@ static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor * t, for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q8_0x16(dst_tmp, interleave_block); + *dst++ = make_block_q8_0xMx1(dst_tmp); } src += nrows_interleaved * nblocks; } return 0; } -static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { - block_iq4_nlx4 out; +template +static block_q2_Kx make_block_q2_KxMx1(const block_q2_K * in) { + block_q2_Kx out; + constexpr int N_COLS = nrows_interleaved; - for (int i = 0; i < 4; i++) { - out.d[i] = in[i].d; + // 1. Copy Super-Scales (d) and Super-Mins (dmin) + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; } - const int end = QK4_NL * 2 / blck_size_interleave; + // 2. Interleave Q2_K Data + const int bytes_per_col = 64; + const int total_bytes = N_COLS * bytes_per_col; + const int end = total_bytes; - // TODO: this branch seems wrong - //if (blck_size_interleave == 8) { - // for (int i = 0; i < end; ++i) { - // int src_id = i % 4; - // int src_offset = (i / 4) * blck_size_interleave; - // int dst_offset = i * blck_size_interleave; + for (int i = 0; i < end; ++i) { + int src_col_id = i % N_COLS; + int src_offset = (i / N_COLS); + int dst_offset = i * 1; + memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], 1); + } - // // Using memcpy to avoid unaligned memory accesses - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); - // } - //} else - if (blck_size_interleave == 4) { - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + int out_idx = 0; - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + // Arrays define the sub-block order for each group + const int even_low_sbs[] = {0, 2, 4, 6}; + const int odd_low_sbs[] = {1, 3, 5, 7}; + const int even_high_sbs[] = {8, 10, 12, 14}; + const int odd_high_sbs[] = {9, 11, 13, 15}; + + // Pack Group 1: Even-Low + for (int sb : even_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 2: Odd-Low + for (int sb : odd_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 3: Even-High + for (int sb : even_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 4: Odd-High + for (int sb : odd_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; } - } else { - GGML_ASSERT(false); } return out; } -static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 4); +template +static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; + block_q2_Kx * dst = (block_q2_Kx*)t->data; + const block_q2_K * src = (const block_q2_K*) data; - block_iq4_nl dst_tmp[4]; + block_q2_K dst_tmp[nrows_interleaved]; int nrow = ggml_nrows(t); - int nrows_interleaved = 4; - int nblocks = t->ne[0] / QK4_NL; + int nblocks = t->ne[0] / QK_K; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; @@ -3616,10 +3331,13 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { + // This loop gathers 16 separate blocks (one from each column) + // that correspond to the same K-dimension chunk. + for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); + + *dst++ = make_block_q2_KxMx1(dst_tmp); } src += nrows_interleaved * nblocks; } @@ -3628,55 +3346,77 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } -static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) { - block_iq4_nlx8 out; +template +static block_q4_Kx make_block_q4_KxMx1(block_q4_K * in) { + block_q4_Kx out; + //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].d; + for (int i = 0; i < nrows_interleaved; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; } - const int end = QK4_NL * 4 / blck_size_interleave; + const int end = QK_K * nrows_interleaved / 2; - if (blck_size_interleave == 8) { - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[8 * nrows_interleaved], m[8 * nrows_interleaved]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[i * nrows_interleaved + j] = in[j].scales[i] & 63; + m[i * nrows_interleaved + j] = in[j].scales[i + 4] & 63; } - } else { - GGML_ASSERT(false); + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 8 * nrows_interleaved; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 8 * nrows_interleaved / 2; i++) { + out.scales[nrows_interleaved * 8 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[nrows_interleaved * 8 / 2 + i] & 48) | ((m[nrows_interleaved * 8 / 2 + i] & 48) << 2); } return out; } -static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 8); - - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data; - - block_iq4_nl dst_tmp[8]; +template +static int repack_q4_K_to_q4_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + block_q4_Kx * dst = (block_q4_Kx*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[nrows_interleaved]; int nrow = ggml_nrows(t); - int nrows_interleaved = 8; - int nblocks = t->ne[0] / QK4_NL; + int nblocks = t->ne[0] / QK_K; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); - if (t->ne[1] % nrows_interleaved != 0) { + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; } for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { + for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block); + *dst++ = make_block_q4_KxMx1(dst_tmp); } src += nrows_interleaved * nblocks; } @@ -3685,41 +3425,37 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } -static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) { - block_iq4_nlx16 out; +template +static block_iq4_nlx make_block_iq4_nlxMx1(block_iq4_nl * in) { + block_iq4_nlx out; - for (int i = 0; i < 16; i++) { + for (int i = 0; i < nrows_interleaved; i++) { out.d[i] = in[i].d; } - const int end = QK4_NL * 8 / blck_size_interleave; + const int end = QK4_NL * nrows_interleaved / 2; - if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - } else { - GGML_ASSERT(false); + out.qs[dst_offset] = in[src_id].qs[src_offset]; } return out; } -static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { +template +static int repack_iq4_nl_to_iq4_nl_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 1); - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx * dst = ( block_iq4_nlx *)t->data; - block_iq4_nl dst_tmp[16]; + block_iq4_nl dst_tmp[nrows_interleaved]; int nrow = ggml_nrows(t); - int nrows_interleaved = 16; int nblocks = t->ne[0] / QK4_NL; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); @@ -3733,7 +3469,7 @@ static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_ for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block); + *dst++ = make_block_iq4_nlxMx1(dst_tmp); } src += nrows_interleaved * nblocks; } @@ -3741,6 +3477,7 @@ static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_ GGML_UNUSED(data_size); } +#endif static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { block_mxfp4x4 out; @@ -3932,24 +3669,74 @@ template <> int repack(struct ggml_tensor * t, const void * da } #if defined __riscv_zvfh +// Q4_0 +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_Mx1_bl<8>(t, data, data_size); +} template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); + return repack_q4_0_to_q4_0_Mx1_bl<16>(t, data, data_size); } - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_Mx1_bl<32>(t, data, data_size); } - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_Mx1_bl<64>(t, data, data_size); } +// Q8_0 +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_Mx1_bl<8>(t, data, data_size); +} template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); + return repack_q8_0_to_q8_0_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_Mx1_bl<64>(t, data, data_size); } +// Q2_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_Mx1_bl<8>(t, data, data_size); +} template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size); + return repack_q2_K_to_q2_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_Mx1_bl<64>(t, data, data_size); +} + +// Q4_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<64>(t, data, data_size); +} + +// IQ4_NL +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<64>(t, data, data_size); } #endif @@ -4029,25 +3816,75 @@ template <> void gemv(int n, float * s, size_t } #if defined __riscv_zvfh +// Q4_0 +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } +// Q8_0 +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +// Q2_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// Q4_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// IQ4_NL +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif // gemm @@ -4126,25 +3963,75 @@ template <> void gemm(int n, float * s, size_t } #if defined __riscv_zvfh +// Q4_0 +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } +// Q8_0 +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +// Q2_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// Q4_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// IQ4_NL +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -4560,11 +4447,35 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // These implement outer-product style matrix multiplication kernels with // an interleave of 1. #if defined __riscv_zvfh + // Q4_0 + static const ggml::cpu::repack::tensor_traits q4_0_8x1_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; - static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; - static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_64x1_q8_0; + + // Q8_0 + static const ggml::cpu::repack::tensor_traits q8_0_8x1_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_64x1_q8_0; + + // Q2_K + static const ggml::cpu::repack::tensor_traits q2_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q2_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q2_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q2_K_64x1_q8_K; + + // Q4_K + static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_64x1_q8_K; + + // IQ4_NL + static const ggml::cpu::repack::tensor_traits iq4_nl_8x1_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_64x1_q8_0; #endif if (cur->type == GGML_TYPE_Q4_0) { @@ -4586,10 +4497,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q4_0_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_0_64x1_q8_0; } break; } default: { return nullptr; } } #endif @@ -4613,10 +4524,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q4_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_K_64x1_q8_K; } break; } default: { return nullptr; } } #endif @@ -4630,10 +4541,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q2_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q2_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q2_K_64x1_q8_K; } break; } default: { return nullptr; } } #endif @@ -4674,10 +4585,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &iq4_nl_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &iq4_nl_64x1_q8_0; } break; } default: { return nullptr; } } #endif @@ -4707,10 +4618,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q8_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q8_0_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q8_0_64x1_q8_0; } break; } default: { return nullptr; } } #endif diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb21edf6239..0d97c9b9b53 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -29,48 +29,58 @@ template struct block { static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block<4, 64>) == 64 * sizeof(ggml_half) + QK8_0 * 32, "wrong block<4,64> size/padding"); static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); +static_assert(sizeof(block<8, 32>) == 32 * sizeof(ggml_half) + QK8_0 * 32, "wrong block<8,32> size/padding"); +static_assert(sizeof(block<8, 64>) == 64 * sizeof(ggml_half) + QK8_0 * 64, "wrong block<8,64> size/padding"); using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; using block_q4_0x16 = block<4, 16>; +using block_q4_0x32 = block<4, 32>; +using block_q4_0x64 = block<4, 64>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; using block_q8_0x16 = block<8, 16>; +using block_q8_0x32 = block<8, 32>; +using block_q8_0x64 = block<8, 64>; -struct block_q4_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[96]; // scales and mins, quantized with 6 bits - uint8_t qs[1024]; // 4--bit quants +template struct block_q4_Kx{ + ggml_half d[N]; // super-block scale for quantized scales + ggml_half dmin[N]; // super-block scale for quantized mins + uint8_t scales[12 * N]; // scales and mins, quantized with 6 bits + uint8_t qs[128 * N]; // 4--bit quants }; -static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); -struct block_q4_Kx16 { - ggml_half d[16]; // super-block scale for quantized scales - ggml_half dmin[16]; // super-block scale for quantized mins - uint8_t scales[192]; // scales and mins, quantized with 6 bits - uint8_t qs[2048]; // 4--bit quants -}; +using block_q4_Kx8 = block_q4_Kx<8>; +using block_q4_Kx16 = block_q4_Kx<16>; +using block_q4_Kx32 = block_q4_Kx<32>; +using block_q4_Kx64 = block_q4_Kx<64>; +static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); -struct block_q2_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[128]; // scales and mins, quantized with 4 bits - uint8_t qs[512]; // 2--bit quants +static_assert(sizeof(block_q4_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 16, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 32, "wrong q4_K block size/padding"); + +template struct block_q2_Kx { + ggml_half d[N]; // super-block scale for quantized scales + ggml_half dmin[N]; // super-block scale for quantized mins + uint8_t scales[16 * N]; // scales and mins, quantized with 4 bits + uint8_t qs[64 * N]; // 2--bit quants }; +using block_q2_Kx8 = block_q2_Kx<8>; +using block_q2_Kx16 = block_q2_Kx<16>; +using block_q2_Kx32 = block_q2_Kx<32>; +using block_q2_Kx64 = block_q2_Kx<64>; + static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); -struct block_q2_Kx16 { - ggml_half d[16]; // Super-block scale for quantized scales - ggml_half dmin[16]; // Super-block scale for quantized mins - uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks) - uint8_t qs[1024]; // Data (16 cols * 64 bytes per block) -}; static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding"); +static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); +static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); struct block_q5_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -83,15 +93,22 @@ struct block_q5_Kx8 { static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, "wrong q5_K block size/padding"); -struct block_q6_Kx8 { - ggml_half d[8]; - int8_t scales[QK_K / 16 * 8]; - uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) - uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +template struct block_q6_Kx { + ggml_half d[N]; + int8_t scales[QK_K / 16 * N]; + uint8_t ql[QK_K / 2 * N]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * N]; // high bits of 6-bit quants (groups of 4) }; -static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, - "wrong q6_K block size/padding"); +using block_q6_Kx8 = block_q6_Kx<8>; +using block_q6_Kx16 = block_q6_Kx<16>; +using block_q6_Kx32 = block_q6_Kx<32>; +using block_q6_Kx64 = block_q6_Kx<64>; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx16) == sizeof(ggml_half) * 16 + QK_K / 16 * 16 + 3 * QK_K / 4 * 16, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx32) == sizeof(ggml_half) * 32 + QK_K / 16 * 32 + 3 * QK_K / 4 * 32, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx64) == sizeof(ggml_half) * 64 + QK_K / 16 * 64 + 3 * QK_K / 4 * 64, "wrong q6_K block size/padding"); struct block_q8_Kx4 { float d[4]; // delta @@ -101,26 +118,23 @@ struct block_q8_Kx4 { static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding"); -struct block_iq4_nlx4 { - ggml_half d[4]; // deltas for 4 iq4_nl blocks - uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks +template struct block_iq4_nlx { + ggml_half d[N]; // deltas for `N` iq4_nl blocks + uint8_t qs[QK4_NL * N / 2]; // nibbles / quants for N iq4_nl blocks }; -static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); - -struct block_iq4_nlx8 { - ggml_half d[8]; // deltas for 8 iq4_nl blocks - uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks -}; +using block_iq4_nlx4 = block_iq4_nlx<4>; +using block_iq4_nlx8 = block_iq4_nlx<8>; +using block_iq4_nlx16 = block_iq4_nlx<16>; +using block_iq4_nlx32 = block_iq4_nlx<32>; +using block_iq4_nlx64 = block_iq4_nlx<64>; +static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); - -struct block_iq4_nlx16 { - ggml_half d[16]; // deltas for 16 iq4_nl blocks - uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks -}; - static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); +static_assert(sizeof(block_iq4_nlx32) == 32 * sizeof(ggml_half) + QK4_NL * 16, "wrong iq4_nlx32 block size/padding"); +static_assert(sizeof(block_iq4_nlx64) == 64 * sizeof(ggml_half) + QK4_NL * 32, "wrong iq4_nlx64 block size/padding"); + struct block_mxfp4x4 { uint8_t e[4]; uint8_t qs[QK_MXFP4 * 2]; @@ -176,16 +190,46 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif // Native implementations @@ -228,16 +272,46 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus) From d646c0d58047e25bcbc74e22d44c7fb6ec4c3b0a Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Fri, 6 Feb 2026 06:29:42 +0500 Subject: [PATCH 09/13] ggml-cpu: refactor; add rvv repacking for q6_K --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 571 ++++++++++++++++++++---- ggml/src/ggml-cpu/repack.cpp | 319 ++++++++++++- ggml/src/ggml-cpu/repack.h | 18 +- 3 files changed, 814 insertions(+), 94 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 9f336bdd50f..a8965bcc1a2 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -230,11 +230,11 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_ for (int x = 0; x < nc / ncols_interleaved; x++) { const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); - // 1x16 Accumulator + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 1x16 Integer Accumulator + // 1xM Integer Accumulator vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -243,7 +243,7 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_ // Load `b_ptr`. const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, ncols_interleaved), 4, ncols_interleaved); - const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, ncols_interleaved); sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, ncols_interleaved); sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, ncols_interleaved); @@ -451,11 +451,11 @@ void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); - // 1x16 Accumulator + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 1x16 Integer Accumulator + // 1xM Integer Accumulator vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. @@ -732,7 +732,7 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); - // 1x16 Accumulator + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { @@ -776,15 +776,16 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); // Accumulation for 2 sub-blocks. - { - // 4x16 integer accumulators + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); @@ -796,20 +797,22 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_s_0_16, 16); + sumi_s_0_16, ncols_interleaved); sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_s_1_16, 16); + sumi_s_1_16, ncols_interleaved); } - { - // 4x16 integer accumulators + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); @@ -867,6 +870,192 @@ void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +template +void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + // 1xM Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 8 16-element sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Load the scales. + // + // Low bits. + vint16m8_t scales = __riscv_vwcvt_x_x_v_i16m8(__riscv_vle8_v_i8m4(&b_ptr[l].scales[j * 8 * ncols_interleaved], 8 * ncols_interleaved), 8 * ncols_interleaved); + + // Sub-blocks 0, 2, 4, 6. + { + for (int i = 0; i < QK8_0 / 2; i++) { + // Load the high bits. + const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + i) * ncols_interleaved], ncols_interleaved); + + // Sub-blocks 0, 4 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_4_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); + const vuint8mf2_t b_4_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3 , ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, __riscv_vsll_vx_u8mf2(b_0_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_4_m = __riscv_vor_vv_u8mf2(b_4_lo, __riscv_vsll_vx_u8mf2(b_4_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); + const vint8mf2_t b_4 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_4_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 128 + i], ncols_interleaved), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 128 + 64 + i], ncols_interleaved), ncols_interleaved); + } + // Sub-blocks 2, 6 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 32 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_6_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); + const vuint8mf2_t b_6_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3 , ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, __riscv_vsll_vx_u8mf2(b_2_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_6_m = __riscv_vor_vv_u8mf2(b_6_lo, __riscv_vsll_vx_u8mf2(b_6_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); + const vint8mf2_t b_6 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_6_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 128 + 32 + i], ncols_interleaved), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 128 + 96 + i], ncols_interleaved), ncols_interleaved); + } + } + } + // Sub-blocks 1, 3, 5, 7. + { + for (int i = 0; i < QK8_0 / 2; i++) { + // Load the high bits. + const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + 16 + i) * ncols_interleaved], ncols_interleaved); + + // Sub-blocks 1, 5 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_1_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_5_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); + const vuint8mf2_t b_5_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3 , ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, __riscv_vsll_vx_u8mf2(b_1_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_5_m = __riscv_vor_vv_u8mf2(b_5_lo, __riscv_vsll_vx_u8mf2(b_5_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); + const vint8mf2_t b_5 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_5_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 128 + 16 + i], ncols_interleaved), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 128 + 80 + i], ncols_interleaved), ncols_interleaved); + } + // Sub-blocks 3, 7 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 48 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_3_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_7_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); + const vuint8mf2_t b_7_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, __riscv_vsll_vx_u8mf2(b_3_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_7_m = __riscv_vor_vv_u8mf2(b_7_lo, __riscv_vsll_vx_u8mf2(b_7_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); + const vint8mf2_t b_7 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_7_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 128 + 48 + i], ncols_interleaved), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 128 + 112 + i], ncols_interleaved), ncols_interleaved); + } + } + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q6_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q6_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q6_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q6_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q6_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q6_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q6_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q6_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + template void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -886,27 +1075,27 @@ void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); - // 1x16 Accumulator1 + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 1x16 integer accumulator + // 1xM integer accumulator vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK4_NL / 2; i++) { // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); - const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], ncols_interleaved); - const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], ncols_interleaved); + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i], ncols_interleaved); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[16 + i], ncols_interleaved); sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, ncols_interleaved), ncols_interleaved); } @@ -1206,14 +1395,14 @@ void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators + // 4xM integer accumulators vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -1688,14 +1877,14 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v for (int x = 0; x < nc / ncols_interleaved; x++) { const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 4x16 Integer Accumulators + // 4xM Integer Accumulators vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); @@ -2116,7 +2305,7 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); @@ -2139,7 +2328,7 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl], vl); + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); vuint8m2_t scales_hi; vuint8m2_t mins_hi; if (!j) { @@ -2162,52 +2351,52 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); @@ -2225,8 +2414,12 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // Accumulation for 2 sub-blocks. - { - // 4x8 integer accumulators + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -2236,12 +2429,9 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], 16); + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); @@ -2281,8 +2471,13 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), sumi_3_s_1_16, ncols_interleaved); } - { - // 4x16 integer accumulators + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -2292,10 +2487,7 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); @@ -2388,6 +2580,231 @@ void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +template +void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + // 4xM Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 8 16-element sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Load the scales. + // + // Low bits. + vint16m8_t scales = __riscv_vwcvt_x_x_v_i16m8(__riscv_vle8_v_i8m4(&b_ptr[l].scales[j * 8 * ncols_interleaved], 8 * ncols_interleaved), 8 * ncols_interleaved); + + // Loop 1: Sub-blocks 0, 2, 4, 6. + // Loop 2: Sub-blocks 8, 10, 12, 14. + for (int i = 0; i < QK8_0 / 2; i++) { + // Load the high bits. + const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + i) * ncols_interleaved], ncols_interleaved); + + // Sub-blocks 0, 4 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_4_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); + const vuint8mf2_t b_4_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, __riscv_vsll_vx_u8mf2(b_0_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_4_m = __riscv_vor_vv_u8mf2(b_4_lo, __riscv_vsll_vx_u8mf2(b_4_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); + const vint8mf2_t b_4 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_4_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + } + // Sub-blocks 2, 6 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 32 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_6_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); + const vuint8mf2_t b_6_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3 , ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, __riscv_vsll_vx_u8mf2(b_2_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_6_m = __riscv_vor_vv_u8mf2(b_6_lo, __riscv_vsll_vx_u8mf2(b_6_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); + const vint8mf2_t b_6 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_6_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + } + } + // Loop 1: Sub-blocks 1, 3, 5, 7. + // Loop 2: Sub-blocks 9, 11, 13, 15. + for (int i = 0; i < QK8_0 / 2; i++) { + // Load the high bits. + const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + 16 + i) * ncols_interleaved], ncols_interleaved); + + // Sub-blocks 1, 5 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_1_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_5_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); + const vuint8mf2_t b_5_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3 , ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, __riscv_vsll_vx_u8mf2(b_1_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_5_m = __riscv_vor_vv_u8mf2(b_5_lo, __riscv_vsll_vx_u8mf2(b_5_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); + const vint8mf2_t b_5 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_5_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + } + // Sub-blocks 3, 7 + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 48 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_3_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_7_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); + const vuint8mf2_t b_7_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, __riscv_vsll_vx_u8mf2(b_3_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_7_m = __riscv_vor_vv_u8mf2(b_7_lo, __riscv_vsll_vx_u8mf2(b_7_hi, 4, ncols_interleaved), ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); + const vint8mf2_t b_7 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_7_m), 32, ncols_interleaved); + + // Multiply and accumulate. + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + } + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q6_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q6_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q6_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q6_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q6_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q6_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q6_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q6_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + template void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -2409,21 +2826,21 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators + // 4xM integer accumulators vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); @@ -2432,21 +2849,21 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const // Accumulation loop. for (int i = 0; i < QK4_NL / 2; i++) { // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], ncols_interleaved); - const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], ncols_interleaved); - const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], ncols_interleaved); - const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], ncols_interleaved); + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4], ncols_interleaved); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 3], ncols_interleaved); - const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], ncols_interleaved); - const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], ncols_interleaved); - const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], ncols_interleaved); - const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], ncols_interleaved); + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4], ncols_interleaved); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4 + 3], ncols_interleaved); sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, ncols_interleaved), ncols_interleaved); sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, ncols_interleaved), ncols_interleaved); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 7f101d946b9..5dc2b62ab24 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -361,6 +361,7 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT const int nb = n / qk; const int blocklen = 1; + assert(nr == 1); assert (n % qk == 0); assert (nc % ncols_interleaved == 0); @@ -374,7 +375,7 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT UNUSED(ncols_interleaved); UNUSED(blocklen); - float sumf[16]; + float sumf[ncols_interleaved]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -412,7 +413,7 @@ static inline void ggml_gemv_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT UNUSED(bs); UNUSED(nr); - float sumf[16]; + float sumf[ncols_interleaved]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -462,7 +463,7 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT const block_q2_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; const block_q8_K * y_ptr = y; - float sumf[16] = {0}; + float sumf[ncols_interleaved] = {0}; // Loop over K-blocks for (int k_block = 0; k_block < nb; ++k_block) { @@ -536,8 +537,11 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; + + assert(nr == 1); assert (n % qk == 0); assert (nc % ncols_interleaved == 0); + UNUSED(s); UNUSED(bs); UNUSED(vx); @@ -547,6 +551,7 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); + float sumf[ncols_interleaved]; float sum_minf[ncols_interleaved]; uint8_t scales[ncols_interleaved * 8]; @@ -604,6 +609,65 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemv_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + + int sumi0; + int sumi4; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + // Processing 2 sub-blocks at once. + for (int sb = 0; sb < 8; sb++) { + int scales_idx = (sb / 4) * 4 + sb; + int qh_idx = (sb / 4) * 32 + (sb % 2) * 16; + const int8_t *scales_0 = &b_ptr[l].scales[(scales_idx) * ncols_interleaved]; + const int8_t *scales_4 = &b_ptr[l].scales[(scales_idx + 4) * ncols_interleaved]; + for (int i = 0; i < QK8_0 / 2; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi0 = 0; + sumi4 = 0; + sumi = 0; + const uint8_t v0 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + const uint8_t v4 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + const int8_t a0 = (int8_t)(v0 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> ((sb / 2) % 2) * 2) & 0x03)) << 4) - 32; + const int8_t a4 = (int8_t)(v4 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> (((sb / 2) % 2) * 2 + 4)) & 0x03)) << 4) - 32; + sumi0 = (a0 * a_ptr[l].qs[(sb / 4) * 64 + sb * 16 + i]); + sumi4 = (a4 * a_ptr[l].qs[(sb / 4) * 64 + sb * 16 + 64 + i]); + sumi0 = sumi0 * scales_0[j]; + sumi4 = sumi4 * scales_4[j]; + sumi += sumi0 + sumi4; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + template static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -617,7 +681,7 @@ static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC UNUSED(bs); UNUSED(nr); - float sumf[16]; + float sumf[ncols_interleaved]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -714,7 +778,7 @@ static inline void ggml_gemm_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; @@ -884,8 +948,8 @@ static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT mins[i] = b_ptr[l].scales[i] >> 4; } for (int i = 0; i < ncols_interleaved * 4; i++) { - scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; - mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; } @@ -934,6 +998,77 @@ static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemm_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + + float sumf[4][ncols_interleaved]; + int sumi0; + int sumi4; + int sumi; + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + + for (int l = 0; l < nb; l++) { + // Processing 2 sub-blocks at once. + int total = 0; + for (int sb = 0; sb < 8; sb++) { + int scales_idx = (sb / 4) * 4 + sb; + int qh_idx = (sb / 4) * 32 + (sb % 2) * 16; + const int8_t *scales_0 = &b_ptr[l].scales[(scales_idx) * ncols_interleaved]; + const int8_t *scales_4 = &b_ptr[l].scales[(scales_idx + 4) * ncols_interleaved]; + for (int i = 0; i < QK8_0 / 2; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi0 = 0; + sumi4 = 0; + sumi = 0; + const uint8_t v0 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + const uint8_t v4 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + const int8_t a0 = (int8_t)(v0 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> ((sb / 2) % 2) * 2) & 0x03)) << 4) - 32; + const int8_t a4 = (int8_t)(v4 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> (((sb / 2) % 2) * 2 + 4)) & 0x03)) << 4) - 32; + sumi0 = (a0 * a_ptr[l].qs[(sb / 4) * 64 * 4 + sb * 4 * 16 + i * 4 + m]); + sumi4 = (a4 * a_ptr[l].qs[(sb / 4) * 64 * 4 + sb * 4 * 16 + 4 * 64 + i * 4 + m]); + sumi0 = sumi0 * scales_0[j]; + sumi4 = sumi4 * scales_4[j]; + sumi += sumi0 + sumi4; + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + template static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -944,7 +1079,7 @@ static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][16]; + float sumf[4][ncols_interleaved]; int sumi; for (int y = 0; y < nr / 4; y++) { @@ -1340,11 +1475,14 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } +<<<<<<< HEAD void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); } +======= +>>>>>>> 4044254f5 (ggml-cpu: refactor; add rvv repacking for q6_K) void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } @@ -1598,7 +1736,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, #if defined __riscv_zvfh // Q4_0 void ggml_gemv_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - ggml_gemm_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); + ggml_gemv_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q4_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); @@ -1652,6 +1790,20 @@ void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q6_K +void ggml_gemv_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -2452,6 +2604,20 @@ void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemm_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q6_K +void ggml_gemm_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -3349,7 +3515,6 @@ static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ template static block_q4_Kx make_block_q4_KxMx1(block_q4_K * in) { block_q4_Kx out; - //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure for (int i = 0; i < nrows_interleaved; i++) { out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; } @@ -3425,6 +3590,70 @@ static int repack_q4_K_to_q4_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ GGML_UNUSED(data_size); } +template +static block_q6_Kx make_block_q6_KxMx1(block_q6_K * in) { + block_q6_Kx out; + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].d; + } + + const int end_ls = QK_K * nrows_interleaved / 2; + for (int i = 0; i < end_ls; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.ql[dst_offset] = in[src_id].ql[src_offset]; + } + + const int end_hs = QK_K * nrows_interleaved / 4; + for (int i = 0; i < end_hs; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qh[dst_offset] = in[src_id].qh[src_offset]; + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int j = 0; j < 16; j++) { + out.scales[j * nrows_interleaved + i] = in[i].scales[j]; + } + } + + return out; +} + +template +static int repack_q6_K_to_q6_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + + block_q6_Kx * dst = (block_q6_Kx*)t->data; + const block_q6_K * src = (const block_q6_K*) data; + block_q6_K dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q6_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + template static block_iq4_nlx make_block_iq4_nlxMx1(block_iq4_nl * in) { block_iq4_nlx out; @@ -3725,6 +3954,20 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q4_K_to_q4_K_Mx1_bl<64>(t, data, data_size); } +// Q6_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<64>(t, data, data_size); +} + // IQ4_NL template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_Mx1_bl<8>(t, data, data_size); @@ -3872,6 +4115,20 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q6_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4019,6 +4276,20 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q6_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4471,6 +4742,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_32x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_64x1_q8_K; + // Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_64x1_q8_K; + // IQ4_NL static const ggml::cpu::repack::tensor_traits iq4_nl_8x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; @@ -4497,7 +4774,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q4_0_8x1_q8_0; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q4_0_32x1_q8_0; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_0_64x1_q8_0; } break; } @@ -4524,7 +4801,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q4_K_8x1_q8_K; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q4_K_32x1_q8_K; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_K_64x1_q8_K; } break; } @@ -4541,10 +4818,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q2_K_8x1_q8_K; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q2_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q2_K_32x1_q8_K; } break; } - case 1024: { if (cur->ne[1] % 64 == 0) { return &q2_K_64x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q2_K_64x1_q8_K; } break; } default: { return nullptr; } } #endif @@ -4570,6 +4847,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q6_K_8x4_q8_K; } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &q6_K_8x1_q8_K; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &q6_K_16x1_q8_K; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &q6_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q6_K_64x1_q8_K; } break; } + default: { return nullptr; } + } + #endif } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { @@ -4585,7 +4872,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x1_q8_0; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &iq4_nl_32x1_q8_0; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &iq4_nl_64x1_q8_0; } break; } @@ -4618,7 +4905,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q8_0_8x1_q8_0; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q8_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q8_0_32x1_q8_0; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &q8_0_64x1_q8_0; } break; } diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 0d97c9b9b53..7ef2e8341bc 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -119,7 +119,7 @@ struct block_q8_Kx4 { static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding"); template struct block_iq4_nlx { - ggml_half d[N]; // deltas for `N` iq4_nl blocks + ggml_half d[N]; // deltas for `N` iq4_nl blocks uint8_t qs[QK4_NL * N / 2]; // nibbles / quants for N iq4_nl blocks }; @@ -206,6 +206,10 @@ void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -226,6 +230,10 @@ void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -288,6 +296,10 @@ void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -308,6 +320,10 @@ void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 90c9244cc1cb28d3a437c1b14b5dde0eee748809 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Wed, 11 Feb 2026 17:54:33 +0500 Subject: [PATCH 10/13] ggml-cpu: refactor; add rvv repacking for q5_K --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 818 ++++++++++++++++++------ ggml/src/ggml-cpu/repack.cpp | 494 ++++++++++++-- ggml/src/ggml-cpu/repack.h | 39 +- 3 files changed, 1109 insertions(+), 242 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index a8965bcc1a2..b9822dc9c03 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -871,7 +871,7 @@ void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -892,7 +892,7 @@ void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); @@ -900,120 +900,258 @@ void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); - // We process 8 16-element sub-blocks at once. + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { - // Load the scales. + // Extract the scales and the mins. // // Low bits. - vint16m8_t scales = __riscv_vwcvt_x_x_v_i16m8(__riscv_vle8_v_i8m4(&b_ptr[l].scales[j * 8 * ncols_interleaved], 8 * ncols_interleaved), 8 * ncols_interleaved); + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); - // Sub-blocks 0, 2, 4, 6. - { - for (int i = 0; i < QK8_0 / 2; i++) { - // Load the high bits. - const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + i) * ncols_interleaved], ncols_interleaved); + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); - // Sub-blocks 0, 4 - { - // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_4_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - // Unpack the high bits. - const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); - const vuint8mf2_t b_4_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3 , ncols_interleaved); + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); - // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, __riscv_vsll_vx_u8mf2(b_0_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_4_m = __riscv_vor_vv_u8mf2(b_4_lo, __riscv_vsll_vx_u8mf2(b_4_hi, 4, ncols_interleaved), ncols_interleaved); + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // Bias adjustment. - const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); - const vint8mf2_t b_4 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_4_m), 32, ncols_interleaved); + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); - // Multiply and accumulate. - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 128 + i], ncols_interleaved), ncols_interleaved); - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 128 + 64 + i], ncols_interleaved), ncols_interleaved); - } - // Sub-blocks 2, 6 - { - // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 32 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_6_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 0), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 1), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); - // Unpack the high bits. - const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); - const vuint8mf2_t b_6_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3 , ncols_interleaved); + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, ncols_interleaved); + } - // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, __riscv_vsll_vx_u8mf2(b_2_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_6_m = __riscv_vor_vv_u8mf2(b_6_lo, __riscv_vsll_vx_u8mf2(b_6_hi, 4, ncols_interleaved), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, ncols_interleaved); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // Bias adjustment. - const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); - const vint8mf2_t b_6 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_6_m), 32, ncols_interleaved); + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); - // Multiply and accumulate. - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 128 + 32 + i], ncols_interleaved), ncols_interleaved); - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 128 + 96 + i], ncols_interleaved), ncols_interleaved); - } + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 2), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 3), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, ncols_interleaved); } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, ncols_interleaved); } - // Sub-blocks 1, 3, 5, 7. - { - for (int i = 0; i < QK8_0 / 2; i++) { + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q5_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q5_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q5_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q5_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q5_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q5_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q5_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q5_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + // 1xM Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 2 16-element sub-blocks at once. + for (int j = 0; j < QK_K / 16; j += 4) { + // Load the scales. + // + // Low bits. + vint16m4_t scales = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(&b_ptr[l].scales[j * ncols_interleaved], 4 * ncols_interleaved), 4 * ncols_interleaved); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < k * 8 + QK8_0 / 4; i++) { // Load the high bits. - const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 4 + i) * ncols_interleaved], ncols_interleaved); - // Sub-blocks 1, 5 { // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 16 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_1_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_5_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_1_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); // Unpack the high bits. - const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); - const vuint8mf2_t b_5_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3 , ncols_interleaved); + const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x30, ncols_interleaved); + const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, __riscv_vsll_vx_u8mf2(b_1_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_5_m = __riscv_vor_vv_u8mf2(b_5_lo, __riscv_vsll_vx_u8mf2(b_5_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, b_0_hi, ncols_interleaved); + const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, b_1_hi, ncols_interleaved); // Bias adjustment. + const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); - const vint8mf2_t b_5 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_5_m), 32, ncols_interleaved); - // Multiply and accumulate. - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 128 + 16 + i], ncols_interleaved), ncols_interleaved); - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 128 + 80 + i], ncols_interleaved), ncols_interleaved); + // Multiply and accumulate in int16. + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 16 + 0 + i], b_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 16 + 16 + i], b_1, ncols_interleaved); } - // Sub-blocks 3, 7 + __asm__ __volatile__("" ::: "memory"); { // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 48 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_3_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_7_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_3_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); // Unpack the high bits. - const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); - const vuint8mf2_t b_7_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3, ncols_interleaved); + const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(b_hi, 0x30, ncols_interleaved); + const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, __riscv_vsll_vx_u8mf2(b_3_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_7_m = __riscv_vor_vv_u8mf2(b_7_lo, __riscv_vsll_vx_u8mf2(b_7_hi, 4, ncols_interleaved), ncols_interleaved); + const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, b_2_hi, ncols_interleaved); + const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, b_3_hi, ncols_interleaved); // Bias adjustment. + const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); - const vint8mf2_t b_7 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_7_m), 32, ncols_interleaved); - // Multiply and accumulate. - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 128 + 48 + i], ncols_interleaved), ncols_interleaved); - sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 128 + 112 + i], ncols_interleaved), ncols_interleaved); + // Multiply and accumulate in int16. + sumi_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_s_2_16, a_ptr[l].qs[j * 16 + 32 + i], b_2, ncols_interleaved); + sumi_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_s_3_16, a_ptr[l].qs[j * 16 + 48 + i], b_3, ncols_interleaved); } + __asm__ __volatile__("" ::: "memory"); } + + // Multiply and accumulate in int32. + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_0_16, __riscv_vget_v_i16m4_i16m1(scales, 0), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_1_16, __riscv_vget_v_i16m4_i16m1(scales, 1), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_2_16, __riscv_vget_v_i16m4_i16m1(scales, 2), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_3_16, __riscv_vget_v_i16m4_i16m1(scales, 3), ncols_interleaved); } } @@ -2581,7 +2719,7 @@ void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -2603,7 +2741,7 @@ void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); @@ -2617,142 +2755,438 @@ void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); - // We process 8 16-element sub-blocks at once. + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { - // Load the scales. + // Extract the scales and the mins. // // Low bits. - vint16m8_t scales = __riscv_vwcvt_x_x_v_i16m8(__riscv_vle8_v_i8m4(&b_ptr[l].scales[j * 8 * ncols_interleaved], 8 * ncols_interleaved), 8 * ncols_interleaved); + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); - // Loop 1: Sub-blocks 0, 2, 4, 6. - // Loop 2: Sub-blocks 8, 10, 12, 14. - for (int i = 0; i < QK8_0 / 2; i++) { - // Load the high bits. - const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + i) * ncols_interleaved], ncols_interleaved); + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); - // Sub-blocks 0, 4 - { - // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_4_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); - // Unpack the high bits. - const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); - const vuint8mf2_t b_4_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3, ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, __riscv_vsll_vx_u8mf2(b_0_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_4_m = __riscv_vor_vv_u8mf2(b_4_lo, __riscv_vsll_vx_u8mf2(b_4_hi, 4, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); - // Bias adjustment. - const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); - const vint8mf2_t b_4 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_4_m), 32, ncols_interleaved); - - // Multiply and accumulate. - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 0), __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[j * 512 + 64 * 0 + i * 4 + 3], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 4), __riscv_vwmul_vx_i16m1(b_4, a_ptr[l].qs[j * 512 + 64 * 4 + i * 4 + 3], ncols_interleaved), ncols_interleaved); - } - // Sub-blocks 2, 6 - { - // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 32 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_6_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, ncols_interleaved), ncols_interleaved), ncols_interleaved); - // Unpack the high bits. - const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); - const vuint8mf2_t b_6_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3 , ncols_interleaved); - // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, __riscv_vsll_vx_u8mf2(b_2_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_6_m = __riscv_vor_vv_u8mf2(b_6_lo, __riscv_vsll_vx_u8mf2(b_6_hi, 4, ncols_interleaved), ncols_interleaved); + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // Bias adjustment. - const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); - const vint8mf2_t b_6 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_6_m), 32, ncols_interleaved); - - // Multiply and accumulate. - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 2), __riscv_vwmul_vx_i16m1(b_2, a_ptr[l].qs[j * 512 + 64 * 2 + i * 4 + 3], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 6), __riscv_vwmul_vx_i16m1(b_6, a_ptr[l].qs[j * 512 + 64 * 6 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); + + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 0), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 1), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, ncols_interleaved); } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, ncols_interleaved); } - // Loop 1: Sub-blocks 1, 3, 5, 7. - // Loop 2: Sub-blocks 9, 11, 13, 15. - for (int i = 0; i < QK8_0 / 2; i++) { - // Load the high bits. - const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 32 + 16 + i) * ncols_interleaved], ncols_interleaved); + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // Sub-blocks 1, 5 - { - // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 16 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_1_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_5_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); - // Unpack the high bits. - const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(b_hi, 0x3, ncols_interleaved); - const vuint8mf2_t b_5_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x3 , ncols_interleaved); + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 2), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 3), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); - // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, __riscv_vsll_vx_u8mf2(b_1_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_5_m = __riscv_vor_vv_u8mf2(b_5_lo, __riscv_vsll_vx_u8mf2(b_5_hi, 4, ncols_interleaved), ncols_interleaved); + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, ncols_interleaved); - // Bias adjustment. - const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); - const vint8mf2_t b_5 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_5_m), 32, ncols_interleaved); - - // Multiply and accumulate. - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 1), __riscv_vwmul_vx_i16m1(b_1, a_ptr[l].qs[j * 512 + 64 * 1 + i * 4 + 3], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 5), __riscv_vwmul_vx_i16m1(b_5, a_ptr[l].qs[j * 512 + 64 * 5 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, ncols_interleaved); } - // Sub-blocks 3, 7 - { - // Load the low bits. - const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 64 + 48 + i) * ncols_interleaved], ncols_interleaved); - const vuint8mf2_t b_3_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); - const vuint8mf2_t b_7_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); - // Unpack the high bits. - const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x3, ncols_interleaved); - const vuint8mf2_t b_7_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 6, ncols_interleaved), 0x3, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, ncols_interleaved); + } + } - // Merge the low bits with the corresponding high bits. - const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, __riscv_vsll_vx_u8mf2(b_3_hi, 4, ncols_interleaved), ncols_interleaved); - const vuint8mf2_t b_7_m = __riscv_vor_vv_u8mf2(b_7_lo, __riscv_vsll_vx_u8mf2(b_7_hi, 4, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], ncols_interleaved); - // Bias adjustment. - const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); - const vint8mf2_t b_7 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_7_m), 32, ncols_interleaved); - - // Multiply and accumulate. - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 0], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 1], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 2], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 3), __riscv_vwmul_vx_i16m1(b_3, a_ptr[l].qs[j * 512 + 64 * 3 + i * 4 + 3], ncols_interleaved), ncols_interleaved); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m8_i16m1(scales, 7), __riscv_vwmul_vx_i16m1(b_7, a_ptr[l].qs[j * 512 + 64 * 7 + i * 4 + 3], ncols_interleaved), ncols_interleaved); + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +template +void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + // 4xM Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 2 16-element sub-blocks at once. + for (int j = 0; j < QK_K / 16; j += 4) { + // Load the scales. + // + // Low bits. + vint16m4_t scales = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(&b_ptr[l].scales[j * ncols_interleaved], 4 * ncols_interleaved), 4 * ncols_interleaved); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + #pragma GCC unroll 1 + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < k * 8 + QK8_0 / 4; i++) { + // Load the high bits. + vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 4 + i) * ncols_interleaved], ncols_interleaved); + + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_1_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x30, ncols_interleaved); + const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, b_0_hi, ncols_interleaved); + const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, b_1_hi, ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); + const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); + + // Multiply and accumulate in int16. + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 0], b_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 1], b_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 2], b_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 3], b_0, ncols_interleaved); + // + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 0], b_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 1], b_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 2], b_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 3], b_1, ncols_interleaved); + } + asm volatile ("" ::: "memory"); + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_3_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(b_hi, 0x30, ncols_interleaved); + const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, b_2_hi, ncols_interleaved); + const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, b_3_hi, ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); + const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); + + // Multiply and accumulate in int16. + sumi_0_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 0], b_2, ncols_interleaved); + sumi_1_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 1], b_2, ncols_interleaved); + sumi_2_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 2], b_2, ncols_interleaved); + sumi_3_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 3], b_2, ncols_interleaved); + // + sumi_0_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 0], b_3, ncols_interleaved); + sumi_1_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 1], b_3, ncols_interleaved); + sumi_2_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 2], b_3, ncols_interleaved); + sumi_3_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 3], b_3, ncols_interleaved); + } + asm volatile ("" ::: "memory"); } + + // Multiply and accumulate in int32. + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_0_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_1_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_2_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_3_s_0_16, ncols_interleaved); + // + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_3_s_1_16, ncols_interleaved); + // + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_0_s_2_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_1_s_2_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_2_s_2_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_3_s_2_16, ncols_interleaved); + // + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_0_s_3_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_1_s_3_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_2_s_3_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_3_s_3_16, ncols_interleaved); } } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5dc2b62ab24..0418c7f9128 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -609,6 +609,85 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemv_q5_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert(nr == 1); + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint8_t scales[ncols_interleaved * 8]; + uint8_t mins[ncols_interleaved * 8]; + int sumi1; + int sumi2; + int sumi; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 0))) { v0 += 16; } + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 1))) { v1 += 16; } + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + template static inline void ggml_gemv_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; @@ -624,7 +703,9 @@ static inline void ggml_gemv_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT float sumf[ncols_interleaved]; int sumi0; - int sumi4; + int sumi1; + int sumi2; + int sumi3; int sumi; const block_q8_K * a_ptr = (const block_q8_K *) vy; @@ -636,26 +717,35 @@ static inline void ggml_gemv_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } for (int l = 0; l < nb; l++) { - // Processing 2 sub-blocks at once. - for (int sb = 0; sb < 8; sb++) { - int scales_idx = (sb / 4) * 4 + sb; - int qh_idx = (sb / 4) * 32 + (sb % 2) * 16; - const int8_t *scales_0 = &b_ptr[l].scales[(scales_idx) * ncols_interleaved]; - const int8_t *scales_4 = &b_ptr[l].scales[(scales_idx + 4) * ncols_interleaved]; - for (int i = 0; i < QK8_0 / 2; i++) { + // Processing 4 sub-blocks at once. + for (int sb = 0; sb < QK_K/16; sb += 4) { + const int8_t *scales_0 = &b_ptr[l].scales[sb * ncols_interleaved]; + const int8_t *scales_1 = &b_ptr[l].scales[(sb + 1) * ncols_interleaved]; + const int8_t *scales_2 = &b_ptr[l].scales[(sb + 2) * ncols_interleaved]; + const int8_t *scales_3 = &b_ptr[l].scales[(sb + 3) * ncols_interleaved]; + const int qh_idx = sb * 4; + for (int i = 0; i < QK8_0/2; i++) { for (int j = 0; j < ncols_interleaved; j++) { sumi0 = 0; - sumi4 = 0; + sumi1 = 0; sumi = 0; - const uint8_t v0 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); - const uint8_t v4 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); - const int8_t a0 = (int8_t)(v0 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> ((sb / 2) % 2) * 2) & 0x03)) << 4) - 32; - const int8_t a4 = (int8_t)(v4 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> (((sb / 2) % 2) * 2 + 4)) & 0x03)) << 4) - 32; - sumi0 = (a0 * a_ptr[l].qs[(sb / 4) * 64 + sb * 16 + i]); - sumi4 = (a4 * a_ptr[l].qs[(sb / 4) * 64 + sb * 16 + 64 + i]); + const uint8_t v0 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v1 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] >> 4); + const uint8_t v2 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v3 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] >> 4); + const int8_t a0 = (int8_t)(v0 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 4) & 0x30)) - 32; + const int8_t a1 = (int8_t)(v1 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 2) & 0x30)) - 32; + const int8_t a2 = (int8_t)(v2 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j]) & 0x30)) - 32; + const int8_t a3 = (int8_t)(v3 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> 2) & 0x30)) - 32; + sumi0 = (a0 * a_ptr[l].qs[sb * 16 + i]); + sumi1 = (a1 * a_ptr[l].qs[sb * 16 + 16 + i]); + sumi2 = (a2 * a_ptr[l].qs[sb * 16 + 32 + i]); + sumi3 = (a3 * a_ptr[l].qs[sb * 16 + 48 + i]); sumi0 = sumi0 * scales_0[j]; - sumi4 = sumi4 * scales_4[j]; - sumi += sumi0 + sumi4; + sumi1 = sumi1 * scales_1[j]; + sumi2 = sumi2 * scales_2[j]; + sumi3 = sumi3 * scales_3[j]; + sumi += sumi0 + sumi1 + sumi2 + sumi3; sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; } } @@ -998,6 +1088,102 @@ static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemm_q5_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint8_t scales[8 * ncols_interleaved]; + uint8_t mins[8 * ncols_interleaved]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; + } + + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + + int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 0))) { v0 += 16; } + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 1))) { v1 += 16; } + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + template static inline void ggml_gemm_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; @@ -1018,7 +1204,9 @@ static inline void ggml_gemm_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT float sumf[4][ncols_interleaved]; int sumi0; - int sumi4; + int sumi1; + int sumi2; + int sumi3; int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); @@ -1032,28 +1220,36 @@ static inline void ggml_gemm_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } for (int l = 0; l < nb; l++) { - // Processing 2 sub-blocks at once. - int total = 0; - for (int sb = 0; sb < 8; sb++) { - int scales_idx = (sb / 4) * 4 + sb; - int qh_idx = (sb / 4) * 32 + (sb % 2) * 16; - const int8_t *scales_0 = &b_ptr[l].scales[(scales_idx) * ncols_interleaved]; - const int8_t *scales_4 = &b_ptr[l].scales[(scales_idx + 4) * ncols_interleaved]; - for (int i = 0; i < QK8_0 / 2; i++) { + // Processing 4 sub-blocks at once. + for (int sb = 0; sb < QK_K/16; sb += 4) { + const int8_t *scales_0 = &b_ptr[l].scales[sb * ncols_interleaved]; + const int8_t *scales_1 = &b_ptr[l].scales[(sb + 1) * ncols_interleaved]; + const int8_t *scales_2 = &b_ptr[l].scales[(sb + 2) * ncols_interleaved]; + const int8_t *scales_3 = &b_ptr[l].scales[(sb + 3) * ncols_interleaved]; + const int qh_idx = sb * 4; + for (int i = 0; i < QK8_0/2; i++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumi0 = 0; - sumi4 = 0; + sumi1 = 0; sumi = 0; - const uint8_t v0 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); - const uint8_t v4 = (uint8_t) (b_ptr[l].ql[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); - const int8_t a0 = (int8_t)(v0 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> ((sb / 2) % 2) * 2) & 0x03)) << 4) - 32; - const int8_t a4 = (int8_t)(v4 | (((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> (((sb / 2) % 2) * 2 + 4)) & 0x03)) << 4) - 32; - sumi0 = (a0 * a_ptr[l].qs[(sb / 4) * 64 * 4 + sb * 4 * 16 + i * 4 + m]); - sumi4 = (a4 * a_ptr[l].qs[(sb / 4) * 64 * 4 + sb * 4 * 16 + 4 * 64 + i * 4 + m]); + const uint8_t v0 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v1 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] >> 4); + const uint8_t v2 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v3 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] >> 4); + const int8_t a0 = (int8_t)(v0 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 4) & 0x30)) - 32; + const int8_t a1 = (int8_t)(v1 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 2) & 0x30)) - 32; + const int8_t a2 = (int8_t)(v2 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j]) & 0x30)) - 32; + const int8_t a3 = (int8_t)(v3 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> 2) & 0x30)) - 32;sumi0 = (a0 * a_ptr[l].qs[sb * 64 + i * 4 + m]); + sumi0 = (a0 * a_ptr[l].qs[sb * 64 + 0 + i * 4 + m]); + sumi1 = (a1 * a_ptr[l].qs[sb * 64 + 64 + i * 4 + m]); + sumi2 = (a2 * a_ptr[l].qs[sb * 64 + 128 + i * 4 + m]); + sumi3 = (a3 * a_ptr[l].qs[sb * 64 + 192 + i * 4 + m]); sumi0 = sumi0 * scales_0[j]; - sumi4 = sumi4 * scales_4[j]; - sumi += sumi0 + sumi4; + sumi1 = sumi1 * scales_1[j]; + sumi2 = sumi2 * scales_2[j]; + sumi3 = sumi3 * scales_3[j]; + sumi += sumi0 + sumi1 + sumi2 + sumi3; sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; } } @@ -1790,6 +1986,20 @@ void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q5_K +void ggml_gemv_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // Q6_K void ggml_gemv_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q6_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -2604,6 +2814,20 @@ void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemm_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q5_K +void ggml_gemm_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // Q6_K void ggml_gemm_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_q6_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -3591,30 +3815,158 @@ static int repack_q4_K_to_q4_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ } template -static block_q6_Kx make_block_q6_KxMx1(block_q6_K * in) { - block_q6_Kx out; +static block_q5_Kx make_block_q5_KxMx1(block_q5_K * in) { + block_q5_Kx out; for (int i = 0; i < nrows_interleaved; i++) { - out.d[i] = in[i].d; + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < nrows_interleaved; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; } const int end_ls = QK_K * nrows_interleaved / 2; + for (int i = 0; i < end_ls; ++i) { - int src_id = i % nrows_interleaved; + int src_id = i % nrows_interleaved; int src_offset = i / nrows_interleaved; int dst_offset = i; - out.ql[dst_offset] = in[src_id].ql[src_offset]; + out.qs[dst_offset] = in[src_id].qs[src_offset]; } - const int end_hs = QK_K * nrows_interleaved / 4; + const int end_hs = 32 * nrows_interleaved; + for (int i = 0; i < end_hs; ++i) { - int src_id = i % nrows_interleaved; + int src_id = i % nrows_interleaved; int src_offset = i / nrows_interleaved; int dst_offset = i; out.qh[dst_offset] = in[src_id].qh[src_offset]; } + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[8 * nrows_interleaved], m[8 * nrows_interleaved]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[i * nrows_interleaved + j] = in[j].scales[i] & 63; + m[i * nrows_interleaved + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 8 * nrows_interleaved; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 8 * nrows_interleaved / 2; i++) { + out.scales[nrows_interleaved * 8 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[nrows_interleaved * 8 / 2 + i] & 48) | ((m[nrows_interleaved * 8 / 2 + i] & 48) << 2); + } + + return out; +} + +template +static int repack_q5_K_to_q5_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + + block_q5_Kx * dst = (block_q5_Kx*)t->data; + const block_q5_K * src = (const block_q5_K*) data; + block_q5_K dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +template +static block_q6_Kx make_block_q6_KxMx1(block_q6_K * in) { + block_q6_Kx out; + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].d; + } + + const int end_ls = QK_K / 2; + for (int i = 0; i < end_ls; i += 64) { + for (int l = 0; l < 2; l++) { + for (int k = 0; k < 16; k++) { + uint8_t temp[nrows_interleaved]; + int dst_offset = (i + k + (l % 2) * 16) * nrows_interleaved; + int src_offset = i + k + (l % 2) * 32; + for (int j = 0; j < nrows_interleaved; j++) { + int src_id = j % nrows_interleaved; + + temp[j] = (in[src_id].ql[src_offset] & 0xF) | ((in[src_id].ql[src_offset + 16] & 0xF) << 4); + } + for (int j = 0; j < nrows_interleaved; j++) { + out.ql[dst_offset + j] = temp[j]; + } + } + } + for (int l = 0; l < 2; l++) { + for (int k = 0; k < 16; k++) { + uint8_t temp[nrows_interleaved]; + int dst_offset = (i + 32 + k + (l % 2) * 16) * nrows_interleaved; + int src_offset = i + k + (l % 2) * 32; + for (int j = 0; j < nrows_interleaved; j++) { + int src_id = j % nrows_interleaved; + + temp[j] = (in[src_id].ql[src_offset] >> 4) | ((in[src_id].ql[src_offset + 16] >> 4) << 4); + } + for (int j = 0; j < nrows_interleaved; j++) { + out.ql[dst_offset + j] = temp[j]; + } + } + } + } + + const int end_hs = QK_K / 4; + for (int i = 0; i < end_hs; i += 32) { + for (int l = 0; l < 2; l++) { + for (int k = 0; k < 16; k++) { + uint8_t temp[nrows_interleaved]; + int dst_offset = (i + l * 16 + k) * nrows_interleaved; + int src_offset = i + k; + for (int j = 0; j < nrows_interleaved; j++) { + int src_id = j; + + uint8_t a = (in[src_id].qh[src_offset] >> (4*(l%2))) & 3; + uint8_t b = (in[src_id].qh[src_offset + 16] >> (4*(l%2))) & 3; + uint8_t c = (in[src_id].qh[src_offset] >> (2+4*(l%2))) & 3; + uint8_t d = (in[src_id].qh[src_offset + 16] >> (2+4*(l%2))) & 3; + + temp[j] = a | (b << 2) | (c << 4) | (d << 6); + } + for (int j = 0; j < nrows_interleaved; j++) { + out.qh[dst_offset + j] = temp[j]; + } + } + } + } + for (int i = 0; i < nrows_interleaved; i++) { for (int j = 0; j < 16; j++) { out.scales[j * nrows_interleaved + i] = in[i].scales[j]; @@ -3954,6 +4306,20 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q4_K_to_q4_K_Mx1_bl<64>(t, data, data_size); } +// Q5_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<64>(t, data, data_size); +} + // Q6_K template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q6_K_to_q6_K_Mx1_bl<8>(t, data, data_size); @@ -4115,6 +4481,20 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q5_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // Q6_K template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q6_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); @@ -4276,6 +4656,20 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q5_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // Q6_K template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q6_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); @@ -4742,6 +5136,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_32x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_64x1_q8_K; + // Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_64x1_q8_K; + // Q6_K static const ggml::cpu::repack::tensor_traits q6_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q6_K_16x1_q8_K; @@ -4836,6 +5236,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q5_K_8x4_q8_K; } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &q5_K_8x1_q8_K; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &q5_K_16x1_q8_K; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &q5_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q5_K_64x1_q8_K; } break; } + default: { return nullptr; } + } + #endif } } else if (cur->type == GGML_TYPE_Q6_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 7ef2e8341bc..dbb266a2172 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -82,16 +82,23 @@ static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); -struct block_q5_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[96]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants - uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) +template struct block_q5_Kx { + ggml_half d[N]; // super-block scale for quantized scales + ggml_half dmin[N]; // super-block scale for quantized mins + uint8_t scales[12 * N]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * N / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * N / 2]; // low bits of 5-bit quants (in groups of 4) }; -static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, - "wrong q5_K block size/padding"); +using block_q5_Kx8 = block_q5_Kx<8>; +using block_q5_Kx16 = block_q5_Kx<16>; +using block_q5_Kx32 = block_q5_Kx<32>; +using block_q5_Kx64 = block_q5_Kx<64>; + +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 10, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 20, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 40, "wrong q5_K block size/padding"); template struct block_q6_Kx { ggml_half d[N]; @@ -206,6 +213,10 @@ void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -230,6 +241,10 @@ void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -296,6 +311,10 @@ void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -320,6 +339,10 @@ void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 6df5cef8dfa9b6852e9ab367c06ac6e577c35709 Mon Sep 17 00:00:00 2001 From: Rehan Qasim Date: Wed, 18 Feb 2026 09:53:53 +0500 Subject: [PATCH 11/13] ggml-cpu: refactor; add rvv repacking for q3_K --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 722 +++++++++++++++++++++--- ggml/src/ggml-cpu/repack.cpp | 441 ++++++++++++++- ggml/src/ggml-cpu/repack.h | 35 +- 3 files changed, 1105 insertions(+), 93 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index b9822dc9c03..9a5453132fa 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -507,34 +507,33 @@ void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); assert(nr == 1); - assert(nc % 16 == 0); + assert(nc % ncols_interleaved == 0); UNUSED(bs); const int num_k_blocks = n / QK_K; - const size_t vl = __riscv_vsetvl_e32m2(ncols_interleaved); for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; const block_q2_Kx* rhs_base_ptr = (const block_q2_Kx*)vx + (col_tile / ncols_interleaved) * num_k_blocks; - vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int k_block = 0; k_block < num_k_blocks; ++k_block) { const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; const block_q2_Kx* rhs_current = &rhs_base_ptr[k_block]; // 1. Prepare Global Min Scales - vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); - vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, ncols_interleaved); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, ncols_interleaved); - vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl); + vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, ncols_interleaved); - vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); const uint8_t* rhs_qs_ptr = rhs_current->qs; const uint8_t* rhs_sc_ptr = rhs_current->scales; @@ -550,75 +549,77 @@ void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo { vuint8mf2_t v_raw; // Sub-block 0 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); - v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, ncols_interleaved); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); // Sub-block 1 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); - v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, ncols_interleaved); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); // Sub-block 2 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); - v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, ncols_interleaved); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); // Sub-block 3 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); - v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, ncols_interleaved); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); - rhs_sc_ptr += 64; + rhs_sc_ptr+=ncols_interleaved; } int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); int k_offsets[4] = {0, 32, 64, 96}; - // B. Inner Dot Product Loop for (int l = 0; l < 16; ++l) { - vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); - rhs_qs_ptr += 16; + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, ncols_interleaved); + rhs_qs_ptr += ncols_interleaved; // Sub-block 0 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } // Sub-block 1 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, ncols_interleaved), 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } // Sub-block 2 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, ncols_interleaved), 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } // Sub-block 3 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, ncols_interleaved), 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } } @@ -630,55 +631,56 @@ void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo int sb_idx = sb_base_abs + (k_offsets[0] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } // Sub-block 1 { int sb_idx = sb_base_abs + (k_offsets[1] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } // Sub-block 2 { int sb_idx = sb_base_abs + (k_offsets[2] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } // Sub-block 3 { int sb_idx = sb_base_abs + (k_offsets[3] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } } // End Phase Loop // Apply global Scales - vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); - vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, ncols_interleaved); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, ncols_interleaved); - vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl); - vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); - v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl); - v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl); + vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, ncols_interleaved); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, ncols_interleaved); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, ncols_interleaved); + v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, ncols_interleaved); } // End K-Block - __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, ncols_interleaved); } } + void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { #if defined __riscv_zvfh ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); @@ -708,6 +710,271 @@ void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } + + +template +__attribute__((optimize("no-schedule-insns"))) +static void inline ggml_gemv_q3_K_Mx1_q8_K( int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + assert(n % QK_K == 0); + // GEMV processes 1 row against 16 columns of weights + const int N_COLS_TILE = ncols_interleaved; + + assert(nc % N_COLS_TILE == 0); + + const int num_k_blocks = n / QK_K; + + // vl = 16. Using LMUL=2 for 32-bit accumulators on VLEN=256 + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + + // Loop over output columns (16 at a time) + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + const block_q8_K * lhs_base_ptr = (const block_q8_K *) vy; + const block_q3_Kx * rhs_base_ptr = (const block_q3_Kx *) vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Stage 3: Persistent Float Accumulator (1 vector for 16 columns) + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_K * lhs_current = &lhs_base_ptr[k_block]; + const block_q3_Kx * rhs_current = &rhs_base_ptr[k_block]; + + const uint8_t * rhs_qs_ptr = rhs_current->qs; + const uint8_t * rhs_hmask_ptr = rhs_current->hmask; + const uint8_t * rhs_sc_low_ptr = rhs_current->scales; + const uint8_t * rhs_sc_high_ptr = rhs_current->scales + 8*ncols_interleaved; + + // Activation pointer (linear access for GEMV) + const int8_t * lhs_qs_ptr = lhs_current->qs; + + // Stage 2: Main Integer Accumulator (1 vector) + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + + for (int group = 0; group < 4; ++group) { + // High scales are needed for all 4 sub-blocks + vuint8mf2_t v_sc_h_quad = __riscv_vle8_v_u8mf2(rhs_sc_high_ptr, vl); + rhs_sc_high_ptr += ncols_interleaved; + + // --- Scope 1: Sub-blocks 1 & 2 (Pair 0) --- + { + vuint8mf2_t v_sc_l_pair0 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, vl); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 1 --- + { + // 1. Initialize Temps + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + // 2. Heavy Dot Product Loop + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + // Mask generation as requested + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + // Masked subtraction as requested + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + // Scalar broadcast multiply-accumulate + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + // 3. Just-In-Time Scale Calculation + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair0, 0x0F, vl); + vuint8mf2_t v_sc_hi = __riscv_vand_vx_u8mf2(v_sc_h_quad, 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + // 4. Accumulate + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + + // --- Sub-block 2 --- + { + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + // JIT Scale Calc (Shift 4) + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair0, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 2, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + } + + // --- Scope 2: Sub-blocks 3 & 4 (Pair 1) --- + { + vuint8mf2_t v_sc_l_pair1 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, vl); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 3 --- + { + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair1, 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 4, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + // --- Sub-block 4 --- + { + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + // JIT Scale Calc (Shift 6) + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair1, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 6, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + } // End Scope 2 (Pair 1) + } // End group loop + + // --- Final Super-Block accumulation --- + vfloat32m2_t rhs_d = + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *) rhs_current->d, vl), vl); + float lhs_d = lhs_current->d; + vfloat32m2_t v_isum_f = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); + + // v_sumf += isum * d_act (scalar) * d_weight (vector) + v_sumf = __riscv_vfmacc_vv_f32m2(v_sumf, __riscv_vfmul_vf_f32m2(v_isum_f, lhs_d, vl), rhs_d, vl); + + } // End k_block loop + + // --- Store Results --- + // GEMV outputs a vector 's' (1 row). We store 16 contiguous elements. + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + + } // End col_tile loop +} + + +void ggml_gemv_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q3_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q3_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q3_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q3_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q3_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q3_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_q3_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_q3_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + template void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; @@ -2092,7 +2359,7 @@ void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); const int num_k_blocks = n / QK_K; const int N_ROWS_TILE = 4; @@ -2107,7 +2374,7 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { // Base Pointers const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; - const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / ncols_interleaved) * num_k_blocks; + const block_q2_Kx* rhs_base_ptr = (const block_q2_Kx*)vx + (col_tile / ncols_interleaved) * num_k_blocks; // Persistent Float Accumulators vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); @@ -2119,7 +2386,7 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #pragma GCC unroll 1 for (int k_block = 0; k_block < num_k_blocks; ++k_block) { const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block]; - const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + const block_q2_Kx* rhs_current = &rhs_base_ptr[k_block]; // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers) vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); @@ -2147,26 +2414,29 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo { vuint8mf2_t v_raw; // Sub-block 0 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr , vl); v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); // Sub-block 1 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr , vl); v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); // Sub-block 2 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, vl); v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); // Sub-block 3 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, vl); v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; - rhs_sc_ptr += 64; } int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); @@ -2176,8 +2446,7 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #pragma GCC unroll 1 for (int l = 0; l < 16; ++l) { vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); - rhs_qs_ptr += 16; - + rhs_qs_ptr+=ncols_interleaved; // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase) // --- Sub-block 0 --- @@ -2418,6 +2687,321 @@ void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +template +__attribute__((optimize("no-schedule-insns"))) +static void ggml_gemm_q3_K_Mx1_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + assert(n % QK_K == 0); + const int N_ROWS_TILE = 4; + const int N_COLS_TILE = ncols_interleaved; + + assert(nr % N_ROWS_TILE == 0); + assert(nc % N_COLS_TILE == 0); + + const int num_k_blocks = n / QK_K; + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + + for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + const block_q8_Kx4 * lhs_base_ptr = (const block_q8_Kx4 *) vy + (row_tile / N_ROWS_TILE) * num_k_blocks; + const block_q3_Kx * rhs_base_ptr = (const block_q3_Kx *) vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Stage 3: Persistent Float Accumulators (8 registers) + vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_Kx4 * lhs_current = &lhs_base_ptr[k_block]; + const block_q3_Kx * rhs_current = &rhs_base_ptr[k_block]; + + const uint8_t * rhs_qs_ptr = rhs_current->qs; + const uint8_t * rhs_hmask_ptr = rhs_current->hmask; + const uint8_t * rhs_sc_low_ptr = rhs_current->scales; + const uint8_t * rhs_sc_high_ptr = rhs_current->scales + (8 * ncols_interleaved); + const int8_t * lhs_qs_ptr = lhs_current->qs; + + // Stage 2: Main Integer Accumulators (8 registers) + vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int group = 0; group < 4; ++group) { + // High scales are needed for all 4 sub-blocks (0.5 register) + vuint8mf2_t v_sc_h_quad = __riscv_vle8_v_u8mf2(rhs_sc_high_ptr, 16); + rhs_sc_high_ptr += ncols_interleaved; + + // --- Scope 1: Sub-blocks 1 & 2 (Pair 0) --- + // By scoping this, v_sc_l_pair0 dies before we load pair1 + { + vuint8mf2_t v_sc_l_pair0 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, 16); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 1 --- + { + // 1. Initialize Temps (4 registers) + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // 2. Heavy Dot Product Loop + // Note: v_sc_16 is NOT live here, saving 0.5 - 1 register of pressure + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // 3. Just-In-Time Scale Calculation + // Only now do we allocate the register for v_sc_16 + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair0, 0x0F, vl); + vuint8mf2_t v_sc_hi = __riscv_vand_vx_u8mf2(v_sc_h_quad, 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + // 4. Accumulate + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + + // --- Sub-block 2 --- + { + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // Dot Product Loop (Same as above) + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // JIT Scale Calc + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair0, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 2, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + } // v_sc_l_pair0 dies here + + // --- Scope 2: Sub-blocks 3 & 4 (Pair 1) --- + { + vuint8mf2_t v_sc_l_pair1 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, 16); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 3 --- + { + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // Dot Product Loop (Same as above) + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // JIT Scale Calc + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair1, 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 4, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + + // --- Sub-block 4 --- + { + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // Dot Product Loop (Same as above) + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // JIT Scale Calc + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair1, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 6, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + } // v_sc_l_pair1 dies here + } // End group loop + + vfloat32m2_t rhs_d = + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *) rhs_current->d, vl), vl); + const float * lhs_d_ptr = lhs_current->d; + + vfloat32m2_t v_isum_0_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl); + vfloat32m2_t v_isum_1_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl); + vfloat32m2_t v_isum_2_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl); + vfloat32m2_t v_isum_3_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl); + + v_sumf_0 = + __riscv_vfmacc_vv_f32m2(v_sumf_0, __riscv_vfmul_vf_f32m2(v_isum_0_f, lhs_d_ptr[0], vl), rhs_d, vl); + v_sumf_1 = + __riscv_vfmacc_vv_f32m2(v_sumf_1, __riscv_vfmul_vf_f32m2(v_isum_1_f, lhs_d_ptr[1], vl), rhs_d, vl); + v_sumf_2 = + __riscv_vfmacc_vv_f32m2(v_sumf_2, __riscv_vfmul_vf_f32m2(v_isum_2_f, lhs_d_ptr[2], vl), rhs_d, vl); + v_sumf_3 = + __riscv_vfmacc_vv_f32m2(v_sumf_3, __riscv_vfmul_vf_f32m2(v_isum_3_f, lhs_d_ptr[3], vl), rhs_d, vl); + + } // End k_block loop + + __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl); + } + } +} + +void ggml_gemm_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q3_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q3_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q3_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q3_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q3_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q3_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q3_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q3_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + template void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 0418c7f9128..44398f4abd7 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -467,8 +467,8 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT // Loop over K-blocks for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[16] = {0}; - int32_t summs[16] = {0}; + int32_t isum[ncols_interleaved] = {0}; + int32_t summs[ncols_interleaved] = {0}; const uint8_t * qs_rhs = x_ptr[k_block].qs; const uint8_t * sc_rhs = x_ptr[k_block].scales; @@ -479,9 +479,9 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT for (int sb = 0; sb < 16; ++sb) { // Correction Term int16_t bsum = bs_lhs[sb]; - int scale_offset = sb_perm[sb] * 16; + int scale_offset = sb_perm[sb] * ncols_interleaved; - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { uint8_t sc_val = sc_rhs[scale_offset + col]; summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits } @@ -494,14 +494,14 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT int shift = ((sb / 2) % 4) * 2; - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { uint8_t sc_val = sc_rhs[scale_offset + col]; int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits // Process 16 elements (l=0..15) for (int l = 0; l < 16; ++l) { // Q2: Interleaved by column. Byte `l` contains 4 k-values. - int qs_idx = (byte_base + l) * 16 + col; + int qs_idx = (byte_base + l) * ncols_interleaved + col; uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; // Q8: Linear access @@ -514,7 +514,7 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } // Finalize K-Block - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { float d_lhs = y_ptr[k_block].d; float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); @@ -526,7 +526,89 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { + s[col_tile + col] = sumf[col]; + } + } +} + + +template +void ggml_gemv_q3_K_Mx1_q8_K_generic( + int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % ncols_interleaved == 0); + (void)bs; + + const int nb = n / QK_K; + const block_q3_Kx * x = (const block_q3_Kx *) vx; + const block_q8_K * y = (const block_q8_K *) vy; + + const int scale_high_offset = 8 * ncols_interleaved; + + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { + const block_q3_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; + + float sumf[ncols_interleaved]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + const block_q3_Kx & xb = x_ptr[k_block]; + const block_q8_K & yb = y[k_block]; + + int32_t isum[ncols_interleaved] = {0}; + + for (int sb = 0; sb < 16; ++sb) { + const int s_row_lo = sb >> 1; + const int s_row_hi = sb >> 2; + const int s_shift_lo = (sb & 1) ? 4 : 0; + const int s_shift_hi = (sb & 3) * 2; + + for (int l = 0; l < 16; ++l) { + const int k = sb * 16 + l; + const int qs_row = k >> 2; + const int qs_shift = (k & 3) * 2; + const int hm_row = k >> 3; + const int hm_shift = k & 7; + + const int8_t q8 = yb.qs[k]; + + for (int col = 0; col < ncols_interleaved; ++col) { + // Inline q3k_get_scale6_packed + const uint8_t scale_lo_byte = xb.scales[s_row_lo * ncols_interleaved + col]; + const uint8_t scale_hi_byte = xb.scales[scale_high_offset + s_row_hi * ncols_interleaved + col]; + + const uint8_t s_lo = (scale_lo_byte >> s_shift_lo) & 0x0F; + const uint8_t s_hi = (scale_hi_byte >> s_shift_hi) & 0x03; + const int sc = (int)((s_hi << 4) | s_lo) - 32; + + // Inline q3k_get_val_packed + const uint8_t qs_byte = xb.qs[qs_row * ncols_interleaved + col]; + const uint8_t hm_byte = xb.hmask[hm_row * ncols_interleaved + col]; + + const int low2 = (qs_byte >> qs_shift) & 3; + const int hb = (hm_byte >> hm_shift) & 1; + const int v = low2 - (hb ? 0 : 4); + + isum[col] += (v * sc) * q8; + } + } + } + + for (int col = 0; col < ncols_interleaved; ++col) { + sumf[col] += (float) isum[col] * (GGML_FP16_TO_FP32(xb.d[col]) * yb.d); + } + } + + for (int col = 0; col < ncols_interleaved; ++col) { s[col_tile + col] = sumf[col]; } } @@ -904,7 +986,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT assert(nr % 4 == 0); assert(nc % 16 == 0); const int nb = n / QK_K; - const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q2_Kx * x = (const block_q2_Kx *)vx; const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; const int sb_perm[16] = { @@ -915,17 +997,17 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT // Iterate Rows in tiles of 4 for (int row_tile = 0; row_tile < nr; row_tile += 4) { // Iterate Columns in tiles of 16 - for (int col_tile = 0; col_tile < nc; col_tile += 16) { + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { - const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q2_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; - float sumf[4][16]; + float sumf[4][ncols_interleaved]; memset(sumf, 0, sizeof(sumf)); for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[4][16]; - int32_t summs[4][16]; + int32_t isum[4][ncols_interleaved]; + int32_t summs[4][ncols_interleaved]; memset(isum, 0, sizeof(isum)); memset(summs, 0, sizeof(summs)); @@ -935,14 +1017,14 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT const int16_t * bs_lhs = y_ptr[k_block].bsums; for (int sb = 0; sb < 16; ++sb) { - int scale_offset = sb_perm[sb] * 16; + int scale_offset = sb_perm[sb] * ncols_interleaved; int byte_base; if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; else byte_base = (sb % 2 == 0) ? 32 : 48; int shift = ((sb / 2) % 4) * 2; - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { uint8_t sc_val = sc_rhs[scale_offset + col]; int32_t d_sb = sc_val & 0xF; int32_t m_sb = sc_val >> 4; @@ -955,7 +1037,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT // Main Dot Product for (int l = 0; l < 16; ++l) { - int qs_idx = (byte_base + l) * 16 + col; + int qs_idx = (byte_base + l) * ncols_interleaved + col; uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; // Calculate Q8 index for this specific k and row @@ -972,7 +1054,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } // Finalize K-Block - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); @@ -986,7 +1068,115 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } for (int r = 0; r < 4; ++r) { - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; + } + } + } + } +} + +template +void ggml_gemm_q3_K_Mx1_q8_K_generic( + int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + const int nb = n / QK_K; + const block_q3_Kx * x = (const block_q3_Kx *) vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *) vy; + + // Offsets for the high part of the scales (8 rows of low bytes * columns) + const int scale_high_offset = 8 * ncols_interleaved; + + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { + const block_q3_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][ncols_interleaved]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + const block_q3_Kx & xb = x_ptr[k_block]; + const block_q8_Kx4 & yb = y_ptr[k_block]; + + int32_t isum[4][ncols_interleaved]; + memset(isum, 0, sizeof(isum)); + + for (int sb = 0; sb < 16; ++sb) { + // Pre-calc scale indices for this sub-block + const int s_row_lo = sb >> 1; + const int s_row_hi = sb >> 2; + const int s_shift_lo = (sb & 1) ? 4 : 0; + const int s_shift_hi = (sb & 3) * 2; + + for (int l = 0; l < 16; ++l) { + const int k = sb * 16 + l; + + // Pre-calc weight indices for this k + const int qs_row = k >> 2; + const int qs_shift = (k & 3) * 2; + const int hm_row = k >> 3; + const int hm_shift = k & 7; + + const int8_t q8_0 = yb.qs[k * 4 + 0]; + const int8_t q8_1 = yb.qs[k * 4 + 1]; + const int8_t q8_2 = yb.qs[k * 4 + 2]; + const int8_t q8_3 = yb.qs[k * 4 + 3]; + + for (int col = 0; col < ncols_interleaved; ++col) { + // Inline q3k_get_scale6_packed + const uint8_t scale_lo_byte = xb.scales[s_row_lo * ncols_interleaved + col]; + const uint8_t scale_hi_byte = xb.scales[scale_high_offset + s_row_hi * ncols_interleaved + col]; + + const uint8_t s_lo = (scale_lo_byte >> s_shift_lo) & 0x0F; + const uint8_t s_hi = (scale_hi_byte >> s_shift_hi) & 0x03; + + const int sc = (int)((s_hi << 4) | s_lo) - 32; + + // Inline q3k_get_val_packed + const uint8_t qs_byte = xb.qs[qs_row * ncols_interleaved + col]; + const uint8_t hm_byte = xb.hmask[hm_row * ncols_interleaved + col]; + + const int low2 = (qs_byte >> qs_shift) & 3; + const int hb = (hm_byte >> hm_shift) & 1; + + const int v = low2 - (hb ? 0 : 4); + + const int w = v * sc; + isum[0][col] += w * q8_0; + isum[1][col] += w * q8_1; + isum[2][col] += w * q8_2; + isum[3][col] += w * q8_3; + } + } + } + + for (int col = 0; col < ncols_interleaved; ++col) { + const float d_rhs = GGML_FP16_TO_FP32(xb.d[col]); + const float g0 = d_rhs * yb.d[0]; + const float g1 = d_rhs * yb.d[1]; + const float g2 = d_rhs * yb.d[2]; + const float g3 = d_rhs * yb.d[3]; + + sumf[0][col] += (float) isum[0][col] * g0; + sumf[1][col] += (float) isum[1][col] * g1; + sumf[2][col] += (float) isum[2][col] * g2; + sumf[3][col] += (float) isum[3][col] * g3; + } + } + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < ncols_interleaved; ++col) { s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; } } @@ -1972,6 +2162,20 @@ void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q3_K +void ggml_gemv_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // Q4_K void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -2800,6 +3004,20 @@ void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemm_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q3_K +void ggml_gemm_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // Q4_K void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -3288,6 +3506,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } + static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -3643,13 +3862,12 @@ static block_q2_Kx make_block_q2_KxMx1(const block_q2_K * in) block_q2_Kx out; constexpr int N_COLS = nrows_interleaved; - // 1. Copy Super-Scales (d) and Super-Mins (dmin) for (int i = 0; i < N_COLS; i++) { out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; } - // 2. Interleave Q2_K Data + // Interleave Q2_K Data const int bytes_per_col = 64; const int total_bytes = N_COLS * bytes_per_col; const int end = total_bytes; @@ -3661,7 +3879,7 @@ static block_q2_Kx make_block_q2_KxMx1(const block_q2_K * in) memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], 1); } - // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + // Repack Scales into the Optimized "Sequential-Parallel" Layout int out_idx = 0; // Arrays define the sub-block order for each group @@ -3721,7 +3939,7 @@ static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - // This loop gathers 16 separate blocks (one from each column) + // This loop gathers 16 separate blocks (one from each row (of transposed matrix() // that correspond to the same K-dimension chunk. for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; @@ -3736,6 +3954,122 @@ static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ GGML_UNUSED(data_size); } +template +static block_q3_Kx make_block_q3_KxMx1(const block_q3_K * in) { + block_q3_Kx out; + constexpr int N_COLS = nrows_interleaved; + constexpr int scales_bytes = 12; + constexpr int hmask_bytes = 32; + constexpr int qs_bytes = 64; + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].d; + } + + // 2. Process each column to Linearize metadata, then Interleave + uint8_t temp_scales[scales_bytes]; + uint8_t temp_hmask[hmask_bytes]; + uint8_t temp_qs[qs_bytes]; + + for (int col = 0; col < N_COLS; ++col) { + const block_q3_K & src = in[col]; + + uint8_t scale6[16]; + for (int sb = 0; sb < 16; ++sb) { + const uint8_t lo = sb < 8 ? (src.scales[sb] & 0xF) : (src.scales[sb - 8] >> 4); + const uint8_t hi = (src.scales[8 + (sb & 3)] >> (2 * (sb >> 2))) & 0x3; + scale6[sb] = lo | (hi << 4); + } + + // Repack into linear format: + // 0-7: Low 4 bits of pairs + for (int i = 0; i < 8; ++i) { + temp_scales[i] = (scale6[2*i] & 0x0F) | ((scale6[2*i + 1] & 0x0F) << 4); + } + // 8-11: High 2 bits of quads + for (int i = 0; i < 4; ++i) { + const int base = 4*i; + temp_scales[8 + i] = + (((scale6[base + 0] >> 4) & 0x03) << 0) | + (((scale6[base + 1] >> 4) & 0x03) << 2) | + (((scale6[base + 2] >> 4) & 0x03) << 4) | + (((scale6[base + 3] >> 4) & 0x03) << 6); + } + + // --- transpose HMask --- + memset(temp_hmask, 0, sizeof(temp_hmask)); + for (int hb = 0; hb < hmask_bytes; ++hb) { + const int elem_base = hb * 8; + for (int bit = 0; bit < 8; ++bit) { + const int idx = elem_base + bit; + // We want sequential: Byte `i` contains bits for weights `8*i` to `8*i+7` + const uint8_t hi = (src.hmask[idx & 31] >> (idx >> 5)) & 0x1; + temp_hmask[hb] |= (hi << bit); + } + } + + // --- QS (De-stride) --- + memset(temp_qs, 0, sizeof(temp_qs)); + for (int qb = 0; qb < qs_bytes; ++qb) { + const int elem_base = qb * 4; + for (int lane = 0; lane < 4; ++lane) { + const int idx = elem_base + lane; + // Logic to find byte offset and shift in standard Q3_K strided layout + const int src_byte = ((idx >> 7) << 5) + (idx & 31); + const int shift = ((idx >> 5) & 0x3) << 1; + const uint8_t lo2 = (src.qs[src_byte] >> shift) & 0x3; + temp_qs[qb] |= (lo2 << (2 * lane)); + } + } + + // --- Write Interleaved to Output --- + for (int i = 0; i < scales_bytes; ++i) { + out.scales[i * N_COLS + col] = temp_scales[i]; + } + for (int i = 0; i < hmask_bytes; ++i) { + out.hmask[i * N_COLS + col] = temp_hmask[i]; + } + for (int i = 0; i < qs_bytes; ++i) { + out.qs[i * N_COLS + col] = temp_qs[i]; + } + } + + return out; +} + +template +static int repack_q3_K_to_q3_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + + block_q3_Kx * dst = (block_q3_Kx*)t->data; + const block_q3_K * src = (const block_q3_K*) data; + + block_q3_K dst_tmp[nrows_interleaved]; + + const int nrow = ggml_nrows(t); + const int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == (size_t) nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // Gather N separate blocks from N adjacent rows + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q3_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + template static block_q4_Kx make_block_q4_KxMx1(block_q4_K * in) { block_q4_Kx out; @@ -4292,6 +4626,20 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q2_K_to_q2_K_Mx1_bl<64>(t, data, data_size); } +// Q3_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<64>(t, data, data_size); +} + // Q4_K template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_K_to_q4_K_Mx1_bl<8>(t, data, data_size); @@ -4467,6 +4815,20 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q3_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // Q4_K template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); @@ -4642,6 +5004,20 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q3_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // Q4_K template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); @@ -5130,6 +5506,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q2_K_32x1_q8_K; static const ggml::cpu::repack::tensor_traits q2_K_64x1_q8_K; + // Q3_K + static const ggml::cpu::repack::tensor_traits q3_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q3_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q3_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q3_K_64x1_q8_K; + // Q4_K static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; @@ -5226,6 +5608,19 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } + } else if (cur->type == GGML_TYPE_Q3_K) { + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &q3_K_8x1_q8_K; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &q3_K_16x1_q8_K; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &q3_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q3_K_64x1_q8_K; } break; } + default: { return nullptr; } + } + #endif + } + } else if (cur->type == GGML_TYPE_Q5_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index dbb266a2172..207f9ae5049 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -82,6 +82,24 @@ static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); +template +struct block_q3_Kx { + ggml_half d[N]; // super-block scales + uint8_t scales[12 * N]; // 6-bit quantized scales (packed) + uint8_t hmask[N * QK_K / 8]; // high bit of weights (1 bit/weight) + uint8_t qs[N * QK_K / 4]; // low 2 bits of weights (2 bits/weight) +}; + +using block_q3_Kx8 = block_q3_Kx<8>; +using block_q3_Kx16 = block_q3_Kx<16>; +using block_q3_Kx32 = block_q3_Kx<32>; +using block_q3_Kx64 = block_q3_Kx<64>; + +static_assert(sizeof(block_q3_Kx8) == sizeof(ggml_half) * 8 + 12 * 8 + QK_K + QK_K * 2, "wrong q3_K block size/padding for x8"); +static_assert(sizeof(block_q3_Kx16) == sizeof(ggml_half) * 16 + 12 * 16 + QK_K * 2 + QK_K * 4, "wrong q3_K block size/padding for x16"); +static_assert(sizeof(block_q3_Kx32) == sizeof(ggml_half) * 32 + 12 * 32 + QK_K * 4 + QK_K * 8, "wrong q3_K block size/padding for x32"); +static_assert(sizeof(block_q3_Kx64) == sizeof(ggml_half) * 64 + 12 * 64 + QK_K * 8 + QK_K * 16, "wrong q3_K block size/padding for x64"); + template struct block_q5_Kx { ggml_half d[N]; // super-block scale for quantized scales ggml_half dmin[N]; // super-block scale for quantized mins @@ -196,7 +214,6 @@ void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -209,6 +226,10 @@ void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -237,6 +258,10 @@ void ggml_gemm_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -307,6 +332,10 @@ void ggml_gemv_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -335,6 +364,10 @@ void ggml_gemm_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From c4808afc3f84fddc4039d4f9f967a4d83c0aabb8 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Thu, 19 Feb 2026 16:09:33 +0500 Subject: [PATCH 12/13] ggml-cpu: refactor; add rvv repacking for q5_K --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 9a5453132fa..3c9e595ecd8 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -1206,15 +1206,15 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // Accumulation for 2 sub-blocks. // - // This might overflow, so we accumulate in two steps. + // This might overflow, so we accumulate in 4 steps. // // Recheck. - for (int k = 0; k < 2; k++) { + for (int k = 0; k < 4; k++) { // 4xM integer accumulators vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + for (int i = k * 8; i < (k + 1) * 8; i++) { // Load `b_ptr`. const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); @@ -1240,15 +1240,15 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } // Accumulation for 2 sub-blocks. // - // This might overflow, so we accumulate in two steps. + // This might overflow, so we accumulate in 4 steps. // // Recheck. - for (int k = 0; k < 2; k++) { + for (int k = 0; k < 4; k++) { // 4xM integer accumulators vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + for (int i = k * 8; i < (k + 1) * 8; i++) { // Load `b_ptr`. const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); @@ -1343,7 +1343,7 @@ void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); - // We process 2 16-element sub-blocks at once. + // We process 4 16-element sub-blocks at once. for (int j = 0; j < QK_K / 16; j += 4) { // Load the scales. // @@ -2736,7 +2736,7 @@ static void ggml_gemm_q3_K_Mx1_q8_K(int n, for (int group = 0; group < 4; ++group) { // High scales are needed for all 4 sub-blocks (0.5 register) vuint8mf2_t v_sc_h_quad = __riscv_vle8_v_u8mf2(rhs_sc_high_ptr, 16); - rhs_sc_high_ptr += ncols_interleaved; + rhs_sc_high_ptr += ncols_interleaved; // --- Scope 1: Sub-blocks 1 & 2 (Pair 0) --- // By scoping this, v_sc_l_pair0 dies before we load pair1 @@ -3437,10 +3437,10 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // Accumulation for 2 sub-blocks. // - // This might overflow, so we accumulate in two steps. + // This might overflow, so we accumulate in 4 steps. // // Recheck. - for (int k = 0; k < 2; k++) { + for (int k = 0; k < 4; k++) { // 4xM integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -3451,7 +3451,7 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + for (int i = k * 8; i < (k + 1) * 8; i++) { // Load `b_ptr`. const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); @@ -3505,7 +3505,7 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // This might overflow, so we accumulate in two steps. // // Recheck. - for (int k = 0; k < 2; k++) { + for (int k = 0; k < 4; k++) { // 4xM integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -3516,7 +3516,7 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + for (int i = k * 8; i < (k + 1) * 8; i++) { // Load `b_ptr`. const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); From 053cb1c360463f217f6cd50cef71c6add2d080d4 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Wed, 4 Mar 2026 19:00:30 +0500 Subject: [PATCH 13/13] ggml-cpu: refactor; add rvv repacking for mxfp4 --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 801 ++++++++---------------- ggml/src/ggml-cpu/repack.cpp | 730 ++++++++++++++++++++- ggml/src/ggml-cpu/repack.h | 42 +- 3 files changed, 999 insertions(+), 574 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 3c9e595ecd8..c3990ed62f5 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -222,10 +222,6 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_ UNUSED(ncols_interleaved); UNUSED(blocklen); -<<<<<<< HEAD -#if defined __riscv_v_intrinsic -======= ->>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); @@ -273,143 +269,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q4_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); #else ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); -<<<<<<< HEAD -} - -void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const block_q8_K * a_ptr = (const block_q8_K *) vy; - - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - - // 1x16 Accumulator - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); - - // Load `dmin`. - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); - - // We process 4 sub-blocks at once. - for (int j = 0; j < QK_K / 128; j++) { - // Extract the scales and the mins. - // - // Low bits. - vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); - vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); - vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); - - // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); - vuint8m2_t scales_hi; - vuint8m2_t mins_hi; - if (!j) { - scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); - mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); - } else { - scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); - mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); - } - vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); - vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); - - // Reduce the mins and multiply with `dmin`. - // - // Correct in `sumf`. - vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - - sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); - - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16); - sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16); - } - - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_s_0_16, 16); - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_s_1_16, 16); - } - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16); - sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16); - } - - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_s_0_16, 16); - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_s_1_16, 16); - } - } - - const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16); - const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16); - - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); - } - - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); - } - return; -======= ->>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) #endif } void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -428,7 +287,7 @@ void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -507,7 +366,7 @@ void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); assert(nr == 1); assert(nc % ncols_interleaved == 0); @@ -680,7 +539,6 @@ static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_ } } - void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { #if defined __riscv_zvfh ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); @@ -710,8 +568,6 @@ void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } - - template __attribute__((optimize("no-schedule-insns"))) static void inline ggml_gemv_q3_K_Mx1_q8_K( int n, @@ -945,7 +801,6 @@ static void inline ggml_gemv_q3_K_Mx1_q8_K( int n, } // End col_tile loop } - void ggml_gemv_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { #if defined __riscv_zvfh ggml_gemv_q3_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); @@ -976,7 +831,7 @@ void ggml_gemv_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -1005,6 +860,10 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -1038,8 +897,6 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); // Accumulation for 2 sub-blocks. @@ -1138,7 +995,7 @@ void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -1167,6 +1024,10 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -1200,8 +1061,6 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); // Accumulation for 2 sub-blocks. @@ -1314,7 +1173,7 @@ void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -1462,7 +1321,7 @@ void ggml_gemv_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -1543,6 +1402,94 @@ void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const #endif } +template +static inline void ggml_gemv_mxfp4_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_mxfp4, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + + // 1xM Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 1xM integer accumulator + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK_MXFP4 / 2; i++) { + // Load `b_ptr`. + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + } + + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); + + float b_scales[ncols_interleaved]; + for (int i = 0; i < ncols_interleaved; i++) { + b_scales[i] = GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[i]); + } + const vfloat32m2_t b_e = __riscv_vle32_v_f32m2((const float *)&b_scales[0], ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d), ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_mxfp4_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_mxfp4_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_mxfp4_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_mxfp4_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_mxfp4_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_mxfp4_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemv_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemv_mxfp4_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemv_mxfp4_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1770,7 +1717,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo } template -void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -1789,12 +1736,6 @@ void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo UNUSED(ncols_interleaved); UNUSED(blocklen); -<<<<<<< HEAD -#if defined __riscv_v_intrinsic -======= - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); - ->>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1873,280 +1814,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q4_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); #else ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); -<<<<<<< HEAD -} - -void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - - // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); - - // Load `dmin`. - const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16); - - // We process 4 sub-blocks at once. - for (int j = 0; j < QK_K / 128; j++) { - // Extract the scales and the mins. - // - // Low bits. - vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); - vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); - vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); - - // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); - vuint8m2_t scales_hi; - vuint8m2_t mins_hi; - if (!j) { - scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); - mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); - } else { - scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); - mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); - } - vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); - vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); - - // Reduce the mins and multiply with `dmin`. - // - // Correct in `sumf`. - vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); - - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); - sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); - sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16); - sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16); - - - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - // 4x16 integer accumulators - vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16); - sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16); - sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16); - sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16); - - sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16); - sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16); - sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16); - sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16); - } - - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_0_s_0_16, 16); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_0_s_1_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_1_s_0_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_1_s_1_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_2_s_0_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_2_s_1_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_3_s_0_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_3_s_1_16, 16); - } - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - // 4x16 integer accumulators - vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16); - sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16); - sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16); - sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16); - - sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16); - sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16); - sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16); - sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16); - } - - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_0_s_0_16, 16); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_0_s_1_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_1_s_0_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_1_s_1_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_2_s_0_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_2_s_1_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_3_s_0_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_3_s_1_16, 16); - } - } - - const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16); - const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); - } - - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); - } - } - return; -======= ->>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) #endif } void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2165,101 +1832,9 @@ void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; -<<<<<<< HEAD - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); - - // Accumulation loop. - for (int i = 0; i < QK4_NL / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - - const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); - const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); - const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); - const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16); - - const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16); - const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16); - const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16); - const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16); - - sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16); - sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16); - sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16); - sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); - } - - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); - } - - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); - } - } - return; -#endif - ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - -void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; -======= ->>>>>>> aaf8e8a16 (ggml-cpu: extend rvv gemm, gemv to other vlens) const int blocklen = 1; assert (n % qk == 0); @@ -2359,7 +1934,7 @@ void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -static void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); const int num_k_blocks = n / QK_K; const int N_ROWS_TILE = 4; @@ -3003,7 +2578,7 @@ void ggml_gemm_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -3039,6 +2614,9 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -3120,14 +2698,10 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], ncols_interleaved); sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); @@ -3303,7 +2877,7 @@ void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -3339,6 +2913,9 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -3420,14 +2997,10 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], ncols_interleaved); sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); @@ -3617,7 +3190,7 @@ void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -3824,7 +3397,7 @@ void ggml_gemm_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -3870,8 +3443,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4], ncols_interleaved); const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 1], ncols_interleaved); @@ -3940,3 +3511,129 @@ void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); #endif } + +template +static inline void ggml_gemm_mxfp4_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_mxfp4, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + + // 4xM Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 4xM integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK_MXFP4 / 2; i++) { + // Load `b_ptr`. + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + } + + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, ncols_interleaved); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, ncols_interleaved); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, ncols_interleaved); + + float b_scales[ncols_interleaved]; + for (int i = 0; i < ncols_interleaved; i++) { + b_scales[i] = GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[i]); + } + const vfloat32m2_t b_e = __riscv_vle32_v_f32m2((const float *)&b_scales[0], ncols_interleaved); + + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[0]), ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[1]), ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[2]), ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[3]), ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } + return; +#endif + ggml_gemm_mxfp4_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_mxfp4_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_mxfp4_8x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_mxfp4_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_mxfp4_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_mxfp4_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_mxfp4_32x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_mxfp4_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_mxfp4_64x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 44398f4abd7..034b1b97c2b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -585,7 +585,7 @@ void ggml_gemv_q3_K_Mx1_q8_K_generic( // Inline q3k_get_scale6_packed const uint8_t scale_lo_byte = xb.scales[s_row_lo * ncols_interleaved + col]; const uint8_t scale_hi_byte = xb.scales[scale_high_offset + s_row_hi * ncols_interleaved + col]; - + const uint8_t s_lo = (scale_lo_byte >> s_shift_lo) & 0x0F; const uint8_t s_hi = (scale_hi_byte >> s_shift_hi) & 0x03; const int sc = (int)((s_hi << 4) | s_lo) - 32; @@ -593,7 +593,7 @@ void ggml_gemv_q3_K_Mx1_q8_K_generic( // Inline q3k_get_val_packed const uint8_t qs_byte = xb.qs[qs_row * ncols_interleaved + col]; const uint8_t hm_byte = xb.hmask[hm_row * ncols_interleaved + col]; - + const int low2 = (qs_byte >> qs_shift) & 3; const int hb = (hm_byte >> hm_shift) & 1; const int v = low2 - (hb ? 0 : 4); @@ -877,6 +877,44 @@ static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } + +template +static inline void ggml_gemv_mxfp4_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} #endif #if defined __riscv_zvfh @@ -1121,7 +1159,7 @@ void ggml_gemm_q3_K_Mx1_q8_K_generic( for (int l = 0; l < 16; ++l) { const int k = sb * 16 + l; - + // Pre-calc weight indices for this k const int qs_row = k >> 2; const int qs_shift = (k & 3) * 2; @@ -1137,19 +1175,19 @@ void ggml_gemm_q3_K_Mx1_q8_K_generic( // Inline q3k_get_scale6_packed const uint8_t scale_lo_byte = xb.scales[s_row_lo * ncols_interleaved + col]; const uint8_t scale_hi_byte = xb.scales[scale_high_offset + s_row_hi * ncols_interleaved + col]; - + const uint8_t s_lo = (scale_lo_byte >> s_shift_lo) & 0x0F; const uint8_t s_hi = (scale_hi_byte >> s_shift_hi) & 0x03; - + const int sc = (int)((s_hi << 4) | s_lo) - 32; // Inline q3k_get_val_packed const uint8_t qs_byte = xb.qs[qs_row * ncols_interleaved + col]; const uint8_t hm_byte = xb.hmask[hm_row * ncols_interleaved + col]; - + const int low2 = (qs_byte >> qs_shift) & 3; const int hb = (hm_byte >> hm_shift) & 1; - + const int v = low2 - (hb ? 0 : 4); const int w = v * sc; @@ -1498,8 +1536,448 @@ static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC } } } + +template +static inline void ggml_gemm_mxfp4_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} #endif +template +static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +template +static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + const int q8_half_stride = 512; + const int q8_low_high_step = 256; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = + qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = + qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +template +static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +template +static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = + qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 256 + + (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1861,16 +2339,101 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } -<<<<<<< HEAD void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); } -======= ->>>>>>> 4044254f5 (ggml-cpu: refactor; add rvv repacking for q6_K) void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); + constexpr int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + + + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + // qh indexing with 8-byte interleaving (like q5_K) + const int qh_byte_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_byte_l / 8; + const int qh_pos_l = qh_byte_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_byte_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_byte_h / 8; + const int qh_pos_h = qh_byte_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + + printf("w: %d %d, b: %d %d %d\n", q_l, a_l, l_4, hi_2_l, b_ptr[l].qh[qh_offset_h]); + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2231,6 +2794,20 @@ void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +void ggml_gemv_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); +} #endif void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -3073,6 +3650,20 @@ void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +void ggml_gemm_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); +} #endif } // extern "C" @@ -4392,6 +4983,59 @@ static int repack_iq4_nl_to_iq4_nl_Mx1_bl(struct ggml_tensor * t, const void * G GGML_UNUSED(data_size); } + +template +static block_mxfp4x make_block_mxfp4xMx1(block_mxfp4 * in) { + block_mxfp4x out; + + for (int i = 0; i < nrows_interleaved; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * nrows_interleaved / 2; + + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + return out; +} + +template +static int repack_mxfp4_to_mxfp4_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x * dst = ( block_mxfp4x *)t->data; + + block_mxfp4 dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4xMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} #endif static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { @@ -4695,6 +5339,20 @@ template <> int repack(struct ggml_tensor * t, const void * template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_Mx1_bl<64>(t, data, data_size); } + +// MXFP4 +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<64>(t, data, data_size); +} #endif // gemv @@ -4884,6 +5542,20 @@ template <> void gemv(int n, float * s, siz template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif // gemm @@ -5073,6 +5745,20 @@ template <> void gemm(int n, float * s, siz template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -5535,6 +6221,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_32x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_64x1_q8_0; + + // MXFP4 + static const ggml::cpu::repack::tensor_traits mxfp4_8x1_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_64x1_q8_0; #endif if (cur->type == GGML_TYPE_Q4_0) { @@ -5620,7 +6312,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } - + } else if (cur->type == GGML_TYPE_Q5_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { @@ -5631,6 +6323,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q5_K_8x4_q8_K; } + } if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { @@ -5652,6 +6345,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q6_K_8x4_q8_K; } + } if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { @@ -5718,6 +6412,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } + } else if (cur->type == GGML_TYPE_MXFP4) { + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &mxfp4_8x1_q8_0; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &mxfp4_16x1_q8_0; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &mxfp4_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &mxfp4_64x1_q8_0; } break; } + default: { return nullptr; } + } + #endif + } } return nullptr; diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 207f9ae5049..e585cb68ece 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -82,7 +82,7 @@ static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); -template +template struct block_q3_Kx { ggml_half d[N]; // super-block scales uint8_t scales[12 * N]; // 6-bit quantized scales (packed) @@ -160,17 +160,23 @@ static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "w static_assert(sizeof(block_iq4_nlx32) == 32 * sizeof(ggml_half) + QK4_NL * 16, "wrong iq4_nlx32 block size/padding"); static_assert(sizeof(block_iq4_nlx64) == 64 * sizeof(ggml_half) + QK4_NL * 32, "wrong iq4_nlx64 block size/padding"); -struct block_mxfp4x4 { - uint8_t e[4]; - uint8_t qs[QK_MXFP4 * 2]; +template struct block_mxfp4x { + ggml_half e[N]; // deltas for `N` mxfp4 blocks + uint8_t qs[QK_MXFP4 * N / 2]; // nibbles / quants for N mxfp4 blocks }; -static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); -struct block_mxfp4x8 { - uint8_t e[8]; - uint8_t qs[QK_MXFP4 * 4]; -}; -static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); +using block_mxfp4x4 = block_mxfp4x<4>; +using block_mxfp4x8 = block_mxfp4x<8>; +using block_mxfp4x16 = block_mxfp4x<16>; +using block_mxfp4x32 = block_mxfp4x<32>; +using block_mxfp4x64 = block_mxfp4x<64>; + +static_assert(sizeof(block_mxfp4x4) == 4 * sizeof(ggml_half) + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); +static_assert(sizeof(block_mxfp4x8) == 8 * sizeof(ggml_half) + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); +static_assert(sizeof(block_mxfp4x16) == 16 * sizeof(ggml_half) + QK_MXFP4 * 8, "wrong mxfp4x16 block size/padding"); +static_assert(sizeof(block_mxfp4x32) == 32 * sizeof(ggml_half) + QK_MXFP4 * 16, "wrong mxfp4x32 block size/padding"); +static_assert(sizeof(block_mxfp4x64) == 64 * sizeof(ggml_half) + QK_MXFP4 * 32, "wrong mxfp4x64 block size/padding"); + #if defined(__cplusplus) extern "C" { @@ -246,6 +252,10 @@ void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -278,6 +288,10 @@ void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif // Native implementations @@ -352,6 +366,10 @@ void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -384,6 +402,10 @@ void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus)