Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions include/infiniop/ops/gptq_qyblas_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__
#define __INFINIOP_GPTQ_QYBLAS_GEMM_API_H__

#include "../operator_descriptor.h"
#include <cstdint>

typedef struct InfiniopDescriptor *infiniopGptqQyblasGemmDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateGptqQyblasGemmDescriptor(
infiniopHandle_t handle,
infiniopGptqQyblasGemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scales_desc,
infiniopTensorDescriptor_t b_zeros_desc);

__INFINI_C __export infiniStatus_t infiniopGetGptqQyblasGemmWorkspaceSize(
infiniopGptqQyblasGemmDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t infiniopGptqQyblasGemm(
infiniopGptqQyblasGemmDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *a,
const void *b,
void *b_scale,
void *b_zero,
int64_t quant_type,
int64_t bit,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyGptqQyblasGemmDescriptor(
infiniopGptqQyblasGemmDescriptor_t desc);
#endif
49 changes: 49 additions & 0 deletions src/infiniop/ops/gptq_qyblas_gemm/gptq_qyblas_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef GPTQ_QYBLAS_GEMM_H
#define GPTQ_QYBLAS_GEMM_H

#include "../../operator.h"
#include "info.h"

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::gptq_qyblas_gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
GptqQyblasGemmInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
GptqQyblasGemmInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_scales_desc, \
infiniopTensorDescriptor_t b_zeros_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, \
const void *a, const void *b, void *b_scale, void *b_zero, int64_t quant_type, int64_t bit, \
void *stream) const; \
}; \
}

#endif // GPTQ_QYBLAS_GEMM_H
92 changes: 92 additions & 0 deletions src/infiniop/ops/gptq_qyblas_gemm/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#ifndef __GPTQ_QYBLAS_GEMM_INFO_H__
#define __GPTQ_QYBLAS_GEMM_INFO_H__

#include "../../../utils.h"
#include "../../tensor.h"
#include <optional>
#include <vector>

namespace op::gptq_qyblas_gemm {

class GptqQyblasGemmInfo {
GptqQyblasGemmInfo() = default;

public:
infiniDtype_t dtype, weight_dtype, scales_dtype, zeros_dtype;
size_t M, K, N, scales_size_0, scales_size_1;
ptrdiff_t lda, ldb, result_ld;
bool transpose_mat_1, transpose_mat_2, transpose_result;

static utils::Result<GptqQyblasGemmInfo> createGptqQyblasGemmInfo(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scales_desc,
infiniopTensorDescriptor_t b_zeros_desc) {

auto dtype = a_desc->dtype();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
CHECK_DTYPE(dtype, out_desc->dtype());

const infiniDtype_t weight_dtype = b_desc->dtype();
CHECK_DTYPE(weight_dtype, INFINI_DTYPE_F8, INFINI_DTYPE_U8, INFINI_DTYPE_I8);

const infiniDtype_t scales_dtype = b_scales_desc->dtype();
const infiniDtype_t zeros_dtype = b_zeros_desc->dtype();

size_t M = out_desc->shape()[0];
size_t N = out_desc->shape()[1];
size_t K = a_desc->shape()[1];

size_t scales_size_0 = b_scales_desc->shape()[0];
size_t scales_size_1 = b_scales_desc->shape()[1];

auto ndim = out_desc->ndim();
CHECK_OR_RETURN(ndim == 2
&& a_desc->ndim() == ndim
&& b_desc->ndim() == ndim
&& b_scales_desc->ndim() == ndim
&& b_zeros_desc->ndim() == ndim,
INFINI_STATUS_BAD_TENSOR_SHAPE);

bool transpose_result = false;
if (out_desc->strides()[0] == 1 && out_desc->strides()[1] >= std::max<int64_t>(1, out_desc->shape()[0])) {
transpose_result = true;
} else if (out_desc->strides()[1] == 1 && out_desc->strides()[0] >= std::max<int64_t>(1, out_desc->shape()[1])) {
transpose_result = false;
} else {
transpose_result = false;
}
bool transpose_mat_1 = false;
if (a_desc->strides()[0] == 1 && a_desc->strides()[1] >= std::max<int64_t>(1, a_desc->shape()[0])) {
transpose_mat_1 = true;
} else if (a_desc->strides()[1] == 1 && a_desc->strides()[0] >= std::max<int64_t>(1, a_desc->shape()[1])) {
transpose_mat_1 = false;
} else {
transpose_mat_1 = false;
}
bool transpose_mat_2 = false;
if (b_desc->strides()[0] == 1 && b_desc->strides()[1] >= std::max<int64_t>(1, b_desc->shape()[0])) {
transpose_mat_2 = true;
} else if (b_desc->strides()[1] == 1 && b_desc->strides()[0] >= std::max<int64_t>(1, b_desc->shape()[1])) {
transpose_mat_2 = false;
} else {
transpose_mat_2 = false;
}

ptrdiff_t lda = a_desc->strides()[transpose_mat_1 ? 1 : 0];
ptrdiff_t ldb = b_desc->strides()[transpose_mat_2 ? 1 : 0];
ptrdiff_t result_ld = out_desc->strides()[transpose_result ? 1 : 0];

return utils::Result<GptqQyblasGemmInfo>(GptqQyblasGemmInfo{
dtype, weight_dtype, scales_dtype, zeros_dtype,
M, K, N, scales_size_0, scales_size_1,
lda, ldb, result_ld,
transpose_mat_1, transpose_mat_2, transpose_result});
}
};

} // namespace op::gptq_qyblas_gemm

