diff --git a/include/infiniop/ops/gptq_gemm.h b/include/infiniop/ops/gptq_gemm.h new file mode 100644 index 000000000..d1f25a799 --- /dev/null +++ b/include/infiniop/ops/gptq_gemm.h @@ -0,0 +1,38 @@ +#ifndef __INFINIOP_GPTQ_GEMM_API_H__ +#define __INFINIOP_GPTQ_GEMM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopGptqGemmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateGptqGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit); + +__INFINI_C __export infiniStatus_t infiniopGetGptqGemmWorkspaceSize( + infiniopGptqGemmDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopGptqGemm( + infiniopGptqGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + const void *b_scale, + const void *b_zero, + const void *b_g_idx, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyGptqGemmDescriptor( + infiniopGptqGemmDescriptor_t desc); +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/compat.cuh b/src/infiniop/ops/gptq_gemm/cuda/compat.cuh new file mode 100644 index 000000000..1da7ebd14 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/compat.cuh @@ -0,0 +1,69 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _compat_cuh +#define _compat_cuh + +// 1. 包含CUDA核心运行时头文件(必加,提供CUDA基础类型定义) +#include + +// 2. 包含CUDA半精度浮点类型定义头文件(核心,定义half/half2/__half_raw) +#include + +namespace vllm { +namespace gptq { +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half *address, half val) { + unsigned int *address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2 *address, half2 val) { + unsigned int *address_as_ui = (unsigned int *)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2 *)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int *)&new_val)); + } while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half *address, half val) { + atomicAdd_half(address, val); +} + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2 *address, half2 val) { + atomicAdd_half2(address, val); +} +#endif + +#endif +#endif + +} // namespace gptq +} // namespace vllm +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh new file mode 100644 index 000000000..b93715198 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh @@ -0,0 +1,199 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include + +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" +#include "qdq_2.cuh" +#include "qdq_3.cuh" +#include "qdq_4.cuh" +#include "qdq_8.cuh" + +namespace vllm { +namespace gptq { + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_ROWS_8BIT 24 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size)-1) / (size)) + +#if defined(USE_ROCM) +#include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half *alpha, + const half *AP, int lda, const half *BP, int ldb, const half *beta, + half *CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half *a_ptr, + const half2 g_result) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half *a_ptr) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half *a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half *a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half *a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half *a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half *a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half *a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half *a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half *a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half *a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +} // namespace gptq +} // namespace vllm diff --git a/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh b/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh new file mode 100644 index 000000000..5b7e0bd88 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh @@ -0,0 +1,291 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama +*/ + +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +class MatrixView_half { +public: + const half *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half *data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2 *)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half *item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2 *ptr = (half2 *)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2 *ptr = (half2 *)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2 *ptr = (half2 *)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw { +public: + half *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half *data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2 *)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half *item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2 *)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2 *ptr = (half2 *)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row { +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +class MatrixView_q4_column { +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t *item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } +}; + +class MatrixView_q2_row { +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } +}; + +class MatrixView_q3_row { +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; + } + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } +}; + +class MatrixView_q8_row { +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } +}; + +} // namespace gptq +} // namespace vllm +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh new file mode 100644 index 000000000..c9bb50efc --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh @@ -0,0 +1,76 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16(uint32_t *q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh new file mode 100644 index 000000000..eefef3151 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh @@ -0,0 +1,149 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32(uint32_t *q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; + qa >>= 6; + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; + qb >>= 6; + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; + qc >>= 6; + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh new file mode 100644 index 000000000..29ca387ce --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh @@ -0,0 +1,122 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8(uint32_t *q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh new file mode 100644 index 000000000..fb9b8fc47 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh @@ -0,0 +1,35 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +__forceinline__ __device__ void shuffle_8bit_4(uint32_t *q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) { + dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + } + for (int i = 0; i < 4; i++) { + dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + } + + for (int i = 0; i < 4; i++) { + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); + } +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh new file mode 100644 index 000000000..acac37c81 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh @@ -0,0 +1,62 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +// 1. 包含CUDA核心运行时头文件(必加,提供CUDA基础类型定义) +#include + +// 2. 包含CUDA半精度浮点类型定义头文件(核心,定义half/half2/__half_raw) +#include + +namespace vllm { +namespace gptq { + +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +} // namespace gptq +} // namespace vllm +#endif diff --git a/src/infiniop/ops/gptq_gemm/gptq_gemm.h b/src/infiniop/ops/gptq_gemm/gptq_gemm.h new file mode 100644 index 000000000..9498b5b6d --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/gptq_gemm.h @@ -0,0 +1,52 @@ +#ifndef GPTQ_GEMM_H +#define GPTQ_GEMM_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::gptq_gemm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + GptqGemmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + GptqGemmInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t b_scales_desc, \ + infiniopTensorDescriptor_t b_zeros_desc, \ + infiniopTensorDescriptor_t b_g_idx_desc, \ + bool use_exllama, \ + int quant_bit); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, \ + const void *a, const void *b, const void *b_scale, const void *b_zero, const void *b_g_idx, \ + void *stream) const; \ + }; \ + } + +#endif // GPTQ_GEMM_H diff --git a/src/infiniop/ops/gptq_gemm/info.h b/src/infiniop/ops/gptq_gemm/info.h new file mode 100644 index 000000000..11c82d38d --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/info.h @@ -0,0 +1,75 @@ +#ifndef __GPTQ_GEMM_INFO_H__ +#define __GPTQ_GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op::gptq_gemm { + +class GptqGemmInfo { + GptqGemmInfo() = default; + +public: + // --- Data Type --- + infiniDtype_t dtype; + + // --- Shape Dimensions --- + size_t M, K, N, b_size_0; + int block_size, num_groups; + bool use_exllama; + int64_t quant_bit; + + static utils::Result createGptqGemmInfo( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit) { + + auto dtype = out_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16); + if (b_scales_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (b_zeros_desc->dtype() != INFINI_DTYPE_I32 || b_g_idx_desc->dtype() != INFINI_DTYPE_I32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t M = out_desc->shape()[0]; + size_t N = out_desc->shape()[1]; + size_t K = a_desc->shape()[1]; + size_t b_size_0 = b_desc->shape()[0]; + int block_size = 128; + int num_groups = K / block_size; + if (quant_bit != 4) { + throw std::runtime_error( + "quant_bit must be 4, but got " + std::to_string(quant_bit)); + } + + auto ndim = out_desc->ndim(); + CHECK_OR_RETURN(ndim == 2 + && a_desc->ndim() == ndim + && b_desc->ndim() == ndim + && b_scales_desc->ndim() == ndim + && b_zeros_desc->ndim() == ndim, + INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(b_scales_desc->shape()[1] == N + && static_cast(b_scales_desc->shape()[0]) == num_groups, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result(GptqGemmInfo{ + dtype, + M, K, N, b_size_0, + block_size, num_groups, + use_exllama, static_cast(quant_bit)}); + } +}; + +} // namespace op::gptq_gemm + +#endif // __GPTQ_GEMM_INFO_H__ diff --git a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu new file mode 100644 index 000000000..dd7b78dc1 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu @@ -0,0 +1,1745 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "gptq_gemm_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../reduce/cuda/reduce.cuh" +#include + +#include "../cuda/kernel.cuh" +namespace vllm { +namespace gptq { + +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half *, const uint32_t *, + const uint32_t *, const half *, + half *, const int, const int, + const int, const int, + const int *); + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_4bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_2bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; + } + + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_3bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_8bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; + } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) \ + return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) \ + return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) \ + return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) \ + return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; +} + +void gemm_half_q_half_cuda_part(const half *a, const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_q_perm, + half *c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit, cudaStream_t stream) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_8bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 8); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_4bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_3bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / 32 * 3; + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_2bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 2); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 2; p++) { + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +void reconstruct_exllama(const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_q_perm, + half *out, int height, int width, int groups, + int bit, cudaStream_t stream) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} + +INFINIOP_CUDA_KERNEL gemm_half_q_half_alt_4bit_kernel( + const half2 *__restrict__ vec, const uint32_t *__restrict__ mat, + half *__restrict__ mul, const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, const int *__restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + auto b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + auto h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + auto val = threadIdx.x / 8; + auto off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) { + mul[(b + m) * width + w] = __int2half_rn(0); + } + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + // res2 = {}; + res2 = __halves2half2(__float2half(0.0f), __float2half(0.0f)); +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +INFINIOP_CUDA_KERNEL gemm_half_q_half_alt_8bit_kernel( + const half2 *__restrict__ vec, const uint32_t *__restrict__ mat, + half *__restrict__ mul, const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, const int *__restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + auto b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + auto h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + threadIdx.x]; + } + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) { + mul[(b + m) * width + w] = __int2half_rn(0); + } + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + // res2 = {}; + res2 = __halves2half2(__float2half(0.0f), __float2half(0.0f)); +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +void gemm_half_q_half_alt(const half *a, const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, + half *c, int size_m, int size_n, int size_k, + int bit, cudaStream_t stream) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + kernel<<>>( + (const half2 *)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); +} + +template +INFINIOP_CUDA_KERNEL reconstruct_gptq_kernel(const uint32_t *__restrict__ w, + const half *__restrict__ w_scales, + const uint32_t *__restrict__ w_zeros, + const int *__restrict__ g_idx, + const int height, const int width, + const int group, + half *__restrict__ out) { + // Start of block + + auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + auto row = blockIdx.y * 32 / bit; + if (column >= width) { + return; + } + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half *out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_gptq_3bit_kernel( + const uint32_t *__restrict__ w, const half *__restrict__ w_scales, + const uint32_t *__restrict__ w_zeros, const int *__restrict__ g_idx, + const int height, const int width, const int group, + half *__restrict__ out) { + // Start of block + auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + auto row = blockIdx.y * 32; + if (column >= width) { + return; + } + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half *out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); + } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } +} + +void reconstruct_gptq(const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, half *out, + int height, int width, int groups, int bit, cudaStream_t stream) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); +} + +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half *a, + const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, + half *c, half *temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit, cudaStream_t stream) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit, stream); + } else { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit, stream); + } + + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit, stream); + } + + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit, stream); + } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit, stream); + } +} + +INFINIOP_CUDA_KERNEL shuffle_4bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } +} + +INFINIOP_CUDA_KERNEL shuffle_8bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } +} + +INFINIOP_CUDA_KERNEL shuffle_2bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +INFINIOP_CUDA_KERNEL shuffle_3bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } +} + +INFINIOP_CUDA_KERNEL make_sequential_4bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) { + return; + } + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +INFINIOP_CUDA_KERNEL make_sequential_2bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) { + return; + } + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +INFINIOP_CUDA_KERNEL make_sequential_3bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + auto w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) { + return; + } + auto w_new_row = blockIdx.y * 3; + auto q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); + } + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; +} + +INFINIOP_CUDA_KERNEL make_sequential_8bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) { + return; + } + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +} // namespace gptq +} // namespace vllm + +cublasStatus_t GptqGemmKernel(void *c, const void *a, const void *b, + const void *b_scales, const void *b_zeros, const void *b_g_idx, + int M, int K, int N, int num_groups, + bool use_exllama, int64_t bit, cublasHandle_t cublas_handle, void *workspace, cudaStream_t stream) { + + char *workspace_ptr = reinterpret_cast(workspace); + half *temp_dq = reinterpret_cast(workspace_ptr); // shape ? + + vllm::gptq::gemm_half_q_half_cuda( + cublas_handle, (const half *)a, + (const uint32_t *)b, + (const uint32_t *)b_zeros, + (const half *)b_scales, + (const int *)b_g_idx, + (half *)c, temp_dq, + M, + N, + K, + num_groups, + use_exllama, bit, stream); + return CUBLAS_STATUS_SUCCESS; +} + +namespace op::gptq_gemm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit) { + + auto info = GptqGemmInfo::createGptqGemmInfo(out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc, b_g_idx_desc, use_exllama, quant_bit); + + CHECK_RESULT(info); + + size_t workspace_size = b_desc->shape()[0] * 32 / static_cast(quant_bit) * b_desc->shape()[1] * infiniSizeOf(a_desc->dtype()); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + const void *b_scales, + const void *b_zeros, + const void *b_g_idx, + void *stream) const { + + int M = _info.M; + int K = _info.K; + int N = _info.N; + int num_groups = _info.num_groups; + bool use_exllama = _info.use_exllama; + int64_t quant_bit = _info.quant_bit; + + if (_info.dtype == INFINI_DTYPE_F16) { + CHECK_STATUS(_opaque->internal->useCublas( + (cudaStream_t)stream, + [&](cublasHandle_t handle) { + CHECK_CUBLAS( + GptqGemmKernel(out, a, b, + b_scales, b_zeros, + b_g_idx, M, K, N, num_groups, + use_exllama, quant_bit, handle, workspace, (cudaStream_t)stream)); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gptq_gemm::nvidia diff --git a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh new file mode 100644 index 000000000..e718932ea --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh @@ -0,0 +1,7 @@ +#ifndef __GPTQ_GEMM_NVIDIA_API_H__ +#define __GPTQ_GEMM_NVIDIA_API_H__ +#include "../gptq_gemm.h" + +DESCRIPTOR(nvidia) + +#endif // __GPTQ_GEMM_NVIDIA_API_H__ diff --git a/src/infiniop/ops/gptq_gemm/operator.cc b/src/infiniop/ops/gptq_gemm/operator.cc new file mode 100644 index 000000000..220dab07e --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/operator.cc @@ -0,0 +1,109 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gptq_gemm.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#include "nvidia/gptq_gemm_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateGptqGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gptq_gemm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc, b_g_idx_desc, use_exllama, quant_bit); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGetGptqGemmWorkspaceSize( + infiniopGptqGemmDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGptqGemm( + infiniopGptqGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + const void *b_scale, + const void *b_zero, + const void *b_g_idx, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, a, b, b_scale, b_zero, b_g_idx, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopDestroyGptqGemmDescriptor( + infiniopGptqGemmDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + DESTROY(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/test/infiniop/gptq_gemm.py b/test/infiniop/gptq_gemm.py new file mode 100644 index 000000000..dcb2cbfbf --- /dev/null +++ b/test/infiniop/gptq_gemm.py @@ -0,0 +1,345 @@ +import torch +import numpy +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # M, K, N, use_exllama, quant_bit, group_size + (1, 2048, 2048, True, 4, 128), + (1, 2048, 4096, False, 4, 128), + (1, 4096, 2048, False, 4, 128), + (8, 2048, 2048, False, 4, 128), + (8, 2048, 4096, False, 4, 128), + (8, 4096, 2048, False, 4, 128), + (128, 2048, 2048, False, 4, 128), + (128, 2048, 4096, False, 4, 128), + (128, 4096, 2048, False, 4, 128), +] + + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 5e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N): + assert bit == 4, "Reference dequantization only supports 4-bit" + group_size = K // scales.shape[0] + pack_factor = 32 // bit + + # unpack q_weight: (K//pack_factor, N) -> (K, N) + unpacked_q_weight = torch.empty( + q_weight.shape[0] * pack_factor, + q_weight.shape[1], + dtype=torch.uint8, + device=q_weight.device, + ) + for i in range(pack_factor): + unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F + + # unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N) + unpacked_q_zeros = torch.empty( + q_zeros.shape[0], + q_zeros.shape[1] * pack_factor, + dtype=torch.uint8, + device=q_zeros.device, + ) + for i in range(pack_factor): + unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F + + unpacked_q_zeros += 1 + unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype) + + scale_zeros = unpacked_q_zeros * scales # (num_groups, N) + + current_g_idx = torch.tensor( + [i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device + ) + + scale_mat = scales[current_g_idx] # (K, N) + scale_zeros_mat = scale_zeros[current_g_idx] # (K, N) + + # dequant: weight * scale - scale_zeros + dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat + + return dequantized_b.reshape(K, N) + + +def torch_gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit +): + K, N = a.shape[1], b_q_weight.shape[1] + + b_dequant = torch_dequantize( + b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N + ) + c = torch.matmul(a, b_dequant) + return c + + +def test( + handle, + device, + M, + K, + N, + use_exllama, + quant_bit, + group_size, + dtype=InfiniDtype.F16, + sync=None, +): + + print( + f"Testing Gptq Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, use_exllama:{use_exllama}, quant_bit:{quant_bit}, group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + b_fp = TestTensor((K, N), None, dtype, device) + + assert K % group_size == 0, "K must be divisible by group_size" + num_groups = K // group_size + use_shuffle = use_exllama + + if use_shuffle: + print(f"not support use_shuffle:{use_shuffle}") + return + else: + g_idx = torch.tensor( + [i // group_size for i in range(K)], + dtype=torch.int32, + device=b_fp.torch_tensor().device, + ) + b_shuffled = b_fp.torch_tensor()[g_idx] + + b_grouped = b_shuffled.reshape(num_groups, group_size, N) + + b_max = torch.max(b_grouped, dim=1, keepdim=True)[0] + b_min = torch.min(b_grouped, dim=1, keepdim=True)[0] + + scales = (b_max - b_min) / (2**quant_bit - 1) + scales = scales.clamp(min=1e-6) + + zeros_float = (-b_min / scales).round() + + q_b = ( + (b_grouped / scales + zeros_float) + .round() + .clamp(0, 2**quant_bit - 1) + .to(torch.uint8) + ) + + q_zeros_unpacked = zeros_float.to(torch.uint8) - 1 + + b_q_weight = pack_rows(q_b.reshape(K, N), quant_bit, K, N) + + q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N) + b_gptq_qzeros = pack_cols(q_zeros_unpacked, quant_bit, num_groups, N) + b_gptq_scales = scales.squeeze(1) + + A = TestTensor((M, K), None, dtype, device) + C = TestTensor((M, N), None, dtype, device) + + B = TestTensor( + b_q_weight.shape, + b_q_weight.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=b_q_weight, + ) + b_scales = TestTensor( + b_gptq_scales.shape, + b_gptq_scales.stride(), + dtype, + device, + mode="manual", + set_tensor=b_gptq_scales, + ) + b_zeros = TestTensor( + b_gptq_qzeros.shape, + b_gptq_qzeros.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=b_gptq_qzeros, + ) + b_g_idx = TestTensor( + (K,), g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx + ) + + if sync is not None: + sync() + + ans = torch_gptq_gemm( + A.torch_tensor(), + B.torch_tensor(), + b_zeros.torch_tensor(), + b_scales.torch_tensor(), + b_g_idx.torch_tensor(), + use_shuffle, + quant_bit, + ) + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateGptqGemmDescriptor( + handle, + ctypes.byref(descriptor), + C.descriptor, + A.descriptor, + B.descriptor, + b_scales.descriptor, + b_zeros.descriptor, + b_g_idx.descriptor, + use_exllama, + quant_bit, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + + for tensor in [C, A, B, b_scales, b_zeros, b_g_idx]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetGptqGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, A.device) + + def lib_gptq_gemm(): + check_error( + LIBINFINIOP.infiniopGptqGemm( + descriptor, + workspace.data(), + workspace_size.value, + C.data(), + A.data(), + B.data(), + b_scales.data(), + b_zeros.data(), + b_g_idx.data(), + None, + ) + ) + + lib_gptq_gemm() + + if sync is not None: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(C.actual_tensor(), ans, atol=atol, rtol=rtol) + + assert torch.allclose(C.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: gptq_gemm_torch(A.torch_tensor(), B.torch_tensor(), b_scales.torch_tensor(), b_zeros.torch_tensor(), b_g_idx.torch_tensor(), group_size, quant_bit), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_gptq_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyGptqGemmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 1c90feb22..546b96ff3 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1131,6 +1131,48 @@ def per_tensor_dequant_int8_(lib): ] +@OpRegister.operator +def gptq_gemm_(lib): + lib.infiniopCreateGptqGemmDescriptor.restype = c_int32 + lib.infiniopCreateGptqGemmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_bool, + c_int32, + ] + + lib.infiniopGetGptqGemmWorkspaceSize.restype = c_int32 + lib.infiniopGetGptqGemmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopGptqGemm.restype = c_int32 + lib.infiniopGptqGemm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyGptqGemmDescriptor.restype = c_int32 + lib.infiniopDestroyGptqGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32