From 0aeee9e82cdab73cb459f688c513677e606c0bf1 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 7 Apr 2026 17:06:27 +0800 Subject: [PATCH 1/2] issue/1118: qyblas error --- include/infiniop/ops/gptq_qyblas_gemm.h | 37 ++ .../ops/gptq_qyblas_gemm/gptq_qyblas_gemm.h | 49 +++ src/infiniop/ops/gptq_qyblas_gemm/info.h | 92 +++++ .../nvidia/gptq_qyblas_gemm_nvidia.cu | 223 +++++++++++ .../nvidia/gptq_qyblas_gemm_nvidia.cuh | 7 + src/infiniop/ops/gptq_qyblas_gemm/operator.cc | 99 +++++ test/infiniop/gptq_qyblas_gemm.py | 360 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 40 ++ 8 files changed, 907 insertions(+) create mode 100644 include/infiniop/ops/gptq_qyblas_gemm.h create mode 100644 src/infiniop/ops/gptq_qyblas_gemm/gptq_qyblas_gemm.h create mode 100644 src/infiniop/ops/gptq_qyblas_gemm/info.h create mode 100644 src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu create mode 100644 src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cuh create mode 100644 src/infiniop/ops/gptq_qyblas_gemm/operator.cc create mode 100644 test/infiniop/gptq_qyblas_gemm.py diff --git a/include/infiniop/ops/gptq_qyblas_gemm.h b/include/infiniop/ops/gptq_qyblas_gemm.h new file mode 100644 index 000000000..bb105132c --- /dev/null +++ b/include/infiniop/ops/gptq_qyblas_gemm.h @@ -0,0 +1,37 @@ +#ifndef __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__ +#define __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopGptqQyblasGemmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateGptqQyblasGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqQyblasGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc); + +__INFINI_C __export infiniStatus_t infiniopGetGptqQyblasGemmWorkspaceSize( + infiniopGptqQyblasGemmDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopGptqQyblasGemm( + infiniopGptqQyblasGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scale, + void *b_zero, + int64_t quant_type, + int64_t bit, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyGptqQyblasGemmDescriptor( + infiniopGptqQyblasGemmDescriptor_t desc); +#endif diff --git a/src/infiniop/ops/gptq_qyblas_gemm/gptq_qyblas_gemm.h b/src/infiniop/ops/gptq_qyblas_gemm/gptq_qyblas_gemm.h new file mode 100644 index 000000000..456e540e1 --- /dev/null +++ b/src/infiniop/ops/gptq_qyblas_gemm/gptq_qyblas_gemm.h @@ -0,0 +1,49 @@ +#ifndef GPTQ_QYBLAS_GEMM_H +#define GPTQ_QYBLAS_GEMM_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::gptq_qyblas_gemm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + GptqQyblasGemmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + GptqQyblasGemmInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t b_scales_desc, \ + infiniopTensorDescriptor_t b_zeros_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, \ + const void *a, const void *b, void *b_scale, void *b_zero, int64_t quant_type, int64_t bit, \ + void *stream) const; \ + }; \ + } + +#endif // GPTQ_QYBLAS_GEMM_H diff --git a/src/infiniop/ops/gptq_qyblas_gemm/info.h b/src/infiniop/ops/gptq_qyblas_gemm/info.h new file mode 100644 index 000000000..ee3a85c07 --- /dev/null +++ b/src/infiniop/ops/gptq_qyblas_gemm/info.h @@ -0,0 +1,92 @@ +#ifndef __GPTQ_QYBLAS_GEMM_INFO_H__ +#define __GPTQ_QYBLAS_GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op::gptq_qyblas_gemm { + +class GptqQyblasGemmInfo { + GptqQyblasGemmInfo() = default; + +public: + infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype; + size_t M, K, N, scales_size_0, scales_size_1; + ptrdiff_t lda, ldb, result_ld; + bool transpose_mat_1, transpose_mat_2, transpose_result; + + static utils::Result createGptqQyblasGemmInfo( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc) { + + auto dtype = a_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + + const infiniDtype_t weight_dtype = b_desc->dtype(); + CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8); + + const infiniDtype_t scales_dtype = b_scales_desc->dtype(); + const infiniDtype_t zeros_dtype = b_zeros_desc->dtype(); + const infiniDtype_t out_dtype = out_desc->dtype(); + + size_t M = out_desc->shape()[0]; + size_t N = out_desc->shape()[1]; + size_t K = a_desc->shape()[1]; + + size_t scales_size_0 = b_scales_desc->shape()[0]; + size_t scales_size_1 = b_scales_desc->shape()[1]; + + auto ndim = out_desc->ndim(); + CHECK_OR_RETURN(ndim == 2 + && a_desc->ndim() == ndim + && b_desc->ndim() == ndim + && b_scales_desc->ndim() == ndim + && b_zeros_desc->ndim() == ndim, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + bool transpose_result = false; + if (out_desc->strides()[0] == 1 && out_desc->strides()[1] >= std::max(1, out_desc->shape()[0])) { + transpose_result = true; + } else if (out_desc->strides()[1] == 1 && out_desc->strides()[0] >= std::max(1, out_desc->shape()[1])) { + transpose_result = false; + } else { + transpose_result = false; + } + bool transpose_mat_1 = false; + if (a_desc->strides()[0] == 1 && a_desc->strides()[1] >= std::max(1, a_desc->shape()[0])) { + transpose_mat_1 = true; + } else if (a_desc->strides()[1] == 1 && a_desc->strides()[0] >= std::max(1, a_desc->shape()[1])) { + transpose_mat_1 = false; + } else { + transpose_mat_1 = false; + } + bool transpose_mat_2 = false; + if (b_desc->strides()[0] == 1 && b_desc->strides()[1] >= std::max(1, b_desc->shape()[0])) { + transpose_mat_2 = true; + } else if (b_desc->strides()[1] == 1 && b_desc->strides()[0] >= std::max(1, b_desc->shape()[1])) { + transpose_mat_2 = false; + } else { + transpose_mat_2 = false; + } + + ptrdiff_t lda = a_desc->strides()[transpose_mat_1 ? 1 : 0]; + ptrdiff_t ldb = b_desc->strides()[transpose_mat_2 ? 1 : 0]; + ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 1 : 0]; + + return utils::Result(GptqQyblasGemmInfo{ + dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype, + M, K, N, scales_size_0, scales_size_1, + lda, ldb, result_ld, + transpose_mat_1, transpose_mat_2, transpose_result}); + } +}; + +} // namespace op::gptq_qyblas_gemm + +#endif // __GPTQ_QYBLAS_GEMM_INFO_H__ diff --git a/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu b/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu new file mode 100644 index 000000000..654d9738c --- /dev/null +++ b/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu @@ -0,0 +1,223 @@ +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "dlblas_ext.h" +#include "gptq_qyblas_gemm_nvidia.cuh" + +namespace op::gptq_qyblas_gemm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc) { + + auto info = GptqQyblasGemmInfo::createGptqQyblasGemmInfo(out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc); + + CHECK_RESULT(info); + + size_t workspace_size = 0; + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *b_zeros, + int64_t quant_type, + int64_t bit, + void *stream) const { + + int64_t M = static_cast(_info.M); + int64_t K = static_cast(_info.K); + int64_t N = static_cast(_info.N); + int64_t scales_size_0 = static_cast(_info.scales_size_0); + int64_t scales_size_1 = static_cast(_info.scales_size_1); + int64_t lda = static_cast(_info.lda); + int64_t ldb = static_cast(_info.ldb); + int64_t result_ld = static_cast(_info.result_ld); + bool transpose_mat_1 = _info.transpose_mat_1; + bool transpose_mat_2 = _info.transpose_mat_2; + + cudaDataType_t computeType_ = (cudaDataType_t)CUDA_R_32F; + cudaDataType_t kernel_Atype_, kernel_Btype_, kernel_Ctype_, kernel_Stype_, kernel_Ztype_; + + switch (_info.dtype) { + case INFINI_DTYPE_F16: + kernel_Atype_ = CUDA_R_16F; + break; + case INFINI_DTYPE_BF16: + kernel_Atype_ = CUDA_R_16BF; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (quant_type == 0) { + if (8 == bit) { + kernel_Atype_ = (cudaDataType_t)CUDA_R_8U; + } + + if (4 == bit) { + kernel_Atype_ = (cudaDataType_t)CUDA_R_4U; + K = K * 2; + } + } + + switch (_info.weight_dtype) { + case INFINI_DTYPE_F8: + kernel_Btype_ = (cudaDataType_t)CUDA_R_8F_E4M3; + break; + case INFINI_DTYPE_U8: + kernel_Btype_ = CUDA_R_8U; + break; + case INFINI_DTYPE_I8: + kernel_Btype_ = CUDA_R_8I; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + switch (_info.out_dtype) { + case INFINI_DTYPE_F16: + kernel_Ctype_ = CUDA_R_16F; + break; + case INFINI_DTYPE_BF16: + kernel_Ctype_ = CUDA_R_16BF; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + switch (_info.scales_dtype) { + case INFINI_DTYPE_F32: + kernel_Stype_ = CUDA_R_32F; + break; + case INFINI_DTYPE_F16: + kernel_Stype_ = CUDA_R_16F; + break; + case INFINI_DTYPE_BF16: + kernel_Stype_ = CUDA_R_16BF; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + switch (_info.zeros_dtype) { + case INFINI_DTYPE_F32: + kernel_Ztype_ = CUDA_R_32F; + break; + case INFINI_DTYPE_F16: + kernel_Ztype_ = CUDA_R_16F; + break; + case INFINI_DTYPE_BF16: + kernel_Ztype_ = CUDA_R_16BF; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + float alpha = 1.0f; + float beta = 0.0f; + + dlblasExtQuantParametersV2_t extParameters; + + if (quant_type == 0) { + extParameters.a_group_size_m = M / scales_size_0; + extParameters.a_group_size_k = K / scales_size_1; + extParameters.a_zeropoints_type = kernel_Ztype_; + extParameters.a_zeropoints = b_zeros; + extParameters.a_scales_type = kernel_Stype_; + extParameters.a_scales = b_scales; + } else if (quant_type == 1) { + extParameters.a_group_size_m = 1; + extParameters.a_group_size_k = K; + extParameters.a_zeropoints = nullptr; + extParameters.a_scales_type = kernel_Stype_; + extParameters.a_scales = b_scales; + + } else if (quant_type == 2 || quant_type == 3) { + // calculate block_shape according weight/scales shape + int block_shape = 128; + while ((N + block_shape - 1) / block_shape < scales_size_0) { + block_shape /= 2; + if (block_shape < 32) { + fprintf(stderr, + "INTERNAL ASSERT FAILED: block_shape >= 32\n" + "Invalid fp blockwise linear arguments. Weight: [%d, %d]. Scales: [%d, %d].\n", + (int)N, (int)K, (int)scales_size_0, (int)scales_size_1); + abort(); + } + } + if (!((K + block_shape - 1) / block_shape == scales_size_1)) { + fprintf(stderr, + "CHECK FAILED: (K + block_shape - 1) / block_shape == scales_size_1\n"); + abort(); + } + extParameters.a_group_size_m = block_shape; + extParameters.a_group_size_k = block_shape; + extParameters.a_scales_type = kernel_Stype_; + extParameters.a_zeropoints = nullptr; + extParameters.a_scales = b_scales; + } + + cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N; + printf("a=%s, b=%s, c=%s\n", + _info.transpose_mat_1 ? "true" : "false", + _info.transpose_mat_2 ? "true" : "false", + _info.transpose_result ? "true" : "false"); + printf("M-K-N:[%ld, %ld, %ld], lda-ldb-ldc:[%ld, %ld, %ld]\n", M, K, N, lda, ldb, result_ld); + printf("quant type:%ld, bit:%ld, block_shape:%d\n", quant_type, bit, extParameters.a_group_size_m); + + if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) { + CHECK_STATUS(_opaque->internal->useCublas( + (cudaStream_t)stream, + [&](cublasHandle_t handle) { + CHECK_CUBLAS( + dlblasGemmExV2(handle, + transa, + transb, + N, + M, + K, + &alpha, + b, + kernel_Btype_, + ldb, + a, + kernel_Atype_, + lda, + &beta, + out, + kernel_Ctype_, + result_ld, + computeType_, + CUBLAS_GEMM_DEFAULT_TENSOR_OP, + &extParameters)); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gptq_qyblas_gemm::nvidia diff --git a/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cuh b/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cuh new file mode 100644 index 000000000..b489858d9 --- /dev/null +++ b/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cuh @@ -0,0 +1,7 @@ +#ifndef __GPTQ_QYBLAS_GEMM_NVIDIA_API_H__ +#define __GPTQ_QYBLAS_GEMM_NVIDIA_API_H__ +#include "../gptq_qyblas_gemm.h" + +DESCRIPTOR(nvidia) + +#endif // __GPTQ_QYBLAS_GEMM_NVIDIA_API_H__ diff --git a/src/infiniop/ops/gptq_qyblas_gemm/operator.cc b/src/infiniop/ops/gptq_qyblas_gemm/operator.cc new file mode 100644 index 000000000..297bfd408 --- /dev/null +++ b/src/infiniop/ops/gptq_qyblas_gemm/operator.cc @@ -0,0 +1,99 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gptq_qyblas_gemm.h" + +#if defined ENABLE_QY_API +#include "nvidia/gptq_qyblas_gemm_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateGptqQyblasGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqQyblasGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gptq_qyblas_gemm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc); + + switch (handle->device) { + +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGetGptqQyblasGemmWorkspaceSize( + infiniopGptqQyblasGemmDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGptqQyblasGemm( + infiniopGptqQyblasGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scale, + void *b_zero, + int64_t quant_type, + int64_t bit, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, a, b, b_scale, b_zero, quant_type, bit, stream); + + switch (desc->device_type) { + +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopDestroyGptqQyblasGemmDescriptor( + infiniopGptqQyblasGemmDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_QY_API + DESTROY(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/test/infiniop/gptq_qyblas_gemm.py b/test/infiniop/gptq_qyblas_gemm.py new file mode 100644 index 000000000..9230a65ae --- /dev/null +++ b/test/infiniop/gptq_qyblas_gemm.py @@ -0,0 +1,360 @@ +import torch +import numpy +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + to_torch_dtype, +) +from enum import Enum, auto +import itertools + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +# Test configurations + +BLOCK_SIZE = [[128, 128]] +M_list = [1, 7, 83, 512, 2048] +N_list = [128, 512, 1024, 4096, 7748, 13824] +K_list = [256, 4096, 5120, 3884, 13824] +SEEDS = 0 + +def to_iter(x): + return x if isinstance(x, (list, tuple)) else (x,) + + +_TEST_CASES = list( + itertools.product( + to_iter(M_list), + to_iter(K_list), + to_iter(N_list), + to_iter(BLOCK_SIZE), + ) +) + + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16] + + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def native_w8a16_block_int8_matmul( + A, + B, + Bs, + block_size, + output_dtype: torch.float16, +) -> torch.Tensor: + """Matrix multiplication with block-wise quantization using native torch.""" + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def native_w8a16_block_fp8_matmul( + A, + B, + Bs, + block_size, + output_dtype: torch.float16, +) -> torch.Tensor: + return native_w8a16_block_int8_matmul(A, B, Bs, block_size, output_dtype) + + +def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + #A_fp32 = A_fp32.fill_(1) + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + #B_fp32 = B_fp32.fill_(1) + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + #As = As.fill_(1) + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + #Bs = Bs.fill_(1.5) + #ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + # out_dtype) + ref_out = native_w8a16_block_fp8_matmul(A_fp32.to(torch.bfloat16), B_fp8, Bs, block_size, out_dtype) + #out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + B_fp8_T = B_fp8.t() + #print('B_fp8_T', B_fp8_T.size(), B_fp8_T) + + Bs_T = Bs + quant_type = 3 + bit = 8 + return ref_out, A_fp32.to(torch.bfloat16), B_fp8_T, Bs_T, Bs_T, quant_type, bit + + +def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + A_fpb16 = torch.rand(M, K, dtype=torch.float32) / 10 + + + #A_fp32 = A_fp32.fill_(1) + #A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max + #B_fp32 = B_fp32.fill_(1) + B_int8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + A_fpb16 =A_fpb16.to(torch.float16) + + #As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + #As = As.fill_(1) + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + #Bs = Bs.fill_(1.5) + #ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + ref_out = native_w8a16_block_fp8_matmul(A_fpb16, B_int8, Bs, block_size, out_dtype) + #a_q, a_s = native_per_token_group_quant_int8(A_fpb16, block_k) + #ref_out = native_w8a8_block_int8_matmul(a_q, B_int8, a_s, Bs, block_size, output_dtype=A_fpb16.dtype) + ##out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + #print('Bs', Bs.size(), Bs.dtype) + quant_type = 3 + bit = 8 + return ref_out, A_fpb16, B_int8, Bs, Bs, quant_type, bit + + +def test_int8( + handle, + device, + M, + K, + N, + block_size, + dtype=InfiniDtype.BF16, + sync=None, +): + + print( + f"Testing int8 Gptq Qyblas Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, block_size:{block_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + out_dtype = to_torch_dtype(dtype) + ans, a, b_orig, b_scales, b_zeros, quant_type, bit = test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, SEEDS) + b = b_orig.t() + + A = TestTensor( + a.shape, + a.stride(), + InfiniDtype.F16, + device, + mode="manual", + set_tensor=a, + ) + B_orig = TestTensor( + b_orig.shape, + b_orig.stride(), + InfiniDtype.I8, + device, + mode="manual", + set_tensor=b_orig, + ) + B = TestTensor( + b.shape, + b.stride(), + InfiniDtype.I8, + device, + mode="manual", + set_tensor=b, + ) + b_scales = TestTensor( + b_scales.shape, + b_scales.stride(), + InfiniDtype.F32, + device, + mode="manual", + set_tensor=b_scales, + ) + b_zeros = TestTensor( + b_zeros.shape, + b_zeros.stride(), + InfiniDtype.F32, + device, + mode="manual", + set_tensor=b_zeros, + ) + out = TestTensor( + ans.shape, + None, + dtype, + device, + ) + + print("a: ", A.torch_tensor().shape, A.torch_tensor().stride(), A.torch_tensor().dtype) + print("b: ", B.torch_tensor().shape, B.torch_tensor().stride(), B.torch_tensor().dtype) + print("scales: ", b_scales.torch_tensor().shape, b_scales.torch_tensor().dtype) + print("zeros: ", b_zeros.torch_tensor().shape, b_zeros.torch_tensor().dtype) + print("out: ", out.torch_tensor().shape, out.torch_tensor().dtype) + if sync is not None: + sync() + + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateGptqQyblasGemmDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + A.descriptor, + B.descriptor, + b_scales.descriptor, + b_zeros.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + + for tensor in [out, A, B, b_scales, b_zeros]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetGptqQyblasGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, A.device) + + def lib_gptq_qyblas_gemm(): + check_error( + LIBINFINIOP.infiniopGptqQyblasGemm( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + A.data(), + B.data(), + b_scales.data(), + b_zeros.data(), + quant_type, + bit, + None, + ) + ) + + lib_gptq_qyblas_gemm() + + if sync is not None: + sync() + + tmpa = out.torch_tensor().to(torch.float32).detach().to('cpu').numpy().flatten() + tmpb = ans.to(torch.float32).to('cpu').detach().numpy().flatten() + + atol = max(abs(tmpa - tmpb)) + + rtol = atol / (max(abs(tmpb)) + 1e-8) + + + print("absolute error:%.4e"%(atol)) + print("relative error:%.4e"%(rtol)) + print(out.torch_tensor().device, ans.device) + # print(out.torch_tensor()) + # print(ans) + ans = ans.to(out.torch_tensor().device) + rel_diff = (torch.mean( + torch.abs(out.torch_tensor().to(torch.float32) - ans.to(torch.float32))) / + torch.mean(torch.abs(ans.to(torch.float32)))) + print(rel_diff) + assert rel_diff < 0.05 + + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: native_w8a16_block_fp8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_gptq_qyblas_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyGptqQyblasGemmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test_int8, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 9a91c931c..26ade4c2a 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -13,6 +13,7 @@ c_double, c_int64, c_bool, + c_int64, ) @@ -1313,6 +1314,45 @@ def per_tensor_dequant_int8_(lib): ] +@OpRegister.operator +def gptq_qyblas_gemm_(lib): + lib.infiniopCreateGptqQyblasGemmDescriptor.restype = c_int32 + lib.infiniopCreateGptqQyblasGemmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetGptqQyblasGemmWorkspaceSize.restype = c_int32 + lib.infiniopGetGptqQyblasGemmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopGptqQyblasGemm.restype = c_int32 + lib.infiniopGptqQyblasGemm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_int64, + c_int64, + c_void_p, + ] + + lib.infiniopDestroyGptqQyblasGemmDescriptor.restype = c_int32 + lib.infiniopDestroyGptqQyblasGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32 From d194d478939aad901be4daeedb94bf5500b6af7e Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 10 Apr 2026 10:50:40 +0800 Subject: [PATCH 2/2] issue/1118: success qy int8 test --- src/infiniop/ops/gptq_qyblas_gemm/info.h | 6 +- .../nvidia/gptq_qyblas_gemm_nvidia.cu | 19 +- test/infiniop/gptq_qyblas_gemm.py | 207 +++++------------- 3 files changed, 64 insertions(+), 168 deletions(-) diff --git a/src/infiniop/ops/gptq_qyblas_gemm/info.h b/src/infiniop/ops/gptq_qyblas_gemm/info.h index ee3a85c07..86bfdf35e 100644 --- a/src/infiniop/ops/gptq_qyblas_gemm/info.h +++ b/src/infiniop/ops/gptq_qyblas_gemm/info.h @@ -12,7 +12,7 @@ class GptqQyblasGemmInfo { GptqQyblasGemmInfo() = default; public: - infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype; + infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype; size_t M, K, N, scales_size_0, scales_size_1; ptrdiff_t lda, ldb, result_ld; bool transpose_mat_1, transpose_mat_2, transpose_result; @@ -27,13 +27,13 @@ class GptqQyblasGemmInfo { auto dtype = a_desc->dtype(); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + CHECK_DTYPE(dtype, out_desc->dtype()); const infiniDtype_t weight_dtype = b_desc->dtype(); CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8); const infiniDtype_t scales_dtype = b_scales_desc->dtype(); const infiniDtype_t zeros_dtype = b_zeros_desc->dtype(); - const infiniDtype_t out_dtype = out_desc->dtype(); size_t M = out_desc->shape()[0]; size_t N = out_desc->shape()[1]; @@ -80,7 +80,7 @@ class GptqQyblasGemmInfo { ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 1 : 0]; return utils::Result(GptqQyblasGemmInfo{ - dtype, weight_dtype, scales_dtype, zeros_dtype, out_dtype, + dtype, weight_dtype, scales_dtype, zeros_dtype, M, K, N, scales_size_0, scales_size_1, lda, ldb, result_ld, transpose_mat_1, transpose_mat_2, transpose_result}); diff --git a/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu b/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu index 654d9738c..12137951f 100644 --- a/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu +++ b/src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu @@ -1,3 +1,4 @@ +#if defined ENABLE_QY_API #include "../../../devices/nvidia/nvidia_handle.cuh" #include "dlblas_ext.h" #include "gptq_qyblas_gemm_nvidia.cuh" @@ -93,16 +94,7 @@ infiniStatus_t Descriptor::calculate(void *workspace, return INFINI_STATUS_BAD_TENSOR_DTYPE; } - switch (_info.out_dtype) { - case INFINI_DTYPE_F16: - kernel_Ctype_ = CUDA_R_16F; - break; - case INFINI_DTYPE_BF16: - kernel_Ctype_ = CUDA_R_16BF; - break; - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + kernel_Ctype_ = kernel_Atype_; switch (_info.scales_dtype) { case INFINI_DTYPE_F32: @@ -178,12 +170,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N; - printf("a=%s, b=%s, c=%s\n", - _info.transpose_mat_1 ? "true" : "false", - _info.transpose_mat_2 ? "true" : "false", - _info.transpose_result ? "true" : "false"); - printf("M-K-N:[%ld, %ld, %ld], lda-ldb-ldc:[%ld, %ld, %ld]\n", M, K, N, lda, ldb, result_ld); - printf("quant type:%ld, bit:%ld, block_shape:%d\n", quant_type, bit, extParameters.a_group_size_m); if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) { CHECK_STATUS(_opaque->internal->useCublas( @@ -221,3 +207,4 @@ infiniStatus_t Descriptor::calculate(void *workspace, } } // namespace op::gptq_qyblas_gemm::nvidia +#endif diff --git a/test/infiniop/gptq_qyblas_gemm.py b/test/infiniop/gptq_qyblas_gemm.py index 9230a65ae..8adb67a94 100644 --- a/test/infiniop/gptq_qyblas_gemm.py +++ b/test/infiniop/gptq_qyblas_gemm.py @@ -29,9 +29,11 @@ # Test configurations BLOCK_SIZE = [[128, 128]] -M_list = [1, 7, 83, 512, 2048] -N_list = [128, 512, 1024, 4096, 7748, 13824] -K_list = [256, 4096, 5120, 3884, 13824] +M_list = [1, 7]#, 83, 512, 2048] +N_list = [128, 512]#, 1024, 4096, 7748, 13824] +K_list = [256, 4096]#, 5120, 3884, 13824] +_WEIGHT_DTYPES = [InfiniDtype.I8] + SEEDS = 0 def to_iter(x): @@ -44,12 +46,13 @@ def to_iter(x): to_iter(K_list), to_iter(N_list), to_iter(BLOCK_SIZE), + to_iter(_WEIGHT_DTYPES), ) ) # Data types used for testing -_TENSOR_DTYPES = [InfiniDtype.F16] +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] DEBUG = False @@ -108,164 +111,82 @@ def native_w8a16_block_int8_matmul( return C -def native_w8a16_block_fp8_matmul( - A, - B, - Bs, - block_size, - output_dtype: torch.float16, -) -> torch.Tensor: - return native_w8a16_block_int8_matmul(A, B, Bs, block_size, output_dtype) - - -def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - #A_fp32 = A_fp32.fill_(1) - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - #B_fp32 = B_fp32.fill_(1) - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - block_n, block_k = block_size[0], block_size[1] - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - - As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale - #As = As.fill_(1) - Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - #Bs = Bs.fill_(1.5) - #ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - # out_dtype) - ref_out = native_w8a16_block_fp8_matmul(A_fp32.to(torch.bfloat16), B_fp8, Bs, block_size, out_dtype) - #out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - - B_fp8_T = B_fp8.t() - #print('B_fp8_T', B_fp8_T.size(), B_fp8_T) - - Bs_T = Bs - quant_type = 3 - bit = 8 - return ref_out, A_fp32.to(torch.bfloat16), B_fp8_T, Bs_T, Bs_T, quant_type, bit - - -def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): - torch.manual_seed(seed) - factor_for_scale = 1e-2 - int8_info = torch.iinfo(torch.int8) - int8_max, int8_min = int8_info.max, int8_info.min - - A_fpb16 = torch.rand(M, K, dtype=torch.float32) / 10 - - - #A_fp32 = A_fp32.fill_(1) - #A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max - #B_fp32 = B_fp32.fill_(1) - B_int8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) - - block_n, block_k = block_size[0], block_size[1] - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - - A_fpb16 =A_fpb16.to(torch.float16) - - #As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale - #As = As.fill_(1) - Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - #Bs = Bs.fill_(1.5) - #ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - - ref_out = native_w8a16_block_fp8_matmul(A_fpb16, B_int8, Bs, block_size, out_dtype) - #a_q, a_s = native_per_token_group_quant_int8(A_fpb16, block_k) - #ref_out = native_w8a8_block_int8_matmul(a_q, B_int8, a_s, Bs, block_size, output_dtype=A_fpb16.dtype) - ##out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - #print('Bs', Bs.size(), Bs.dtype) - quant_type = 3 - bit = 8 - return ref_out, A_fpb16, B_int8, Bs, Bs, quant_type, bit - - -def test_int8( +def test( handle, device, M, K, N, block_size, + weight_dtype=InfiniDtype.I8, dtype=InfiniDtype.BF16, sync=None, ): print( - f"Testing int8 Gptq Qyblas Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, block_size:{block_size}, dtype:{InfiniDtypeNames[dtype]}" + f"Testing int8 Gptq Qyblas Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, block_size:{block_size}, weight dtype:{InfiniDtypeNames[weight_dtype]}, dtype:{InfiniDtypeNames[dtype]}" ) - out_dtype = to_torch_dtype(dtype) - ans, a, b_orig, b_scales, b_zeros, quant_type, bit = test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, SEEDS) - b = b_orig.t() - + quant_type = 3 + bit = 8 + + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + A = TestTensor( - a.shape, - a.stride(), - InfiniDtype.F16, - device, - mode="manual", - set_tensor=a, - ) - B_orig = TestTensor( - b_orig.shape, - b_orig.stride(), - InfiniDtype.I8, - device, - mode="manual", - set_tensor=b_orig, - ) - B = TestTensor( - b.shape, - b.stride(), - InfiniDtype.I8, + (M, K), + None, + dtype, device, - mode="manual", - set_tensor=b, ) + if weight_dtype == InfiniDtype.I8: + B_orig = TestTensor( + (N, K), + None, + weight_dtype, + device, + randint_low=int8_min, + randint_high=int8_max, + ) + B_torch = B_orig.torch_tensor().t() + B = TestTensor( + (K, N), + B_torch.stride(), + weight_dtype, + device, + mode="manual", + set_tensor=B_torch, + ) + b_scales = TestTensor( - b_scales.shape, - b_scales.stride(), + (n_tiles, k_tiles), + None, InfiniDtype.F32, device, - mode="manual", - set_tensor=b_scales, ) + b_zeros = TestTensor( - b_zeros.shape, - b_zeros.stride(), + (n_tiles, k_tiles), + None, InfiniDtype.F32, device, - mode="manual", - set_tensor=b_zeros, + mode="zeros", ) + out = TestTensor( - ans.shape, + (M, N), None, dtype, device, + mode="zeros", ) - - print("a: ", A.torch_tensor().shape, A.torch_tensor().stride(), A.torch_tensor().dtype) - print("b: ", B.torch_tensor().shape, B.torch_tensor().stride(), B.torch_tensor().dtype) - print("scales: ", b_scales.torch_tensor().shape, b_scales.torch_tensor().dtype) - print("zeros: ", b_zeros.torch_tensor().shape, b_zeros.torch_tensor().dtype) - print("out: ", out.torch_tensor().shape, out.torch_tensor().dtype) + if sync is not None: sync() - descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateGptqQyblasGemmDescriptor( @@ -278,7 +199,6 @@ def test_int8( b_zeros.descriptor, ) ) - # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel for tensor in [out, A, B, b_scales, b_zeros]: @@ -314,31 +234,20 @@ def lib_gptq_qyblas_gemm(): if sync is not None: sync() - tmpa = out.torch_tensor().to(torch.float32).detach().to('cpu').numpy().flatten() - tmpb = ans.to(torch.float32).to('cpu').detach().numpy().flatten() + out_dtype = to_torch_dtype(dtype) + ans = native_w8a16_block_int8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype) - atol = max(abs(tmpa - tmpb)) - - rtol = atol / (max(abs(tmpb)) + 1e-8) - - - print("absolute error:%.4e"%(atol)) - print("relative error:%.4e"%(rtol)) - print(out.torch_tensor().device, ans.device) - # print(out.torch_tensor()) - # print(ans) - ans = ans.to(out.torch_tensor().device) rel_diff = (torch.mean( - torch.abs(out.torch_tensor().to(torch.float32) - ans.to(torch.float32))) / + torch.abs(out.actual_tensor().to(torch.float32) - ans.to(torch.float32))) / torch.mean(torch.abs(ans.to(torch.float32)))) - print(rel_diff) + assert rel_diff < 0.05 # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: native_w8a16_block_fp8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: native_w8a16_block_int8_matmul(A.torch_tensor(), B_orig.torch_tensor(), b_scales.torch_tensor(), block_size, out_dtype), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_gptq_qyblas_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on @@ -355,6 +264,6 @@ def lib_gptq_qyblas_gemm(): NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): - test_operator(device, test_int8, _TEST_CASES, _TENSOR_DTYPES) + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m")