#endif // __GPTQ_QYBLAS_GEMM_INFO_H__
210 changes: 210 additions & 0 deletions src/infiniop/ops/gptq_qyblas_gemm/nvidia/gptq_qyblas_gemm_nvidia.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#if defined ENABLE_QY_API
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "dlblas_ext.h"
#include "gptq_qyblas_gemm_nvidia.cuh"

namespace op::gptq_qyblas_gemm::nvidia {

struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scales_desc,
infiniopTensorDescriptor_t b_zeros_desc) {

auto info = GptqQyblasGemmInfo::createGptqQyblasGemmInfo(out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc);

CHECK_RESULT(info);

size_t workspace_size = 0;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), workspace_size, handle->device, handle->device_id);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *out,
const void *a,
const void *b,
void *b_scales,
void *b_zeros,
int64_t quant_type,
int64_t bit,
void *stream) const {

int64_t M = static_cast<int64_t>(_info.M);
int64_t K = static_cast<int64_t>(_info.K);
int64_t N = static_cast<int64_t>(_info.N);
int64_t scales_size_0 = static_cast<int64_t>(_info.scales_size_0);
int64_t scales_size_1 = static_cast<int64_t>(_info.scales_size_1);
int64_t lda = static_cast<int64_t>(_info.lda);
int64_t ldb = static_cast<int64_t>(_info.ldb);
int64_t result_ld = static_cast<int64_t>(_info.result_ld);
bool transpose_mat_1 = _info.transpose_mat_1;
bool transpose_mat_2 = _info.transpose_mat_2;

cudaDataType_t computeType_ = (cudaDataType_t)CUDA_R_32F;
cudaDataType_t kernel_Atype_, kernel_Btype_, kernel_Ctype_, kernel_Stype_, kernel_Ztype_;

switch (_info.dtype) {
case INFINI_DTYPE_F16:
kernel_Atype_ = CUDA_R_16F;
break;
case INFINI_DTYPE_BF16:
kernel_Atype_ = CUDA_R_16BF;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

if (quant_type == 0) {
if (8 == bit) {
kernel_Atype_ = (cudaDataType_t)CUDA_R_8U;
}

if (4 == bit) {
kernel_Atype_ = (cudaDataType_t)CUDA_R_4U;
K = K * 2;
}
}

switch (_info.weight_dtype) {
case INFINI_DTYPE_F8:
kernel_Btype_ = (cudaDataType_t)CUDA_R_8F_E4M3;
break;
case INFINI_DTYPE_U8:
kernel_Btype_ = CUDA_R_8U;
break;
case INFINI_DTYPE_I8:
kernel_Btype_ = CUDA_R_8I;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

kernel_Ctype_ = kernel_Atype_;

switch (_info.scales_dtype) {
case INFINI_DTYPE_F32:
kernel_Stype_ = CUDA_R_32F;
break;
case INFINI_DTYPE_F16:
kernel_Stype_ = CUDA_R_16F;
break;
case INFINI_DTYPE_BF16:
kernel_Stype_ = CUDA_R_16BF;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

switch (_info.zeros_dtype) {
case INFINI_DTYPE_F32:
kernel_Ztype_ = CUDA_R_32F;
break;
case INFINI_DTYPE_F16:
kernel_Ztype_ = CUDA_R_16F;
break;
case INFINI_DTYPE_BF16:
kernel_Ztype_ = CUDA_R_16BF;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

float alpha = 1.0f;
float beta = 0.0f;

dlblasExtQuantParametersV2_t extParameters;

if (quant_type == 0) {
extParameters.a_group_size_m = M / scales_size_0;
extParameters.a_group_size_k = K / scales_size_1;
extParameters.a_zeropoints_type = kernel_Ztype_;
extParameters.a_zeropoints = b_zeros;
extParameters.a_scales_type = kernel_Stype_;
extParameters.a_scales = b_scales;
} else if (quant_type == 1) {
extParameters.a_group_size_m = 1;
extParameters.a_group_size_k = K;
extParameters.a_zeropoints = nullptr;
extParameters.a_scales_type = kernel_Stype_;
extParameters.a_scales = b_scales;

} else if (quant_type == 2 || quant_type == 3) {
// calculate block_shape according weight/scales shape
int block_shape = 128;
while ((N + block_shape - 1) / block_shape < scales_size_0) {
block_shape /= 2;
if (block_shape < 32) {
fprintf(stderr,
"INTERNAL ASSERT FAILED: block_shape >= 32\n"
"Invalid fp blockwise linear arguments. Weight: [%d, %d]. Scales: [%d, %d].\n",
(int)N, (int)K, (int)scales_size_0, (int)scales_size_1);
abort();
}
}
if (!((K + block_shape - 1) / block_shape == scales_size_1)) {
fprintf(stderr,
"CHECK FAILED: (K + block_shape - 1) / block_shape == scales_size_1\n");
abort();
}
extParameters.a_group_size_m = block_shape;
extParameters.a_group_size_k = block_shape;
extParameters.a_scales_type = kernel_Stype_;
extParameters.a_zeropoints = nullptr;
extParameters.a_scales = b_scales;
}

cublasOperation_t transa = transpose_mat_2 ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transb = transpose_mat_1 ? CUBLAS_OP_T : CUBLAS_OP_N;

if (_info.dtype == INFINI_DTYPE_F16 || _info.dtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(_opaque->internal->useCublas(
(cudaStream_t)stream,
[&](cublasHandle_t handle) {
CHECK_CUBLAS(
dlblasGemmExV2(handle,
transa,
transb,
N,
M,
K,
&alpha,
b,
kernel_Btype_,
ldb,
a,
kernel_Atype_,
lda,
&beta,
out,
kernel_Ctype_,
result_ld,
computeType_,
CUBLAS_GEMM_DEFAULT_TENSOR_OP,
&extParameters));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}

} // namespace op::gptq_qyblas_gemm::nvidia
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef __GPTQ_QYBLAS_GEMM_NVIDIA_API_H__
#define __GPTQ_QYBLAS_GEMM_NVIDIA_API_H__
#include "../gptq_qyblas_gemm.h"

DESCRIPTOR(nvidia)

#endif // __GPTQ_QYBLAS_GEMM_NVIDIA_API_H__
Loading
Loading