From da487c22c5a4bf325469fb889e6a8f3541b0bed7 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 31 Mar 2026 14:18:45 +0800 Subject: [PATCH 1/3] issue/1083: gptq_marlin_gemm --- include/infiniop/ops/gptq_marlin_gemm.h | 42 + .../ops/gptq_marlin_gemm/gptq_marlin_gemm.h | 66 + src/infiniop/ops/gptq_marlin_gemm/info.h | 59 + .../marlin/awq_marlin_repack.cuh | 281 +++ .../ops/gptq_marlin_gemm/marlin/dequant.h | 504 +++++ .../gptq_marlin_gemm/marlin/gptq_marlin.cuh | 1085 ++++++++++ .../marlin/gptq_marlin_repack.cuh | 398 ++++ .../ops/gptq_marlin_gemm/marlin/kernel.h | 34 + .../ops/gptq_marlin_gemm/marlin/marlin.cuh | 92 + .../gptq_marlin_gemm/marlin/marlin_dtypes.cuh | 78 + .../gptq_marlin_gemm/marlin/marlin_template.h | 1917 +++++++++++++++++ .../nvidia/gptq_marlin_gemm_nvidia.cu | 1141 ++++++++++ .../nvidia/gptq_marlin_gemm_nvidia.cuh | 8 + src/infiniop/ops/gptq_marlin_gemm/operator.cc | 120 ++ .../sgl_kernel/scalar_type.hpp | 335 +++ .../sgl_kernel/source_location.h | 41 + .../ops/gptq_marlin_gemm/sgl_kernel/tensor.h | 621 ++++++ .../ops/gptq_marlin_gemm/sgl_kernel/utils.cuh | 310 +++ .../ops/gptq_marlin_gemm/sgl_kernel/utils.h | 241 +++ test/infiniop/gptq_marlin_gemm.py | 623 ++++++ test/infiniop/libinfiniop/op_register.py | 49 +- test/infiniop/libinfiniop/scalar_type.py | 354 +++ xmake/nvidia.lua | 20 + 23 files changed, 8418 insertions(+), 1 deletion(-) create mode 100644 include/infiniop/ops/gptq_marlin_gemm.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/info.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu create mode 100644 src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/operator.cc create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h create mode 100644 test/infiniop/gptq_marlin_gemm.py create mode 100644 test/infiniop/libinfiniop/scalar_type.py diff --git a/include/infiniop/ops/gptq_marlin_gemm.h b/include/infiniop/ops/gptq_marlin_gemm.h new file mode 100644 index 000000000..37e22baec --- /dev/null +++ b/include/infiniop/ops/gptq_marlin_gemm.h @@ -0,0 +1,42 @@ +#ifndef __INFINIOP_GPTQ_MARLIN_GEMM_API_H__ +#define __INFINIOP_GPTQ_MARLIN_GEMM_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopGptqMarlinGemmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateGptqMarlinGemmDescriptor(infiniopHandle_t handle, + infiniopGptqMarlinGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc); + +__INFINI_C __export infiniStatus_t infiniopGetGptqMarlinGemmWorkspaceSize(infiniopGptqMarlinGemmDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopGptqMarlinGemm(infiniopGptqMarlinGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyGptqMarlinGemmDescriptor(infiniopGptqMarlinGemmDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h b/src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h new file mode 100644 index 000000000..4b02f5e1b --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h @@ -0,0 +1,66 @@ +#ifndef __GPTQ_MARLIN_GEMM_H__ +#define __GPTQ_MARLIN_GEMM_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::gptq_marlin_gemm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + GptqMarlinGemmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + size_t workspace_size_, \ + Opaque *opaque, \ + GptqMarlinGemmInfo info, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t b_scales_desc, \ + infiniopTensorDescriptor_t global_scale_desc, \ + infiniopTensorDescriptor_t b_zeros_desc, \ + infiniopTensorDescriptor_t g_idx_desc, \ + infiniopTensorDescriptor_t perm_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *out, \ + const void *a, \ + const void *b, \ + void *b_scales, \ + void *global_scale, \ + void *b_zeros, \ + void *g_idx, \ + void *perm, \ + int64_t b_q_type_id, \ + bool is_k_full, \ + bool use_atomic_add, \ + bool use_fp32_reduce, \ + bool is_zp_float, \ + void *stream) const; \ + }; \ + } + +#endif //__GPTQ_MARLIN_GEMM_H__ diff --git a/src/infiniop/ops/gptq_marlin_gemm/info.h b/src/infiniop/ops/gptq_marlin_gemm/info.h new file mode 100644 index 000000000..422a53e3e --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/info.h @@ -0,0 +1,59 @@ +#ifndef __GPTQ_MARLIN_GEMM_INFO_H__ +#define __GPTQ_MARLIN_GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +#include + +namespace op::gptq_marlin_gemm { + +class GptqMarlinGemmInfo { + GptqMarlinGemmInfo() = default; + +public: + infiniDtype_t dtype; + size_t M, K, N, b_q_size_1; + int num_groups; + ptrdiff_t a_stride_0; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { + CHECK_OR_RETURN( + out_desc != nullptr && a_desc != nullptr && b_desc != nullptr && b_scales_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + const infiniDtype_t dtype = a_desc->dtype(); + size_t M = out_desc->dim(0); + size_t N = out_desc->dim(1); + size_t K = a_desc->dim(1); + size_t b_q_size_1 = b_desc->dim(1); + int num_groups = static_cast(b_scales_desc->dim(0)); + ptrdiff_t a_stride_0 = a_desc->strides()[0]; + + auto ndim = out_desc->ndim(); + CHECK_OR_RETURN(ndim == 2 + && a_desc->ndim() == ndim + && b_desc->ndim() == ndim + && b_scales_desc->ndim() == ndim, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(b_scales_desc->shape()[1] == N + && a_stride_0 % 8 == 0, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result( + GptqMarlinGemmInfo{dtype, M, K, N, b_q_size_1, num_groups, a_stride_0}); + } +}; + +} // namespace op::gptq_marlin_gemm + +#endif // __GPTQ_MARLIN_GEMM_INFO_H__ diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh new file mode 100644 index 000000000..2aea26529 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh @@ -0,0 +1,281 @@ +#pragma once + +#include "../sgl_kernel/tensor.h" + +#include "../sgl_kernel/utils.cuh" + +#include "marlin.cuh" + +namespace device::marlin +{ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + template + __global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) + { + return; + } +#else + + template + __global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) + { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) + { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() + { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4 *sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) + { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) + { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) + { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + else + { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) + { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) + { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) + { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } + else + { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) + { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) + { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) + { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) + { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } + } +#endif + +} // namespace device::marlin + +// Host wrapper +void awq_marlin_repack( + tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) +{ + using namespace host; + using namespace device::marlin; + + // Validate alignment + RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + + int const pack_factor = 32 / num_bits; + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); + + TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) + .with_dtype() + .with_device(cuda_device) + .verify(out); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + // Get pointers + auto *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + auto *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get device attributes + int blocks = 0; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + // Dispatch based on num_bits + if (num_bits == 4) + { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } + else if (num_bits == 8) + { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } + else + { + RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); + } +} diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h new file mode 100644 index 000000000..764375f62 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h @@ -0,0 +1,504 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace device::marlin { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +// New version with s_type_id parameter for marlin_moe_wna16_v2 +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh new file mode 100644 index 000000000..653501357 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh @@ -0,0 +1,1085 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include "../sgl_kernel/tensor.h" + +#include "../sgl_kernel/scalar_type.hpp" + +#include "kernel.h" +#include "marlin_template.h" + +namespace device::marlin +{ + + __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + + using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + + __global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +#else + + // For a given "a" of size [M,K] performs a permutation of the K columns based + // on the given "perm" indices. + __global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) + { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) + { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) + { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) + { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) + { + if (threadIdx.x < rest) + { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) + { + int cur_row = start_row + i; + if (cur_row < size_m) + { + permute_row(cur_row); + } + } + } + + typedef struct + { + int thread_k; + int thread_n; + int num_threads; + } thread_config_t; + + thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + + thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + + typedef struct + { + int blocks_per_sm; + thread_config_t tb_cfg; + } exec_config_t; + + int get_scales_cache_size( + thread_config_t const &th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) + { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) + { + tb_groups = 1; + } + else if (group_size == 0) + { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } + else + { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) + { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } + else + { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } + } + + int get_kernel_cache_size( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) + { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) + { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; + } + + bool is_valid_config( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) + { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) + { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) + { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) + { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) + { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; + } + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) \ + { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + + template + MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) + { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) + { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + FP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + + if (std::is_same::value) + { + if (false) + { + } + FZP_GET_IF(host::kU4) + } + + return kernel; + } + + template + exec_config_t determine_exec_config( + const host::ScalarType &q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) + { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t *thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) + { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) + { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) + { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) + continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; + } + + template + void marlin_mm( + const void *A, + const void *B, + void *C, + void *C_tmp, + void *s, + void *s2, + void *zp, + void *g_idx, + void *perm, + void *a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void *workspace, + host::ScalarType const &q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) + { + if (has_zp) + { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } + else + { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) + { + if (is_k_full) + { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + else + { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } + else + { + if (group_size == -1) + { + group_blocks = -1; + } + else + { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + int4 *C_tmp_ptr = (int4 *)C_tmp; + const int4 *s_ptr = (const int4 *)s; + const uint16_t *s2_ptr = (const uint16_t *)s2; + const int4 *zp_ptr = (const int4 *)zp; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) + { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) + has_act_order = false; + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) + max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) + { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) + par_count = max_par; + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) + { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } + else + { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) + { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) + { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } + } + +#endif + +} // namespace device::marlin + +template +void gptq_marlin_gemm( + tvm::ffi::TensorView a, + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView b_scales, + tvm::ffi::TensorView global_scale, + tvm::ffi::TensorView b_zeros, + tvm::ffi::TensorView g_idx, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView c, + tvm::ffi::TensorView c_tmp, + tvm::ffi::TensorView a_tmp, + tvm::ffi::TensorView workspace, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) +{ + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Bind symbolic sizes + auto M = SymbolicSize{"M"}; + auto K = SymbolicSize{"K"}; + auto N = SymbolicSize{"N"}; + auto device = SymbolicDevice{}; + device.set_options(); + + // Verify a: [M, K] + auto lda = SymbolicSize{"lda"}; + TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); + + int64_t size_m = M.unwrap(); + int64_t size_k = K.unwrap(); + + // Verify b_q_weight: [K/tile_size, packed_N] + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(expected_bqw_dim0); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); + + RuntimeCheck( + b_q_weight.size(1) % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; + N.set_value(actual_size_n); + int64_t size_n = N.unwrap(); + + // Verify stride alignment + int64_t a_stride0 = a.stride(0); + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + + // Verify b_scales: [num_groups, N] + auto num_groups_sym = SymbolicSize{"num_groups"}; + TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); + int num_groups = static_cast(num_groups_sym.unwrap()); + + // Verify c: [M, N] + TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); + + // Early return for zero-size M + if (size_m == 0) + return; + + // Determine has_act_order from g_idx/perm sizes + int64_t g_idx_size = g_idx.size(0); + int64_t perm_size = perm.size(0); + bool has_act_order = g_idx_size > 0 && perm_size > 0; + + if (has_act_order) + { + RuntimeCheck( + (g_idx_size == size_k && perm_size == size_k), + "Unexpected g_idx.size(0) = ", + g_idx_size, + " and perm.size(0) = ", + perm_size, + ", where size_k = ", + size_k); + } + + // Determine has_zp from b_zeros size + int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = b_zeros_size > 0; + + if (has_zp) + { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } + else + { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) + { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); + } + + // Verify b_zeros shape + if (has_zp) + { + RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); + if (is_zp_float) + { + RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } + else + { + RuntimeCheck( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify global_scale + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) + { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } + else + { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + + // Derive group_size + int group_size = -1; + if (has_act_order) + { + if (is_k_full) + { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } + else + { + group_size = 0; + } + } + else + { + if (num_groups > 1) + { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } + else + { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + RuntimeCheck( + workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh new file mode 100644 index 000000000..d0c2d5414 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh @@ -0,0 +1,398 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include "../sgl_kernel/tensor.h" + +#include "../sgl_kernel/utils.cuh" + +#include "marlin.cuh" + +namespace device::marlin +{ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + template + __global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) + { + return; + } +#else + template + __global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) + { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) + { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() + { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4 *sh_perm_ptr = sh; + int4 *sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) + { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) + { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) + { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) + { + if (threadIdx.x < stage_size) + { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + } + else + { + if (threadIdx.x < stage_size) + { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) + { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) + { + for (int i = 0; i < 4; i++) + { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + } + else + { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + +#pragma unroll + for (int i = 0; i < tile_ints; i++) + { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + +#pragma unroll + for (int i = 0; i < 4; i++) + { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) + { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) + { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } + else + { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) + { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) + { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) + { + int n_tile_id = 0; + + if constexpr (has_perm) + { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) + { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } + } +#endif + +} // namespace device::marlin + +#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) \ + { \ + host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ + device::marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem)); \ + host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ + device::marlin::gptq_marlin_repack_kernel, \ + b_q_weight_ptr, \ + perm_ptr, \ + out_ptr, \ + size_k, \ + size_n); \ + } + +void gptq_marlin_repack( + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView out, + int64_t size_k, + int64_t size_n, + int64_t num_bits) +{ + using namespace host; + + // Validate num_bits + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / static_cast(num_bits); + + // Validate size alignment + RuntimeCheck( + size_k % device::marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + device::marlin::tile_k_size); + RuntimeCheck( + size_n % device::marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + device::marlin::tile_n_size); + + // Validate b_q_weight + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(size_k / pack_factor); + bqw_dim1.set_value(size_n); + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); + + // Validate out + auto out_dim0 = SymbolicSize{"out_dim0"}; + auto out_dim1 = SymbolicSize{"out_dim1"}; + out_dim0.set_value(size_k / device::marlin::tile_size); + out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); + TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const *perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + DLDevice dl_device = device_.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + int blocks; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); + + int max_shared_mem = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + if (false) + { + } + CALL_IF_REPACK(4, false) + CALL_IF_REPACK(4, true) + CALL_IF_REPACK(8, false) + CALL_IF_REPACK(8, true) + else + { + Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } +} + +#undef CALL_IF_REPACK diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h new file mode 100644 index 000000000..e0e36cdd4 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h @@ -0,0 +1,34 @@ + +#include "../sgl_kernel/scalar_type.hpp" + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace device::marlin +{ + template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > + __global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh new file mode 100644 index 000000000..9e99d0f4d --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include "../sgl_kernel/utils.cuh" + +#include + +namespace device::marlin +{ + // Marlin params + + // 8 warps are a good choice since every SM has 4 schedulers and having more + // than 1 warp per schedule allows some more latency hiding. At the same time, + // we want relatively few warps to have many registers per warp and small tiles. + static constexpr int default_threads = 256; + + static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + + static constexpr int min_thread_n = 64; + static constexpr int min_thread_k = 64; + static constexpr int max_thread_n = 256; + + static constexpr int tile_size = 16; + static constexpr int max_par = 16; + + // Repack params + static constexpr int repack_stages = 8; + + static constexpr int repack_threads = 256; + + static constexpr int tile_k_size = tile_size; + static constexpr int tile_n_size = tile_k_size * 4; + + // Helpers + template + struct Vec + { + T elems[n]; + __device__ T &operator[](int i) + { + return elems[i]; + } + }; + + using I4 = Vec; + + using host::div_ceil; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + + __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) + { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), + "l"(glob_ptr), + "n"(BYTES)); + } + + __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) + { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); + } + + __device__ inline void cp_async_fence() + { + asm volatile("cp.async.commit_group;\n" ::); + } + + template + __device__ inline void cp_async_wait() + { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); + } + +#endif + +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh new file mode 100644 index 000000000..783374ff2 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh @@ -0,0 +1,78 @@ +#ifndef _data_types_cuh +#define _data_types_cuh +#include "../sgl_kernel/utils.cuh" + +#include "marlin.cuh" + +namespace device::marlin { + +template +class ScalarType { +}; + +template <> +class ScalarType<__half> { +public: + using scalar_t = __half; + using scalar_t2 = fp16x2_t; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const __half x) { + return __half2float(x); + } + + static __device__ fp16x2_t inline num2num2(const __half x) { + return __half2half2(x); + } + + static __device__ fp16x2_t inline nums2num2(const __half x1, const __half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ __half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType<__nv_bfloat16> { +public: + using scalar_t = __nv_bfloat16; + using scalar_t2 = bf16x2_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const __nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ bf16x2_t inline num2num2(const __nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ bf16x2_t inline nums2num2(const __nv_bfloat16 x1, const __nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ __nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace device::marlin + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h new file mode 100644 index 000000000..8f35f227d --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h @@ -0,0 +1,1917 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ +#include "../sgl_kernel/scalar_type.hpp" + +#include "dequant.h" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace device::marlin +{ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + + template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > + __global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce + ) + { + } + +} // namespace device::marlin + +#else + + // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 + // output/accumulation. + template + __device__ inline void + mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) + { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } + else if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } + else + { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + } + + template + __device__ inline void mma_trans( + const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + const typename ScalarType::FragB &frag_b2, + typename ScalarType::FragC &frag_c) + { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + const uint32_t *b2 = reinterpret_cast(&frag_b2); + float *c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } + else if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } + else + { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + } + + // Instruction for loading a full 16x16 matrix fragment of operand A from shared + // memory, directly in tensor core layout. + template + __device__ inline void ldsm(typename ScalarType::FragA &frag_a, const void *smem_ptr) + { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) + { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } + else if constexpr (count == 2) + { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } + else if constexpr (count == 1) + { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } + else + { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } + } + + // Multiply dequantized values by the corresponding quantization scale; used + // only for grouped quantization. + template + __device__ inline void + scale(typename ScalarType::FragB &frag_b, typename ScalarType::FragS &frag_s, int i) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); + } + + template + __device__ inline void scale_and_sub(typename ScalarType::FragB &frag_b, scalar_t s, scalar_t zp) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); + } + + template + __device__ inline void + sub_zp(typename ScalarType::FragB &frag_b, typename ScalarType::scalar_t2 &frag_zp, int i) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); + } + + // Same as above, but for act_order (each K is multiplied individually) + template + __device__ inline void scale4( + typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); + } + + // Given 2 floats multiply by 2 scales (halves) + template + __device__ inline void scale_float(float *c, typename ScalarType::FragS &s) + { + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + } + + // Wait until barrier reaches `count`, then lock for current threadblock. + __device__ inline void barrier_acquire(int *lock, int count) + { + if (threadIdx.x == 0) + { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); + } + + // Release barrier and increment visitation count. + __device__ inline void barrier_release(int *lock, bool reset = false) + { + __syncthreads(); + if (threadIdx.x == 0) + { + if (reset) + { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } + } + + // Wait until value of lock to be negative, and then add 1 + __device__ inline void wait_negative_and_add(int *lock) + { + if (threadIdx.x == 0) + { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); + } + + template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > + __global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t *__restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int *locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) + { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = host::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; + constexpr bool is_int_type = + w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == host::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == host::kFE2M1f) + { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) + { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) + { + if (group_blocks >= thread_k_blocks) + { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) + { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) + { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } + else + { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) + { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) + { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else + { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) + { + if (slice_count > 1 && slice_idx == slice_count - 1) + { + locks_off++; + } + } + else + { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) + { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) + m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) + { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) + { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) + locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) + { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) + { + if constexpr (group_blocks == -1) + { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + else + { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) + { + if constexpr (group_blocks == -1) + { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + else + { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + } + else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) + { + if constexpr (is_zp_float) + { + if constexpr (group_blocks != -1) + { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } + else + { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) + { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4 *sh_b = sh; + int4 *sh_red = sh; + int4 *sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4 *sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); + int4 *sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); + int4 *sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() + { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) + { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) + { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) + { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) + { + for (int i = 0; i < sh_num_groups; i++) + { + if (threadIdx.x < s_sh_stride) + { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } + else + { + for (int i = 0; i < sh_num_groups; i++) + { + if (threadIdx.x < s_sh_stride) + { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) + { + if (pred) + { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) + { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) + { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) + { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) + { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } + else + { + if constexpr (group_blocks != -1) + { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) + { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + else + { + for (int i = 0; i < s_tb_groups; i++) + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) + { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) + { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) + { + if (zp_sh_wr_pred) + { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + else + { + for (int i = 0; i < zp_tb_groups; i++) + { + if (zp_sh_wr_pred) + { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() + { + if (zp_sh_wr_pred) + { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() + { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) + { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) + { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) + { + if constexpr (!has_act_order) + { + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) + { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) + { + // No act-order case + if constexpr (group_blocks == -1) + { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) + { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + else if constexpr (group_blocks != -1) + { + if constexpr (group_blocks >= thread_k_blocks) + { + if (k % b_sh_wr_iters == 0) + { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + else + { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } + else + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) + { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + else + { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) + { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) + { + if (k % 2 == 0) + { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } + else + { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) + { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) + { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) + { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) + { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) + { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) + { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) + { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + } + else if constexpr (group_blocks >= thread_k_blocks) + { + if (k % b_sh_wr_iters == 0) + { + int4 *sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) + { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + else + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) + { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) + { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) + { + if constexpr (group_blocks >= thread_k_blocks) + { + if (k % b_sh_wr_iters == 0) + { + int4 *sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } + else + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2 *frag_b_ptr) + { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) + { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) + { + if (is_new_zp) + { + if constexpr (group_blocks == -1) + is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) + { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } + else + { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) + { + if (is_new_zp) + { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (w_type == host::kFE2M1f) + { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) + { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) + { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } + else if constexpr (w_type.size_bits() == 4) + { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } + else + { + static_assert(w_type.size_bits() == 8); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) + { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) + { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } + else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) + { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } + else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) + { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } + else if constexpr (group_blocks != -1) + { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { + if constexpr (m_block_size_8) + { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } + else + { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() + { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) + { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) + { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) + { + if (i <= red_idx && red_idx < 2 * i) + { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) + { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) + { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) + { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) + { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) + { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) + { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) + { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + else + { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) + { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) + { + if constexpr (m_block_size_8) + { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } + else + { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) + { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || + (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) + { + if (!first) + { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) + { + int delta = 0; + if constexpr (m_block_size_8) + { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) + { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) + { + int delta = 0; + if constexpr (m_block_size_8) + { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) + { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) + { + return; + } + + if (!first) + { + float *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) + { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) + { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) + { + int4 *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) + { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() + { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) + { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } + else + { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) + { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) + { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == host::kFE2M1f) + { + res = __hmul2(res, global_scale); + } + + if constexpr (m_block_size_8) + { + ((scalar_t *)sh_red)[idx] = res.x; + ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; + } + else + { + ((scalar_t2 *)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) + { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { +#pragma unroll + for (int j = 0; j < 4; j++) + { + if constexpr (m_block_size_8) + { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + else + { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) + { + if (c_gl_wr < c_gl_wr_end) + { + if (use_atomic_add && slice_count > 1) + { + scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) + { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } + else + { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() + { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) + { + if (has_act_order && i == 0) + { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) + { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) + { + if (i == 0) + { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) + { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) + { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) + { + start_pipes(); + } + + // Main loop. + while (slice_iters) + { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) + { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) + { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) + { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) + { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) + { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) + { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) + { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) + { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) + { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) + { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) + { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) + { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) + { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) + { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) + { + if (threadIdx.x / 32 < thread_n_blocks / 4) + { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { +#pragma unroll + for (int j = 0; j < 4; j++) + { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) + { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) + { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) + { + global_reduce_fp32(slice_idx == 0, last); + } + else + { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) + { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) + { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) + { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } + else + { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } + } + +} // namespace device::marlin + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu new file mode 100644 index 000000000..3e424ac4f --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu @@ -0,0 +1,1141 @@ +#if defined ENABLE_NVIDIA_API + +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../gptq_marlin_gemm.h" +#include "../sgl_kernel/tensor.h" +#include "gptq_marlin_gemm_nvidia.cuh" + +#include "../sgl_kernel/scalar_type.hpp" + +#include "../marlin/kernel.h" +#include "../marlin/marlin_template.h" + +namespace device::marlin { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +typedef struct +{ + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +typedef struct +{ + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const &th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) { + sh_zp_size = sh_s_size; + } else if (num_bits == 4) { + sh_zp_size = sh_s_size / 4; + } else if (num_bits == 8) { + sh_zp_size = sh_s_size / 2; + } + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + FP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(host::kU4) + } + + return kernel; +} + +template +exec_config_t determine_exec_config( + const host::ScalarType &q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t *thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + continue; + } + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} + +template +void marlin_mm( + const void *A, + const void *B, + void *C, + void *C_tmp, + void *s, + void *s2, + void *zp, + void *g_idx, + void *perm, + void *a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void *workspace, + host::ScalarType const &q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + int4 *C_tmp_ptr = (int4 *)C_tmp; + const int4 *s_ptr = (const int4 *)s; + const uint16_t *s2_ptr = (const uint16_t *)s2; + const int4 *zp_ptr = (const int4 *)zp; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) { + max_par = 16 * 8; + } + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) { + par_count = max_par; + } + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) { + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } +} + +#endif + +} // namespace device::marlin + +template +void gptq_marlin_gemm(const void *a, + const void *b_q_weight, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + void *c, + void *c_tmp, + void *a_tmp, + void *workspace, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + int64_t size_m, + int64_t size_k, + int64_t size_n, + int64_t b_q_size_1, + int64_t a_stride0, + int num_groups, cudaStream_t stream) + +{ + using namespace host; + + // Verify a: [M, K] + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + // Verify b_q_weight: [K/tile_size, packed_N] + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + RuntimeCheck( + b_q_size_1 % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_size_1, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_size_1 / device::marlin::tile_size) * pack_factor; + RuntimeCheck(actual_size_n == size_n, "actual_size_n must = size_n"); + // size_n = actual_size_n + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + // Verify b_scales: [num_groups, N] + // Early return for zero-size M + if (size_m == 0) { + return; + } + + // int64_t g_idx_size = g_idx.size(0);// g_idx_size == size_k + // int64_t perm_size = perm.size(0);// perm_size == size_k + bool has_act_order = (g_idx != nullptr && perm != nullptr); + + // Determine has_zp from b_zeros size + // int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = (b_zeros != nullptr); + + if (has_zp) { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); + } + + // int64_t global_scale_size = global_scale.size(0); + if (global_scale != nullptr) { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + // Derive group_size + int group_size = -1; + if (has_act_order) { + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + int device_id = 0; + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, device_id)); + // RuntimeCheck(workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a, + b_q_weight, + c, + c_tmp, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + a_tmp, + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace, + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + device_id, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} + +template +infiniStatus_t gptq_marlin_gemm_kernel(void *c, + const void *a, + const void *b_q_weight, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + int64_t size_m, + int64_t size_k, + int64_t size_n, + int64_t b_q_size_1, + int64_t a_stride0, + int num_groups, void *total_buffer, cudaStream_t stream) { + int _MAX_THREAD_N = 256; + int max_blocks_per_sm = 1; + float *c_tmp = nullptr; + void *a_tmp = nullptr; + void *workspace = nullptr; + + // 获取设备 SM 数量(只查询 1 次!) + int dev; + cudaGetDevice(&dev); // 获取当前设备号 + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, dev); + const int sms = prop.multiProcessorCount; + + // ===================== 1. 计算每块内存大小 ===================== + size_t c_tmp_bytes = 0; + if (use_fp32_reduce) { + int max_m_block = ((size_m + 15) / 16) * 16; + max_m_block = min(max_m_block, 64); + const size_t c_elems = (size_t)sms * max_m_block * _MAX_THREAD_N; + c_tmp_bytes = c_elems * sizeof(float); + } + + size_t a_tmp_bytes = 0; + bool has_act_order = false; + if (g_idx != nullptr && perm != nullptr) { + has_act_order = true; + } + if (has_act_order) { + a_tmp_bytes = (size_t)size_m * size_k * sizeof(scalar_t); + } + + // workspace 大小(int 类型,必须分配) + const size_t workspace_elems = (size_t)sms * max_blocks_per_sm; + const size_t workspace_bytes = workspace_elems * sizeof(int); + + // ===================== 2. 计算总内存大小 ===================== + const size_t total_bytes = c_tmp_bytes + a_tmp_bytes + workspace_bytes; + + // ===================== 3. 单次 cudaMalloc 分配 ===================== + if (total_bytes > 0) { + cudaMemset(total_buffer, 0, total_bytes); + } + + // ===================== 4. 手动切分指针(核心!) ===================== + uint8_t *ptr = reinterpret_cast(total_buffer); + + // 分配 c_tmp + if (use_fp32_reduce && c_tmp_bytes > 0) { + c_tmp = reinterpret_cast(ptr); + ptr += c_tmp_bytes; + } + + // 分配 a_tmp + if (has_act_order && a_tmp_bytes > 0) { + a_tmp = ptr; + ptr += a_tmp_bytes; + } + + // 分配 workspace + if (workspace_bytes > 0) { + workspace = ptr; + ptr += workspace_bytes; + } + + gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + c, + c_tmp, + a_tmp, + workspace, + b_q_type_id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + size_m, + size_k, + size_n, + b_q_size_1, + a_stride0, + num_groups, + stream); + return INFINI_STATUS_SUCCESS; +} + +int getCudaDeviceSMCount() { + int dev; + cudaGetDevice(&dev); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, dev); + + return prop.multiProcessorCount; +} + +namespace op::gptq_marlin_gemm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { + + auto handle = reinterpret_cast(handle_); + auto result = GptqMarlinGemmInfo::create(out_desc, a_desc, b_desc, b_scales_desc, global_scale_desc, b_zeros_desc, g_idx_desc, perm_desc); + + int sms = getCudaDeviceSMCount(); + int _MAX_THREAD_N = 256; + int max_blocks_per_sm = 1; + int max_m_block = ((out_desc->dim(0) + 15) / 16) * 16; + max_m_block = min(max_m_block, 64); + const size_t c_elems = (size_t)sms * max_m_block * _MAX_THREAD_N; + size_t c_tmp_bytes = c_elems * sizeof(float); + size_t a_tmp_bytes = (size_t)a_desc->dim(0) * a_desc->dim(1) * infiniSizeOf(a_desc->dtype()); + const size_t workspace_elems = (size_t)sms * max_blocks_per_sm; + const size_t workspace_bytes = workspace_elems * sizeof(int); + size_t workspace_size = c_tmp_bytes + a_tmp_bytes + workspace_bytes; + + *desc_ptr = new Descriptor( + workspace_size, + new Opaque{handle->internal()}, + result.take(), + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + int64_t M = static_cast(_info.M); + int64_t K = static_cast(_info.K); + int64_t N = static_cast(_info.N); + int64_t b_q_size_1 = static_cast(_info.b_q_size_1); + int64_t a_stride_0 = static_cast(_info.a_stride_0); + int num_groups = _info.num_groups; + +#define MARLIN(TDATA) \ + gptq_marlin_gemm_kernel(out, a, b, b_scales, global_scale, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, M, K, N, b_q_size_1, a_stride_0, num_groups, workspace, stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + return MARLIN(half); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + return MARLIN(__nv_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gptq_marlin_gemm::nvidia + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh new file mode 100644 index 000000000..f9c7eb6e9 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __GPTQ_MARLIN_GEMM_CUDA_CUH__ +#define __GPTQ_MARLIN_GEMM_CUDA_CUH__ + +#include "../gptq_marlin_gemm.h" + +DESCRIPTOR(nvidia) + +#endif // __GPTQ_MARLIN_GEMM_CUDA_CUH__ diff --git a/src/infiniop/ops/gptq_marlin_gemm/operator.cc b/src/infiniop/ops/gptq_marlin_gemm/operator.cc new file mode 100644 index 000000000..04d9e43f6 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/operator.cc @@ -0,0 +1,120 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gptq_marlin_gemm.h" + +#if defined ENABLE_NVIDIA_API +#include "nvidia/gptq_marlin_gemm_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateGptqMarlinGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqMarlinGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gptq_marlin_gemm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + a_desc, \ + b_desc, \ + b_scales_desc, \ + global_scale_desc, \ + b_zeros_desc, \ + g_idx_desc, \ + perm_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetGptqMarlinGemmWorkspaceSize(infiniopGptqMarlinGemmDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopGptqMarlinGemm( + infiniopGptqMarlinGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, a, b, b_scales, global_scale, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyGptqMarlinGemmDescriptor(infiniopGptqMarlinGemmDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp new file mode 100644 index 000000000..15f46457f --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp @@ -0,0 +1,335 @@ +#pragma once + +#include +#include +#ifndef __CUDACC__ +#include +#endif + +namespace host { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ + private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + + public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = + "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +} // namespace host + diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h new file mode 100644 index 000000000..57573171a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h @@ -0,0 +1,41 @@ +/// \file source_location.h +/// \brief Portable `source_location` wrapper. +/// +/// Uses `std::source_location` when available (C++20), otherwise falls +/// back to a minimal stub that returns empty/zero values. + +#pragma once +#include + +/// NOTE: fallback to a minimal source_location implementation +#if defined(__cpp_lib_source_location) +#include + +using source_location_t = std::source_location; + +#else + +struct source_location_fallback { + public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char* file_name() const noexcept { + return ""; + } + constexpr const char* function_name() const noexcept { + return ""; + } +}; + +using source_location_t = source_location_fallback; + +#endif + diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h new file mode 100644 index 000000000..9f48edd96 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h @@ -0,0 +1,621 @@ +/// \file tensor.h +/// \brief Tensor validation and symbolic matching utilities. +#pragma once +#include "utils.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include "utils.cuh" +#endif + +namespace host +{ + struct SymbolicSize; + struct SymbolicDType; + struct SymbolicDevice; + + namespace details + { + inline constexpr auto kAnyDeviceID = -1; + inline constexpr auto kAnySize = static_cast(-1); + inline constexpr auto kNullSize = static_cast(-1); + inline constexpr auto kNullDType = static_cast(18u); + inline constexpr auto kNullDevice = static_cast(-1); + + template + struct ArrayView + { + const T *data; + size_t size; + + __host__ __device__ ArrayView() : data(nullptr), size(0) {} + __host__ __device__ ArrayView(const T *d, size_t s) : data(d), size(s) {} + + template + __host__ __device__ ArrayView(const std::array &arr) + : data(arr.data()), size(arr.size()) {} + + __host__ __device__ const T &operator[](size_t i) const { return data[i]; } + __host__ __device__ bool empty() const { return size == 0; } + }; + + template + struct PrintAbleSpan + { + const T *data; + size_t length; + + PrintAbleSpan(const T *p, size_t l) : data(p), length(l) {} + size_t size() const { return length; } + const T &operator[](size_t i) const { return data[i]; } + }; + + inline constexpr const char *kDeviceStringMap[] = { + "", // 0 + "cpu", // 1 + "cuda", // 2 + "cuda_host", // 3 + "opencl", // 4 + "vulkan", // 5 + "metal", // 6 + "vpi", // 7 + "rocm", // 8 + "rocm_host", // 9 + "ext_dev", // 10 + "cuda_managed", // 11 + "oneapi", // 12 + "webgpu", // 13 + "hexagon", // 14 + "maia", // 15 + "trn", // 16 + }; + + constexpr int kMaxDeviceType = 16; + + struct PrintableDevice + { + DLDevice device; + }; + + template + struct _dtype_trait; + + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 8, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 32, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 64, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 8, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 32, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 64, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat, 32, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat, 64, 1}; + }; + +#ifdef __CUDACC__ + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLBfloat, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat8_e4m3fn, 8, 1}; + }; +#endif + + template + struct _device_trait + { + static constexpr DLDevice value = {Code, kAnyDeviceID}; + }; + + template + inline constexpr std::array kDTypeList = { + _dtype_trait::value...}; + + template + inline constexpr std::array kDeviceList = { + _device_trait::value...}; + + } // namespace details + + inline std::ostream &operator<<(std::ostream &os, DLDevice device) + { + int code = static_cast(device.device_type); + if (code < 1 || code > details::kMaxDeviceType) + RuntimeCheck(false, "Unknown device: ", code); + os << details::kDeviceStringMap[code]; + if (device.device_id != details::kAnyDeviceID && device.device_type != kDLCPU) + os << ":" << device.device_id; + return os; + } + + inline std::ostream &operator<<(std::ostream &os, details::PrintableDevice pd) + { + return os << pd.device; + } + + template + inline std::ostream &operator<<(std::ostream &os, const details::PrintAbleSpan &span) + { + os << "["; + for (size_t i = 0; i < span.size(); ++i) + { + if (i > 0) + os << ", "; + os << span[i]; + } + os << "]"; + return os; + } + + // ============================================== + // SymbolicSize 完整定义 + // ============================================== + struct SymbolicSize + { + public: + explicit SymbolicSize(std::string_view ann = {}) + : m_value(details::kNullSize), m_ann(ann) {} + + SymbolicSize(const SymbolicSize &) = delete; + SymbolicSize &operator=(const SymbolicSize &) = delete; + + std::string_view get_name() const { return m_ann; } + bool has_value() const { return m_value != details::kNullSize; } + + void set_value(int64_t v) + { + RuntimeCheck(!has_value(), "Size already set"); + m_value = v; + } + + std::optional get_value() const + { + return has_value() ? std::optional(m_value) : std::nullopt; + } + + int64_t unwrap(DebugInfo info = {}) const + { + RuntimeCheck(info, has_value(), "Size not set"); + return m_value; + } + + void verify(int64_t v, const char *prefix, int64_t dim) + { + if (has_value()) + { + if (m_value != v) [[unlikely]] + { + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " got ", v); + } + } + else + { + set_value(v); + } + } + + std::string value_or_name(const char *prefix, int64_t dim) const + { + if (auto v = get_value()) + return std::to_string(*v); + return m_name_str(prefix, dim); + } + + private: + std::string m_name_str(const char *prefix, int64_t dim) const + { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_ann.empty()) + os << "('" << m_ann << "')"; + return os.str(); + } + + int64_t m_value; + std::string_view m_ann; + }; + + inline bool operator==(DLDevice a, DLDevice b) + { + return a.device_type == b.device_type && a.device_id == b.device_id; + } + + // ============================================== + // SymbolicDType 完整定义 + // ============================================== + struct SymbolicDType + { + public: + SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} + SymbolicDType(const SymbolicDType &) = delete; + SymbolicDType &operator=(const SymbolicDType &) = delete; + + bool has_value() const { return m_value.code != details::kNullDType; } + + void set_value(DLDataType v) + { + RuntimeCheck(!has_value(), "DType already set"); + RuntimeCheck(m_check(v), "DType not allowed: ", v); + m_value = v; + } + + std::optional get_value() const + { + return has_value() ? std::optional(m_value) : std::nullopt; + } + + DLDataType unwrap(DebugInfo info = {}) const + { + RuntimeCheck(info, has_value(), "DType not set"); + return m_value; + } + + void set_options(details::ArrayView opts) { m_opts = opts; } + + template + void set_options() + { + m_opts = details::ArrayView(details::kDTypeList.data(), details::kDTypeList.size()); + } + + void verify(DLDataType dtype) + { + if (has_value()) + { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " got ", dtype); + } + else + { + set_value(dtype); + } + } + + template + bool is_type() const + { + return m_value == details::_dtype_trait::value; + } + + private: + bool m_check(DLDataType v) const + { + if (m_opts.empty()) + return true; + for (size_t i = 0; i < m_opts.size; ++i) + if (m_opts[i] == v) + return true; + return false; + } + + details::ArrayView m_opts; + DLDataType m_value; + }; + + // ============================================== + // SymbolicDevice 完整定义 + // ============================================== + struct SymbolicDevice + { + public: + SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} + SymbolicDevice(const SymbolicDevice &) = delete; + SymbolicDevice &operator=(const SymbolicDevice &) = delete; + + bool has_value() const { return m_value.device_type != details::kNullDevice; } + + void set_value(DLDevice v) + { + RuntimeCheck(!has_value(), "Device already set"); + RuntimeCheck(m_check(v), "Device not allowed: ", details::PrintableDevice{v}); + m_value = v; + } + + std::optional get_value() const + { + return has_value() ? std::optional(m_value) : std::nullopt; + } + + DLDevice unwrap(DebugInfo info = {}) const + { + RuntimeCheck(info, has_value(), "Device not set"); + return m_value; + } + + void set_options(details::ArrayView opts) { m_opts = opts; } + + template + void set_options() + { + m_opts = details::ArrayView(details::kDeviceList.data(), details::kDeviceList.size()); + } + + void verify(DLDevice dev) + { + if (has_value()) + { + RuntimeCheck(m_value == dev, "Device mismatch: expected ", + details::PrintableDevice{m_value}, " got ", details::PrintableDevice{dev}); + } + else + { + set_value(dev); + } + } + + private: + bool m_check(DLDevice v) const + { + if (m_opts.empty()) + return true; + for (size_t i = 0; i < m_opts.size; ++i) + { + auto o = m_opts[i]; + if (o.device_type != v.device_type) + continue; + if (o.device_id == details::kAnyDeviceID || o.device_id == v.device_id) + return true; + } + return false; + } + + details::ArrayView m_opts; + DLDevice m_value; + }; + + // ============================================== + // BaseRef / Ref 类型(现在类型已完整定义) + // ============================================== + namespace details + { + template + struct BaseRef + { + BaseRef() : m_ref(&m_cache) {} + explicit BaseRef(T &r) : m_ref(&r) {} + BaseRef(const BaseRef &) = delete; + BaseRef &operator=(const BaseRef &) = delete; + + T *operator->() const { return m_ref; } + T &operator*() const { return *m_ref; } + void rebind(T &r) { m_ref = &r; } + + private: + T *m_ref; + T m_cache; + }; + + struct SizeRef : public BaseRef + { + using BaseRef::BaseRef; + SizeRef(int64_t v); + }; + + struct DTypeRef : public BaseRef + { + using BaseRef::BaseRef; + DTypeRef(DLDataType); + DTypeRef(std::initializer_list); + DTypeRef(ArrayView); + }; + + struct DeviceRef : public BaseRef + { + using BaseRef::BaseRef; + DeviceRef(DLDevice); + DeviceRef(std::initializer_list); + DeviceRef(ArrayView); + }; + + inline SizeRef::SizeRef(int64_t v) + { + if (v != kAnySize) + (**this).set_value(v); + } + inline DTypeRef::DTypeRef(DLDataType v) { (**this).set_value(v); } + inline DTypeRef::DTypeRef(std::initializer_list l) : DTypeRef(ArrayView(l.begin(), l.size())) {} + inline DTypeRef::DTypeRef(ArrayView v) { (**this).set_options(v); } + inline DeviceRef::DeviceRef(DLDevice v) { (**this).set_value(v); } + inline DeviceRef::DeviceRef(std::initializer_list l) : DeviceRef(ArrayView(l.begin(), l.size())) {} + inline DeviceRef::DeviceRef(ArrayView v) { (**this).set_options(v); } + + } // namespace details + + template + inline bool is_type(DLDataType dtype) + { + return dtype == details::_dtype_trait::value; + } + + // ============================================== + // TensorMatcher + // ============================================== + struct TensorMatcher + { + using SizeRef = details::SizeRef; + using DTypeRef = details::DTypeRef; + using DeviceRef = details::DeviceRef; + + TensorMatcher(const TensorMatcher &) = delete; + TensorMatcher &operator=(const TensorMatcher &) = delete; + + explicit TensorMatcher(std::initializer_list s) + : m_shape(s.begin(), s.size()), m_strides(nullptr, 0) {} + + TensorMatcher &&with_strides(std::initializer_list s) && + { + RuntimeCheck(m_strides.empty(), "Strides already set"); + RuntimeCheck(m_shape.size == s.size(), "Stride/shape size mismatch"); + m_strides = details::ArrayView(s.begin(), s.size()); + return std::move(*this); + } + + template + TensorMatcher &&with_dtype(DTypeRef &&d) && + { + m_dtype.rebind(*d); + m_dtype->template set_options(); + return std::move(*this); + } + + template + TensorMatcher &&with_dtype() && + { + m_dtype->template set_options(); + return std::move(*this); + } + + template + TensorMatcher &&with_device(DeviceRef &&d) && + { + m_device.rebind(*d); + m_device->template set_options(); + return std::move(*this); + } + + template + TensorMatcher &&with_device() && + { + m_device->template set_options(); + return std::move(*this); + } + + const TensorMatcher &&verify(tvm::ffi::TensorView, DebugInfo = {}) const &&; + + private: + static void s_print_tensor(std::ostringstream &, tvm::ffi::TensorView); + void m_verify_impl(tvm::ffi::TensorView) const; + + details::ArrayView m_shape; + details::ArrayView m_strides; + DTypeRef m_dtype; + DeviceRef m_device; + }; + + inline void TensorMatcher::s_print_tensor(std::ostringstream &os, tvm::ffi::TensorView v) + { + os << "Tensor<"; + size_t d = 0; + for (int64_t s : v.shape()) + { + if (d++) + os << ", "; + os << s; + } + os << ">[strides=<"; + d = 0; + for (int64_t s : v.strides()) + { + if (d++) + os << ", "; + os << s; + } + os << ">, dtype=" << v.dtype(); + os << ", device=" << details::PrintableDevice{v.device()} << "]"; + } + + inline const TensorMatcher &&TensorMatcher::verify(tvm::ffi::TensorView v, DebugInfo info) const && + { + try + { + m_verify_impl(v); + } + catch (PanicError &e) + { + std::ostringstream os; + os << "Tensor match failed: "; + s_print_tensor(os, v); + os << " @ " << info.file_name() << ":" << info.line() << "\n- cause: " << e.root_cause(); + throw PanicError(os.str()); + } + return std::move(*this); + } + + inline void TensorMatcher::m_verify_impl(tvm::ffi::TensorView v) const + { + size_t dim = static_cast(v.dim()); + RuntimeCheck(dim == m_shape.size, "Dim mismatch: expected ", m_shape.size, " got ", dim); + + for (size_t i = 0; i < dim; ++i) + m_shape[i]->verify(v.size(i), "shape", (int64_t)i); + + if (!m_strides.empty()) + { + for (size_t i = 0; i < dim; ++i) + { + if (v.size(i) != 1 || !m_strides[i]->has_value()) + m_strides[i]->verify(v.stride(i), "stride", (int64_t)i); + } + } + else + { + RuntimeCheck(v.is_contiguous(), "Tensor not contiguous"); + } + + m_dtype->verify(v.dtype()); + m_device->verify(v.device()); + } + +} // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh new file mode 100644 index 000000000..d73c2ac04 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh @@ -0,0 +1,310 @@ +/// \file utils.cuh +/// \brief Core CUDA/device utilities: type aliases, PDL helpers, +/// typed pointer access, kernel launch wrapper, and error checking. +/// +/// This header is included (directly or transitively) by nearly every +/// JIT kernel. It provides: +/// - Scalar/packed type aliases (`fp16_t`, `bf16_t`, `fp8_e4m3_t`, ...). +/// - `SGL_DEVICE` macro (forced-inline device function qualifier). +/// - `kWarpThreads` constant (32). +/// - PDL (Programmatic Dependent Launch) helpers for Hopper (sm_90+). +/// - Typed `load_as` / `store_as` for void-pointer access. +/// - `pointer::offset` for safe void-pointer arithmetic. +/// - `host::LaunchKernel` - kernel launcher with optional PDL. +/// - `host::RuntimeDeviceCheck` - CUDA error checking. + +#pragma once + +#include "utils.h" + +#include +#include + +#include +#include +#include +#ifndef USE_ROCM +#include +#include +#include +#include +#else +#include +#include +#include +#ifndef __grid_constant__ +#define __grid_constant__ +#endif +using cudaError_t = hipError_t; +using cudaStream_t = hipStream_t; +using cudaLaunchConfig_t = hipLaunchConfig_t; +using cudaLaunchAttribute = hipLaunchAttribute; +inline constexpr auto cudaSuccess = hipSuccess; +#define cudaStreamPerThread hipStreamPerThread +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaLaunchKernel hipLaunchKernel +#endif + +#ifndef USE_ROCM +using fp32_t = float; +// using fp16_t = __half; +// using bf16_t = __nv_bfloat16; +using fp8_e4m3_t = __nv_fp8_e4m3; +using fp8_e5m2_t = __nv_fp8_e5m2; + +using fp32x2_t = float2; +using fp16x2_t = __half2; +using bf16x2_t = __nv_bfloat162; +using fp8x2_e4m3_t = __nv_fp8x2_e4m3; +using fp8x2_e5m2_t = __nv_fp8x2_e5m2; + +using fp32x4_t = float4; +#else +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __hip_bfloat16; +using fp8_e4m3_t = uint8_t; +using fp8_e5m2_t = uint8_t; +using fp32x2_t = float2; +using fp16x2_t = half2; +using bf16x2_t = __hip_bfloat162; +using fp8x2_e4m3_t = uint16_t; +using fp8x2_e5m2_t = uint16_t; +using fp32x4_t = float4; +#endif + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + +namespace device { + +/// \brief Macro: forced-inline device function qualifier. +#define SGL_DEVICE __forceinline__ __device__ + +// Architecture detection: SGL_CUDA_ARCH is injected by load_jit() and is +// available in both host and device compilation passes, whereas __CUDA_ARCH__ +// is only defined by nvcc during the device pass. +#if !defined(USE_ROCM) +#if !defined(SGL_CUDA_ARCH) +#error "SGL_CUDA_ARCH is not defined. JIT compilation must inject -DSGL_CUDA_ARCH via load_jit()." +#endif +#if defined(__CUDA_ARCH__) +static_assert( + __CUDA_ARCH__ == SGL_CUDA_ARCH, "SGL_CUDA_ARCH mismatch: injected arch flag does not match device target"); +#endif +#define SGL_ARCH_HOPPER_OR_GREATER (SGL_CUDA_ARCH >= 900) +#define SGL_ARCH_BLACKWELL_OR_GREATER ((SGL_CUDA_ARCH >= 1000) && (CUDA_VERSION >= 12090)) +#else // USE_ROCM +#define SGL_ARCH_HOPPER_OR_GREATER 0 +#define SGL_ARCH_BLACKWELL_OR_GREATER 0 +#endif + +// Maximum vector size in bytes supported by current architecture. +// Pre-Blackwell / AMD: 128-bit (16 bytes) +// Blackwell or greater: 256-bit (32 bytes) +inline constexpr std::size_t kMaxVecBytes = SGL_ARCH_BLACKWELL_OR_GREATER ? 32 : 16; + +/// \brief Number of threads per warp (always 32 on NVIDIA/AMD GPUs). +inline constexpr auto kWarpThreads = 32u; +/// \brief Full warp active mask (all 32 lanes). +inline constexpr auto kFullMask = 0xffffffffu; + +/** + * \brief PDL (Programmatic Dependent Launch): wait for the primary kernel. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.wait` instruction to + * synchronize with a preceding kernel in the same stream. On older + * architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLWaitPrimary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.wait;" ::: "memory"); + } +#endif +} + +/** + * \brief PDL: trigger dependent (secondary) kernel launch. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.launch_dependents` + * instruction. On older architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLTriggerSecondary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.launch_dependents;" :::); + } +#endif +} + +template +SGL_DEVICE constexpr auto div_ceil(T a, U b) { + static_assert(std::is_integral::value && std::is_integral::value, + "div_ceil requires integer types"); + return (a + b - 1) / b; +} + +/** + * \brief Load data with the specified type and offset from a void pointer. + * \tparam T The type to load. + * \param ptr The base pointer. + * \param offset The offset in number of elements of type T. + */ +template +SGL_DEVICE T load_as(const void *ptr, int64_t offset = 0) { + return static_cast(ptr)[offset]; +} + +/** + * \brief Store data with the specified type and offset to a void pointer. + * \tparam T The type to store. + * \param ptr The base pointer. + * \param val The value to store. + * \param offset The offset in number of elements of type T. + * \note we use type_identity_t to force the caller to explicitly specify + * the template parameter `T`, which can avoid accidentally using the wrong type. + */ +template +SGL_DEVICE void store_as(void *ptr, T val, int64_t offset = 0) { + static_cast(ptr)[offset] = val; +} + +/// \brief Safe void-pointer arithmetic (byte-level by default). +namespace pointer { +// we only allow void * pointer arithmetic for safety + +template +SGL_DEVICE auto offset(void *ptr, U... offset) -> void * { + return static_cast(ptr) + (offset + ...); +} + +template +SGL_DEVICE auto offset(const void *ptr, U... offset) -> const void * { + return static_cast(ptr) + (offset + ...); +} + +} // namespace pointer + +} // namespace device + +namespace host { + +/** + * \brief Check the CUDA error code and panic with location info on failure. + */ +inline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) { + if (error != ::cudaSuccess) { + [[unlikely]]; + ::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error)); + } +} + +/// \brief Check the last CUDA error (calls `cudaGetLastError`). +inline void RuntimeDeviceCheck(DebugInfo location = {}) { + return RuntimeDeviceCheck(::cudaGetLastError(), location); +} + +/** + * \brief Kernel launcher with automatic stream resolution and PDL support. + * + * Usage: + * \code + * host::LaunchKernel(grid, block, device) + * .enable_pdl(true) + * (my_kernel, arg1, arg2); + * \endcode + * + * The constructor resolves the CUDA stream from a `DLDevice` (via + * `TVMFFIEnvGetStream`) or accepts a raw `cudaStream_t`. The call + * operator launches the kernel and checks for errors. + */ +struct LaunchKernel { +public: + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + DLDevice device, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)), + m_location(location) {} + + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)), m_location(location) {} + + LaunchKernel(const LaunchKernel &) = delete; + LaunchKernel &operator=(const LaunchKernel &) = delete; + + static auto resolve_device(DLDevice device) -> cudaStream_t { + return static_cast(::TVMFFIEnvGetStream(device.device_type, device.device_id)); + } + + auto enable_pdl(bool enabled = true) -> LaunchKernel & { +#ifdef USE_ROCM + (void)enabled; + m_config.numAttrs = 0; +#else + if (enabled) { + m_attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + m_attrs[0].val.programmaticStreamSerializationAllowed = true; + m_config.numAttrs = 1; + m_config.attrs = m_attrs; + } else { + m_config.numAttrs = 0; + } +#endif + return *this; + } + + template + auto operator()(T &&kernel, Args &&...args) const -> void { +#ifdef USE_ROCM + hipLaunchKernelGGL( + std::forward(kernel), + m_config.gridDim, + m_config.blockDim, + m_config.dynamicSmemBytes, + m_config.stream, + std::forward(args)...); + RuntimeDeviceCheck(m_location); +#else + RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward(args)...), m_location); +#endif + } + +private: + static auto s_make_config( // Make a config for kernel launch + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t smem) -> cudaLaunchConfig_t { + auto config = ::cudaLaunchConfig_t{}; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem; + config.stream = stream; + config.numAttrs = 0; + return config; + } + + cudaLaunchConfig_t m_config; + const DebugInfo m_location; + cudaLaunchAttribute m_attrs[1]; +}; + +} // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h new file mode 100644 index 000000000..bf7a5ce40 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h @@ -0,0 +1,241 @@ +/// \file utils.h +/// \brief Host-side C++ utilities used by JIT kernel wrappers. +/// +/// Provides: +/// - `DebugInfo` - wraps `std::source_location` for error reporting. +/// - `RuntimeCheck` - runtime assertion with formatted error messages. +/// - `Panic` - unconditional abort with formatted error messages. +/// - `pointer::offset` - safe void-pointer arithmetic (host side). +/// - `div_ceil` - integer ceiling division. +/// - `dtype_bytes` - byte width of a `DLDataType`. +/// - `irange` - Python-style integer range for range-for loops. + +#pragma once + +// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3 +#ifdef __CUDACC__ +#include +#if CUDA_VERSION <= 12010 + +#pragma push_macro("__cpp_consteval") +#pragma push_macro("_NODISCARD") +#pragma push_macro("__builtin_LINE") + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbuiltin-macro-redefined" +#define __cpp_consteval 201811L +#pragma clang diagnostic pop + +#ifdef _NODISCARD +#undef _NODISCARD +#define _NODISCARD +#endif + +#define consteval constexpr + +#include "source_location.h" + +#undef consteval +#pragma pop_macro("__cpp_consteval") +#pragma pop_macro("_NODISCARD") +#else // __CUDACC__ && CUDA_VERSION > 12010 +#include "source_location.h" +#endif +#else // no __CUDACC__ +#include "source_location.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace host +{ + + template + inline constexpr bool dependent_false_v = false; + + /// \brief Source-location wrapper for debug/error messages. + struct DebugInfo : public source_location_t + { + DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} + }; + + /// \brief Exception type thrown by `RuntimeCheck` and `Panic`. + struct PanicError : public std::runtime_error + { + public: + explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} + auto root_cause() const -> std::string_view + { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); + } + + private: + std::string m_message; + }; + + /// \brief Unconditionally abort with a formatted error message. + template + [[noreturn]] + inline auto panic(DebugInfo location, Args &&...args) -> void + { + std::ostringstream os; + os << "Runtime check failed at " << location.file_name() << ":" << location.line(); + if constexpr (sizeof...(args) > 0) + { + os << ": "; + (os << ... << std::forward(args)); + } + else + { + os << " in " << location.function_name(); + } + throw PanicError(std::move(os).str()); + } + + /** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ + template + struct RuntimeCheck + { + template + explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) + { + if (condition) + return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + template + explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) + { + if (condition) + return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + }; + + template + struct Panic + { + explicit Panic(Args &&...args, DebugInfo location = {}) + { + ::host::panic(location, std::forward(args)...); + } + explicit Panic(DebugInfo location, Args &&...args) + { + ::host::panic(location, std::forward(args)...); + } + [[noreturn]] ~Panic() + { + std::terminate(); + } + }; + + template + explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; + + template + explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; + + template + explicit Panic(Args &&...) -> Panic; + + template + explicit Panic(DebugInfo, Args &&...) -> Panic; + + namespace pointer + { + + // we only allow void * pointer arithmetic for safety + + template ::value && ...)>> + inline auto offset(void *ptr, U... offset) -> void * + { + return static_cast(ptr) + (... + offset); + } + + template ::value && ...)>> + inline auto offset(const void *ptr, U... offset) -> const void * + { + return static_cast(ptr) + (... + offset); + } + + } // namespace pointer + + /// \brief Integer ceiling division: ceil(a / b). + template + inline constexpr auto div_ceil(T a, U b) + { + static_assert(std::is_integral::value, "T must be integral"); + static_assert(std::is_integral::value, "U must be integral"); + return (a + b - 1) / b; + } + + /// \brief Returns the byte width of a DLPack data type. + inline auto dtype_bytes(DLDataType dtype) -> std::size_t + { + return static_cast(dtype.bits / 8); + } + + // ====================== 修复开始:纯 C++11 兼容版 irange ====================== + // 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 + + template + struct IntegerRange + { + T start_; + T end_; + + struct Iterator + { + T value; + + T operator*() const { return value; } + Iterator &operator++() + { + ++value; + return *this; + } + bool operator!=(const Iterator &other) const + { + return value != other.value; + } + }; + + Iterator begin() const { return {start_}; } + Iterator end() const { return {end_}; } + }; + + /// Python-style integer range: irange(n) -> [0, n) + template + IntegerRange irange(T end) + { + return {0, end}; + } + + /// Python-style integer range: irange(start, end) -> [start, end) + template + IntegerRange irange(T start, T end) + { + return {start, end}; + } + // ====================== 修复结束 ====================== + +} // namespace host diff --git a/test/infiniop/gptq_marlin_gemm.py b/test/infiniop/gptq_marlin_gemm.py new file mode 100644 index 000000000..9ba296d18 --- /dev/null +++ b/test/infiniop/gptq_marlin_gemm.py @@ -0,0 +1,623 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +import itertools +from libinfiniop.scalar_type import scalar_types, ScalarType +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +import numpy as np + + +# ============================================================================== +# Configuration +# ============================================================================== + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (13, 17, 67), + (257, 13, 11), +] + +k_chunk = 128 +n_chunk = [64, 256] +quant_type = [scalar_types.uint4, scalar_types.uint4b8] +group_size = [-1, 128] +mnk_factors = MNK_FACTORS +act_order = [False, True] + +def to_iter(x): + return x if isinstance(x, (list, tuple)) else (x,) + +_TEST_CASES = list(itertools.product( + to_iter(k_chunk), + to_iter(n_chunk), + to_iter(quant_type), + to_iter(group_size), + to_iter(mnk_factors), + to_iter(act_order), +)) + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# ============================================================================== +# Reference Implementation (matches CUDA kernel) +# ============================================================================== + + +GPTQ_MARLIN_TILE = 16 +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + +def get_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_res = np.zeros((size_k, size_n // pack_factor), dtype=np.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(np.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + +def marlin_make_workspace( + device: torch.device, max_blocks_per_sm: int = 1 +) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros( + sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False + ) + + +# ============================================================================== +# Test Entrypoint +# ============================================================================== + + +def test( + handle, + device, + k_chunk, + n_chunk, + quant_type, + group_size, + mnk_factors, + act_order, + dtype=None, + sync=None, +): + m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + if has_zp: + return + + if size_k % group_size != 0: + return + + print( + f"Testing Gptq Marlin Gemm on {InfiniDeviceNames[device]} with M-K-N:({size_m, size_k, size_n}), group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + + a_input = TestTensor((size_m, size_k), None, dtype, device) + b_weight = TestTensor((size_k, size_n), None, dtype, device) + if has_zp: + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight.torch_tensor(), quant_type, group_size + ) + g_idx = None + sort_indices = None + marlin_s2 = None + else: + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight.torch_tensor(), quant_type, group_size, act_order + ) + marlin_zp = None + marlin_s2 = None + output_ref = torch.matmul(a_input.torch_tensor(), w_ref) + b = TestTensor(marlin_q_w.shape, marlin_q_w.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_q_w) + c = TestTensor(output_ref.shape, None, dtype, device) + b_scales = TestTensor(marlin_s.shape, marlin_s.stride(), dtype, device, mode="manual", set_tensor=marlin_s) + global_scale = None + if marlin_zp is not None: + b_zeros = TestTensor(marlin_zp.shape, marlin_zp.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_zp) + else: + b_zeros = None + if g_idx is not None: + b_g_idx = TestTensor(g_idx.shape, g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx) + else: + b_g_idx = None + if sort_indices is not None: + perm = TestTensor(sort_indices.shape, sort_indices.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=sort_indices) + else: + perm = None + + is_k_full=True + use_atomic_add=False + use_fp32_reduce=False + is_zp_float=False + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateGptqMarlinGemmDescriptor( + handle, + ctypes.byref(descriptor), + c.descriptor, + a_input.descriptor, + b.descriptor, + b_scales.descriptor, + global_scale.descriptor if global_scale is not None else None, + b_zeros.descriptor if b_zeros is not None else None, + b_g_idx.descriptor if b_g_idx is not None else None, + perm.descriptor if perm is not None else None, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [c, a_input, b, b_scales, global_scale, b_zeros, b_g_idx, perm]: + if tensor is not None: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetGptqMarlinGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_gptq_marlin_gemm(): + check_error( + LIBINFINIOP.infiniopGptqMarlinGemm( + descriptor, + workspace.data(), + workspace_size.value, + c.data(), + a_input.data(), + b.data(), + b_scales.data(), + global_scale.data() if global_scale is not None else None, + b_zeros.data() if b_zeros is not None else None, + b_g_idx.data() if b_g_idx is not None else None, + perm.data() if perm is not None else None, + quant_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + None, + ) + ) + + lib_gptq_marlin_gemm() + + + max_diff = torch.mean(torch.abs(c.actual_tensor() - output_ref)) / torch.mean( + torch.abs(output_ref) + ) + assert max_diff < 0.04 + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch.matmul(a_input.torch_tensor(), w_ref), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_gptq_marlin_gemm(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyGptqMarlinGemmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 1c90feb22..1074432bf 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -4,7 +4,7 @@ infiniopOperatorDescriptor_t, ) -from ctypes import c_int32, c_void_p, c_size_t, POINTER, c_float, c_double, c_uint64, c_bool +from ctypes import c_int32, c_void_p, c_size_t, POINTER, c_float, c_double, c_uint64, c_bool, c_int64 class OpRegister: registry = [] @@ -1022,6 +1022,53 @@ def dequantize_gptq_(lib): ] +@OpRegister.operator +def gptq_marlin_gemm_(lib): + lib.infiniopCreateGptqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopCreateGptqMarlinGemmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetGptqMarlinGemmWorkspaceSize.restype = c_int32 + lib.infiniopGetGptqMarlinGemmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopGptqMarlinGemm.restype = c_int32 + lib.infiniopGptqMarlinGemm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_int64, + c_bool, + c_bool, + c_bool, + c_bool, + c_void_p, + + ] + lib.infiniopDestroyGptqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopDestroyGptqMarlinGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def per_channel_quant_int8_(lib): lib.infiniopCreatePerChannelQuantI8Descriptor.restype = c_int32 diff --git a/test/infiniop/libinfiniop/scalar_type.py b/test/infiniop/libinfiniop/scalar_type.py new file mode 100644 index 000000000..bc9f067c1 --- /dev/null +++ b/test/infiniop/libinfiniop/scalar_type.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +_SCALAR_TYPES_ID_MAP = {} + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = max_exponent - exponent_bias + exponent_bias_double + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack("!d", struct.pack("!Q", double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert ( + self.size_bits < 64 or self.size_bits == 64 and self.is_signed() + ), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert ( + self.is_signed() + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack("!d", struct.pack("!Q", min_raw))[0] + else: + assert ( + not self.is_signed() or self.size_bits <= 64 + ), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" + + _SCALAR_TYPES_ID_MAP[val] = self + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754 and not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = ( + "float" + + str(self.size_bits) + + "_e" + + str(self.exponent) + + "m" + + str(self.mantissa) + ) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert mantissa > 0 and exponent > 0 + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_( + cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr + ) -> "ScalarType": + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert mantissa > 0 and exponent > 0 + assert nan_repr != NanRepr.IEEE_754, ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions" + ) + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 + + diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 602fb190d..18c8239a3 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -13,6 +13,17 @@ local FLASH_ATTN_ROOT = get_config("flash-attn") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") + +function parse_sgl_cuda_arch(arch) + + local num = arch:match("sm_(%d+)") + if not num then + return nil + end + + return tonumber(num) * 10 +end + target("infiniop-nvidia") set_kind("static") add_deps("infini-utils") @@ -64,6 +75,15 @@ target("infiniop-nvidia") add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function") local arch_opt = get_config("cuda_arch") + if arch_opt then + local sgl_arch = parse_sgl_cuda_arch(arch_opt) + if sgl_arch then + add_defines("SGL_CUDA_ARCH=" .. sgl_arch) + print("SGL_CUDA_ARCH =", sgl_arch) + else + print("Invalid cuda_arch:", arch_opt) + end + end if arch_opt and type(arch_opt) == "string" then for _, arch in ipairs(arch_opt:split(",")) do arch = arch:trim() From 3997fed751e1f31b7dead47f8c4af173f6f555d6 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 31 Mar 2026 14:22:48 +0800 Subject: [PATCH 2/3] issue/1083: modified format --- .../marlin/awq_marlin_repack.cuh | 369 ++- .../ops/gptq_marlin_gemm/marlin/dequant.h | 580 ++-- .../gptq_marlin_gemm/marlin/gptq_marlin.cuh | 1630 ++++++----- .../marlin/gptq_marlin_repack.cuh | 501 ++-- .../ops/gptq_marlin_gemm/marlin/kernel.h | 39 +- .../ops/gptq_marlin_gemm/marlin/marlin.cuh | 75 +- .../gptq_marlin_gemm/marlin/marlin_template.h | 2470 ++++++++--------- .../sgl_kernel/scalar_type.hpp | 509 ++-- .../sgl_kernel/source_location.h | 35 +- .../ops/gptq_marlin_gemm/sgl_kernel/tensor.h | 826 +++--- .../ops/gptq_marlin_gemm/sgl_kernel/utils.cuh | 6 +- .../ops/gptq_marlin_gemm/sgl_kernel/utils.h | 249 +- test/infiniop/gptq_marlin_gemm.py | 110 +- test/infiniop/libinfiniop/scalar_type.py | 2 - 14 files changed, 3471 insertions(+), 3930 deletions(-) diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh index 2aea26529..2963dbf6b 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh @@ -6,22 +6,19 @@ #include "marlin.cuh" -namespace device::marlin -{ +namespace device::marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - template - __global__ void awq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) - { +template +__global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { return; - } +} #else - template - __global__ void awq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) - { +template +__global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -29,22 +26,20 @@ namespace device::marlin int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) - { - return; + if (start_k_tile >= k_tiles) { + return; } int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() - { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); }; extern __shared__ int4 sh[]; @@ -55,227 +50,201 @@ namespace device::marlin constexpr int stage_k_threads = tile_k_size; constexpr int stage_size = stage_k_threads * stage_n_threads; - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - cp_async_fence(); - return; - } + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } - int first_n = n_tile_id * tile_n_size; - int first_n_packed = first_n / pack_factor; + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; - int4 *sh_ptr = sh + stage_size * pipe; + int4 *sh_ptr = sh + stage_size * pipe; - if (threadIdx.x < stage_size) - { - auto k_id = threadIdx.x / stage_n_threads; - auto n_id = threadIdx.x % stage_n_threads; + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * tile_k_size; - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); - } + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } - cp_async_fence(); + cp_async_fence(); }; - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - return; - } + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } - auto warp_id = threadIdx.x / 32; - auto th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; - if (warp_id >= 4) - { - return; - } + if (warp_id >= 4) { + return; + } - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; - constexpr int tc_offsets[4] = {0, 1, 8, 9}; + constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; - int cur_n_packed = cur_n / pack_factor; - int cur_n_pos = cur_n % pack_factor; + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; - constexpr int sh_stride = tile_n_ints; - constexpr uint32_t mask = (1 << num_bits) - 1; + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4 *sh_stage_ptr = sh + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - // Undo interleaving - int cur_n_pos_unpacked; - if constexpr (num_bits == 4) - { - constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } - else - { - constexpr int undo_pack[4] = {0, 2, 1, 3}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } - uint32_t vals[8]; + uint32_t vals[8]; #pragma unroll - for (int i = 0; i < 4; i++) - { - int cur_elem = tc_row + tc_offsets[i]; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; - } + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } - constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; + constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) - { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < 8; i++) - { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - out_ptr[out_offset + th_id * 4 + warp_id] = res; - } - else - { - constexpr int pack_idx[4] = {0, 2, 1, 3}; + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; - uint32_t res1 = 0; - uint32_t res2 = 0; + uint32_t res1 = 0; + uint32_t res2 = 0; #pragma unroll - for (int i = 0; i < 4; i++) - { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; - auto start_pipes = [&](int k_tile_id, int n_tile_id) - { + auto start_pipes = [&](int k_tile_id, int n_tile_id) { #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) - { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } - wait_for_stage(); + wait_for_stage(); }; #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) - { - int n_tile_id = 0; + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; - start_pipes(k_tile_id, n_tile_id); + start_pipes(k_tile_id, n_tile_id); - while (n_tile_id < n_tiles) - { + while (n_tile_id < n_tiles) { #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) - { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; } - n_tile_id += repack_stages; - } } - } +} #endif } // namespace device::marlin // Host wrapper void awq_marlin_repack( - tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) -{ - using namespace host; - using namespace device::marlin; - - // Validate alignment - RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); - RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); - RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); - - int const pack_factor = 32 / num_bits; - - // Validate tensors - SymbolicDevice cuda_device; - cuda_device.set_options(); - - TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); - - TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) - .with_dtype() - .with_device(cuda_device) - .verify(out); - - // Get device and stream - auto device = cuda_device.unwrap(); - auto stream = LaunchKernel::resolve_device(device); - - // Get pointers - auto *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); - auto *out_ptr = reinterpret_cast(out.data_ptr()); - - // Get device attributes - int blocks = 0; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); - RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); - - // Dispatch based on num_bits - if (num_bits == 4) - { - cudaFuncSetAttribute( - awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); - LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( - awq_marlin_repack_kernel, - b_q_weight_ptr, - out_ptr, - static_cast(size_k), - static_cast(size_n)); - } - else if (num_bits == 8) - { - cudaFuncSetAttribute( - awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); - LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( - awq_marlin_repack_kernel, - b_q_weight_ptr, - out_ptr, - static_cast(size_k), - static_cast(size_n)); - } - else - { - RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); - } + tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { + using namespace host; + using namespace device::marlin; + + // Validate alignment + RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + + int const pack_factor = 32 / num_bits; + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); + + TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) + .with_dtype() + .with_device(cuda_device) + .verify(out); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + // Get pointers + auto *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + auto *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get device attributes + int blocks = 0; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + // Dispatch based on num_bits + if (num_bits == 4) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else if (num_bits == 8) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else { + RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); + } } diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h index 764375f62..6a0d90e5d 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h @@ -73,22 +73,26 @@ namespace device::marlin { // all cases. template __device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; } // Constructs destination register by taking bytes from 2 sources (based on // mask) template __device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); - return res; + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; } template -__device__ inline void dequant(int q, scalar_t2* frag_b); +__device__ inline void dequant(int q, scalar_t2 *frag_b); // // Efficiently dequantize 4bit values packed in an int32 value into a full @@ -100,102 +104,102 @@ __device__ inline void dequant(int q, scalar_t2* frag_b); // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> -__device__ inline void dequant(int q, half2* frag_b) { - const int MASK = 0x000f000f; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - frag_b[0] = *reinterpret_cast(&lo); - frag_b[1] = *reinterpret_cast(&hi); +__device__ inline void dequant(int q, half2 *frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off +__device__ inline void dequant(int q, half2 *frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // clang-format on - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2( - *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off +__device__ inline void dequant(int q, half2 *frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // clang-format on - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2( - *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - // clang-format on + // clang-format on - frag_b[0] = *reinterpret_cast(&lo); - frag_b[1] = *reinterpret_cast(&hi); + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t SUB = 0x43084308; + static constexpr uint32_t SUB = 0x43084308; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t SUB = 0x43004300; + static constexpr uint32_t SUB = 0x43004300; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } // @@ -207,298 +211,298 @@ __device__ inline void dequant(int q, nv_bf // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // template <> -__device__ inline void dequant(int q, half2* frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; +__device__ inline void dequant(int q, half2 *frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); - frag_b[0] = *reinterpret_cast(&lo); - frag_b[1] = *reinterpret_cast(&hi); + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388608.f; - fp32_intermediates[1] -= 8388608.f; - fp32_intermediates[2] -= 8388608.f; - fp32_intermediates[3] -= 8388608.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; - - // Extract and shift FP8 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant(int q, half2 *frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - // Convert to half2 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; + constexpr int MASK = 0x7F007F00; - // Extract and shift FP8 values to BF16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); - - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to bfloat162 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - // Constants for FP4 (E2M1) and FP16 formats - constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; - constexpr int MASK = 0x70007000; - - // Extract and shift FP4 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 4; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant(int q, half2 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - // Constants for FP4 (E2M1) and FP16 formats - constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - // Convert to half2 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - // Constants for FP4 (E2M1) and FP16 formats - constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; - constexpr int MASK = 0x70007000; - - // Extract and shift FP4 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 4; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); - - // Constants for FP4 (E2M1) and BF16 formats - constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to half2 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template -__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); +__device__ inline void dequant_fp8_scales(int q, scalar_t2 *frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { - int Out1 = (q & 0xFF00FF00) >> 1; - ; - q <<= 8; - int Out2 = (q & 0xFF00FF00) >> 1; - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, half2 *frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); }; template <> -__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; - - // Extract and shift FP8 values to BF16 format - int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162 *frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); }; // New version with s_type_id parameter for marlin_moe_wna16_v2 template -__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); +__device__ inline void dequant_fp8_scales(int q, scalar_t2 *frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { - int Out1 = (q & 0xFF00FF00) >> 1; - ; - q <<= 8; - int Out2 = (q & 0xFF00FF00) >> 1; - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, half2 *frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); }; template <> -__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; - - // Extract and shift FP8 values to BF16 format - int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162 *frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { - // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, - // but we assume that such a extreme value would not occur in real models. - int Out1 = (q & 0xFF00FF00) >> 1; - q <<= 7; - int Out2 = q & 0x7F807F80; - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162 *frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } #endif -} // namespace device::marlin +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh index 653501357..ca85889f5 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh @@ -28,132 +28,122 @@ #include "kernel.h" #include "marlin_template.h" -namespace device::marlin -{ +namespace device::marlin { - __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; - using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __global__ void permute_cols_kernel( - int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, - int size_m, - int size_k, - int lda, - int block_rows) {} +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} #else - // For a given "a" of size [M,K] performs a permutation of the K columns based - // on the given "perm" indices. - __global__ void permute_cols_kernel( - int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, - int size_m, - int size_k, - int lda, - int block_rows) - { +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { auto start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; - if (finish_row > size_m) - { - finish_row = size_m; + if (finish_row > size_m) { + finish_row = size_m; } int cur_block_rows = finish_row - start_row; int input_row_stride = lda * sizeof(half) / 16; int output_row_stride = size_k * sizeof(half) / 16; - auto permute_row = [&](int row) - { - int iters = size_k / default_threads; - int rest = size_k % default_threads; + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; - int input_offset = row * input_row_stride; - int output_offset = row * output_row_stride; + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; - half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); - half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); - int base_k = 0; + int base_k = 0; - for (int i = 0; i < iters; i++) - { - auto cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; - out_half[cur_k] = a_row_half[src_pos]; + out_half[cur_k] = a_row_half[src_pos]; - base_k += default_threads; - } + base_k += default_threads; + } - if (rest) - { - if (threadIdx.x < rest) - { - auto cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; - out_half[cur_k] = a_row_half[src_pos]; + out_half[cur_k] = a_row_half[src_pos]; + } } - } }; - for (int i = 0; i < cur_block_rows; i++) - { - int cur_row = start_row + i; - if (cur_row < size_m) - { - permute_row(cur_row); - } + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } } - } +} - typedef struct - { +typedef struct +{ int thread_k; int thread_n; int num_threads; - } thread_config_t; +} thread_config_t; - thread_config_t small_batch_thread_configs[] = { - // Ordered by priority +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}}; + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; - thread_config_t large_batch_thread_configs[] = { - // Ordered by priority +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}}; + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; - typedef struct - { +typedef struct +{ int blocks_per_sm; thread_config_t tb_cfg; - } exec_config_t; - - int get_scales_cache_size( - thread_config_t const &th_config, - int prob_m, - int prob_n, - int prob_k, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full) - { +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const &th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; @@ -161,46 +151,37 @@ namespace device::marlin // Get max scale groups per thread-block int tb_groups; - if (group_size == -1) - { - tb_groups = 1; - } - else if (group_size == 0) - { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size - } - else - { - tb_groups = div_ceil(tb_k, group_size); + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); } - if (cache_scales_chunk) - { - int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; - } - else - { - int tb_scales = tb_groups * tb_n * 2; + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; - return tb_scales * pipe_stages; + return tb_scales * pipe_stages; } - } - - int get_kernel_cache_size( - thread_config_t const &th_config, - int thread_m_blocks, - int prob_m, - int prob_n, - int prob_k, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full, - int has_zp, - int is_zp_float) - { +} + +int get_kernel_cache_size( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { int pack_factor = 32 / num_bits; // Get B size @@ -210,61 +191,55 @@ namespace device::marlin int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8); - int sh_s_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; int sh_zp_size = 0; - if (has_zp) - { - if (is_zp_float) - sh_zp_size = sh_s_size; - else if (num_bits == 4) - sh_zp_size = sh_s_size / 4; - else if (num_bits == 8) - sh_zp_size = sh_s_size / 2; + if (has_zp) { + if (is_zp_float) { + sh_zp_size = sh_s_size; + } else if (num_bits == 4) { + sh_zp_size = sh_s_size / 4; + } else if (num_bits == 8) { + sh_zp_size = sh_s_size / 2; + } } int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; return total_size; - } - - bool is_valid_config( - thread_config_t const &th_config, - int thread_m_blocks, - int prob_m, - int prob_n, - int prob_k, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full, - int has_zp, - int is_zp_float, - int max_shared_mem) - { +} + +bool is_valid_config( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) - { - return false; + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; } // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) - { - return false; + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; } // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) - { - return false; + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; } // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) - { - return false; + if (th_config.num_threads < 128) { + return false; } // Check that pipeline fits into cache @@ -281,27 +256,24 @@ namespace device::marlin has_zp, is_zp_float); return cache_size <= max_shared_mem; - } - -#define _GET_IF( \ - W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if ( \ - q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) \ - { \ - kernel = Marlin< \ - scalar_t, \ - W_TYPE.id(), \ - NUM_THREADS, \ - THREAD_M_BLOCKS, \ - THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, \ - pipe_stages, \ - GROUP_BLOCKS, \ - IS_ZP_FLOAT>; \ - } +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) // this is the most common cases @@ -309,132 +281,130 @@ namespace device::marlin // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) // FP4: cases for nvfp4(e2m1) (group_blocks == 1) -#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) - -#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - -#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 4, 8, 128) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 -#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - -#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - -#define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 4, 8, 128) +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 -#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - -#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - -#define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - - template - MarlinFuncPtr get_marlin_kernel( - const host::ScalarType q_type, - int thread_m_blocks, - int thread_n_blocks, - int thread_k_blocks, - bool m_block_size_8, - bool has_act_order, - bool has_zp, - int group_blocks, - int num_threads, - bool is_zp_float) - { +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { int num_bits = q_type.size_bits(); auto kernel = MarlinDefault; - if (false) - { + if (false) { } COMMON_GET_IF(host::kU4) @@ -448,181 +418,163 @@ namespace device::marlin ACT_GET_IF(host::kU4B8) ACT_GET_IF(host::kU8B128) - if (std::is_same::value) - { - if (false) - { - } - FZP_GET_IF(host::kU4) + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(host::kU4) } return kernel; - } - - template - exec_config_t determine_exec_config( - const host::ScalarType &q_type, - int prob_m, - int prob_n, - int prob_k, - int thread_m_blocks, - bool m_block_size_8, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full, - bool has_zp, - bool is_zp_float, - int max_shared_mem, - int sms) - { +} + +template +exec_config_t determine_exec_config( + const host::ScalarType &q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t *thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); - for (int i = 0; i < thread_configs_size; i++) - { - thread_config_t th_config = thread_configs[i]; - - if (!is_valid_config( - th_config, - thread_m_blocks, - prob_m, - prob_n, - prob_k, - num_bits, - group_size, - has_act_order, - is_k_full, - has_zp, - is_zp_float, - max_shared_mem)) - { - continue; - } - - int cache_size = get_kernel_cache_size( - th_config, - thread_m_blocks, - prob_m, - prob_n, - prob_k, - num_bits, - group_size, - has_act_order, - is_k_full, - has_zp, - is_zp_float); - - int group_blocks = 0; - if (!has_act_order) - { - group_blocks = group_size == -1 ? -1 : group_size / 16; - } - - auto kernel = get_marlin_kernel( - q_type, - thread_m_blocks, - th_config.thread_n / 16, - th_config.thread_k / 16, - m_block_size_8, - has_act_order, - has_zp, - group_blocks, - th_config.num_threads, - is_zp_float); - - if (kernel == MarlinDefault) - continue; - - // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); - // int n_tiles = prob_n / th_config.thread_n; - // int k_tiles = prob_k / th_config.thread_k; - - return {1, th_config}; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + continue; + } + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; } return exec_cfg; - } - - template - void marlin_mm( - const void *A, - const void *B, - void *C, - void *C_tmp, - void *s, - void *s2, - void *zp, - void *g_idx, - void *perm, - void *a_tmp, - int prob_m, - int prob_n, - int prob_k, - int lda, - void *workspace, - host::ScalarType const &q_type, - bool has_act_order, - bool is_k_full, - bool has_zp, - int num_groups, - int group_size, - int dev, - cudaStream_t stream, - int thread_k_init, - int thread_n_init, - int sms, - bool use_atomic_add, - bool use_fp32_reduce, - bool is_zp_float) - { - if (has_zp) - { - host::RuntimeCheck( - q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } - else - { - host::RuntimeCheck( - q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); +} + +template +void marlin_mm( + const void *A, + const void *B, + void *C, + void *C_tmp, + void *s, + void *s2, + void *zp, + void *g_idx, + void *perm, + void *a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void *workspace, + host::ScalarType const &q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } host::RuntimeCheck( prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); int group_blocks = 0; - if (has_act_order) - { - if (is_k_full) - { - host::RuntimeCheck(group_size != -1); - group_blocks = group_size / 16; - host::RuntimeCheck( - prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); - } - else - { - host::RuntimeCheck(group_size == 0); - group_blocks = 0; - } - } - else - { - if (group_size == -1) - { - group_blocks = -1; - } - else - { - group_blocks = group_size / 16; - host::RuntimeCheck( - prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); - } + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } } int num_bits = q_type.size_bits(); @@ -639,20 +591,20 @@ namespace device::marlin int *locks = (int *)workspace; - if (has_act_order) - { - // Permute A columns - int block_rows = div_ceil(prob_m, sms); - host::LaunchKernel(sms, default_threads, stream)( - permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); - A_ptr = a_tmp_ptr; - lda = prob_k; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) - has_act_order = false; + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } } int max_shared_mem = 0; @@ -660,187 +612,184 @@ namespace device::marlin host::RuntimeCheck(max_shared_mem > 0); int max_par = 16; - if (prob_n <= 4096) - max_par = 16 * 8; + if (prob_n <= 4096) { + max_par = 16 * 8; + } int max_shared_mem_new = max_shared_mem; int rest_m = prob_m; int max_thread_m_blocks = 4; - while (rest_m) - { - int par_count = rest_m / (max_thread_m_blocks * 16); - if (par_count > max_par) - par_count = max_par; - int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; - - int thread_k = thread_k_init; - int thread_n = thread_n_init; - - int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); - int m_block_size_8 = prob_m_split <= 8; - - // Set thread config - exec_config_t exec_cfg; - thread_config_t thread_tfg; - if (thread_k != -1 && thread_n != -1) - { - thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; - exec_cfg = exec_config_t{1, thread_tfg}; - host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); - host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); - } - else - { - // Auto config - exec_cfg = determine_exec_config( - q_type, - prob_m_split, - prob_n, - prob_k, + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) { + par_count = max_par; + } + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) { + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - m_block_size_8, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, + ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, + ", is_zp_float = ", is_zp_float, - max_shared_mem, - sms); - thread_tfg = exec_cfg.tb_cfg; - if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) - { - max_thread_m_blocks--; - continue; - } - } - - int num_threads = thread_tfg.num_threads; - thread_k = thread_tfg.thread_k; - thread_n = thread_tfg.thread_n; - int blocks = sms * exec_cfg.blocks_per_sm; - if (exec_cfg.blocks_per_sm > 1) - max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - host::RuntimeCheck( - is_valid_config( - thread_tfg, - thread_m_blocks, - prob_m_split, - prob_n, - prob_k, - num_bits, - group_size, - has_act_order, - is_k_full, - has_zp, - is_zp_float, - max_shared_mem_new), - "Invalid thread config: thread_m_blocks = ", - thread_m_blocks, - ", thread_k = ", - thread_tfg.thread_k, - ", thread_n = ", - thread_tfg.thread_n, - ", num_threads = ", - thread_tfg.num_threads, - " for MKN = [", - prob_m, - ", ", - prob_k, - ", ", - prob_n, - "] and num_bits = ", - num_bits, - ", prob_m_split = ", - prob_m_split, - ", group_size = ", - group_size, - ", has_act_order = ", - has_act_order, - ", is_k_full = ", - is_k_full, - ", has_zp = ", - has_zp, - ", is_zp_float = ", - is_zp_float, - ", max_shared_mem_new = ", - max_shared_mem_new); - - auto kernel = get_marlin_kernel( - q_type, - thread_m_blocks, - thread_n_blocks, - thread_k_blocks, - m_block_size_8, - has_act_order, - has_zp, - group_blocks, - num_threads, - is_zp_float); - - if (kernel == MarlinDefault) - { - host::Panic( - "Unsupported shapes: MNK = [", - prob_m, - ", ", - prob_n, - ", ", - prob_k, - "]", - ", has_act_order = ", - has_act_order, - ", num_groups = ", - num_groups, - ", group_size = ", - group_size, - ", prob_m_split = ", - prob_m_split, - ", thread_m_blocks = ", + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, - ", thread_n_blocks = ", thread_n_blocks, - ", thread_k_blocks = ", thread_k_blocks, - ", num_threads = ", + m_block_size_8, + has_act_order, + has_zp, + group_blocks, num_threads, - ", num_bits = ", - num_bits); - } - - host::RuntimeDeviceCheck( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); - - bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; - - host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( - kernel, - A_ptr, - B_ptr, - C_ptr, - C_tmp_ptr, - s_ptr, - s2_ptr, - zp_ptr, - g_idx_ptr, - num_groups, - prob_m_split, - prob_n, - prob_k, - lda, - locks, - part_use_atomic_add, - use_fp32_reduce, - max_shared_mem_new); - - A_ptr += prob_m_split * (lda / 8); - C_ptr += prob_m_split * (prob_n / 8); - rest_m -= prob_m_split; + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; } - } +} #endif @@ -863,223 +812,202 @@ void gptq_marlin_gemm( bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) -{ - using namespace host; - - ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); - - // Bind symbolic sizes - auto M = SymbolicSize{"M"}; - auto K = SymbolicSize{"K"}; - auto N = SymbolicSize{"N"}; - auto device = SymbolicDevice{}; - device.set_options(); - - // Verify a: [M, K] - auto lda = SymbolicSize{"lda"}; - TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); - - int64_t size_m = M.unwrap(); - int64_t size_k = K.unwrap(); - - // Verify b_q_weight: [K/tile_size, packed_N] - RuntimeCheck( - size_k % device::marlin::tile_size == 0, - "size_k = ", - size_k, - " is not divisible by tile_size = ", - device::marlin::tile_size); - int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; - auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; - auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; - bqw_dim0.set_value(expected_bqw_dim0); - TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); - - RuntimeCheck( - b_q_weight.size(1) % device::marlin::tile_size == 0, - "b_q_weight.size(1) = ", - b_q_weight.size(1), - " is not divisible by tile_size = ", - device::marlin::tile_size); - int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; - N.set_value(actual_size_n); - int64_t size_n = N.unwrap(); - - // Verify stride alignment - int64_t a_stride0 = a.stride(0); - RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); - - // Verify b_scales: [num_groups, N] - auto num_groups_sym = SymbolicSize{"num_groups"}; - TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); - int num_groups = static_cast(num_groups_sym.unwrap()); - - // Verify c: [M, N] - TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); - - // Early return for zero-size M - if (size_m == 0) - return; - - // Determine has_act_order from g_idx/perm sizes - int64_t g_idx_size = g_idx.size(0); - int64_t perm_size = perm.size(0); - bool has_act_order = g_idx_size > 0 && perm_size > 0; - - if (has_act_order) - { - RuntimeCheck( - (g_idx_size == size_k && perm_size == size_k), - "Unexpected g_idx.size(0) = ", - g_idx_size, - " and perm.size(0) = ", - perm_size, - ", where size_k = ", - size_k); - } - - // Determine has_zp from b_zeros size - int64_t b_zeros_size = b_zeros.size(0); - bool has_zp = b_zeros_size > 0; - - if (has_zp) - { - RuntimeCheck( - b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); - } - else - { + bool is_zp_float) { + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Bind symbolic sizes + auto M = SymbolicSize{"M"}; + auto K = SymbolicSize{"K"}; + auto N = SymbolicSize{"N"}; + auto device = SymbolicDevice{}; + device.set_options(); + + // Verify a: [M, K] + auto lda = SymbolicSize{"lda"}; + TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); + + int64_t size_m = M.unwrap(); + int64_t size_k = K.unwrap(); + + // Verify b_q_weight: [K/tile_size, packed_N] RuntimeCheck( - b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); - } - - if (has_zp && is_zp_float) - { + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(expected_bqw_dim0); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); + RuntimeCheck( - std::is_same::value, "Computation type must be float16 (half) when using float zero points."); - } - - // Verify b_zeros shape - if (has_zp) - { - RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); - if (is_zp_float) - { - RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); - RuntimeCheck( - num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); - RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + b_q_weight.size(1) % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; + N.set_value(actual_size_n); + int64_t size_n = N.unwrap(); + + // Verify stride alignment + int64_t a_stride0 = a.stride(0); + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + + // Verify b_scales: [num_groups, N] + auto num_groups_sym = SymbolicSize{"num_groups"}; + TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); + int num_groups = static_cast(num_groups_sym.unwrap()); + + // Verify c: [M, N] + TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); + + // Early return for zero-size M + if (size_m == 0) { + return; } - else - { - RuntimeCheck( - b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); - RuntimeCheck( - b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", - b_zeros.size(1), - " is not size_n / pack_factor = ", - size_n / pack_factor); + + // Determine has_act_order from g_idx/perm sizes + int64_t g_idx_size = g_idx.size(0); + int64_t perm_size = perm.size(0); + bool has_act_order = g_idx_size > 0 && perm_size > 0; + + if (has_act_order) { + RuntimeCheck( + (g_idx_size == size_k && perm_size == size_k), + "Unexpected g_idx.size(0) = ", + g_idx_size, + " and perm.size(0) = ", + perm_size, + ", where size_k = ", + size_k); } - } - - // Verify global_scale - int64_t global_scale_size = global_scale.size(0); - if (global_scale_size > 0) - { - RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); - } - else - { - RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); - } - - // Derive group_size - int group_size = -1; - if (has_act_order) - { - if (is_k_full) - { - RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); - RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); - group_size = static_cast(size_k / num_groups); + + // Determine has_zp from b_zeros size + int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = b_zeros_size > 0; + + if (has_zp) { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); } - else - { - group_size = 0; + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); } - } - else - { - if (num_groups > 1) - { - RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); - group_size = static_cast(size_k / num_groups); + + // Verify b_zeros shape + if (has_zp) { + RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); + if (is_zp_float) { + RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } else { + RuntimeCheck( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } } - else - { - group_size = -1; + + // Verify global_scale + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); } - } - - // Verify workspace and get device info - RuntimeCheck( - size_n % device::marlin::min_thread_n == 0, - "size_n = ", - size_n, - ", is not divisible by min_thread_n = ", - device::marlin::min_thread_n); - - DLDevice dl_device = device.unwrap(); - int dev = dl_device.device_id; - cudaStream_t stream = LaunchKernel::resolve_device(dl_device); - - int sms = -1; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); - - RuntimeCheck( - workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); - - // Hardcoded defaults (auto config) - int thread_k_init = -1; - int thread_n_init = -1; - - // Compute c_tmp and a_tmp pointers - // c_tmp and a_tmp are pre-allocated by caller - - device::marlin::marlin_mm( - a.data_ptr(), - b_q_weight.data_ptr(), - c.data_ptr(), - c_tmp.data_ptr(), - b_scales.data_ptr(), - global_scale.data_ptr(), - b_zeros.data_ptr(), - g_idx.data_ptr(), - perm.data_ptr(), - a_tmp.data_ptr(), - static_cast(size_m), - static_cast(size_n), - static_cast(size_k), - static_cast(a_stride0), - workspace.data_ptr(), - b_q_type, - has_act_order, - is_k_full, - has_zp, - num_groups, - group_size, - dev, - stream, - thread_k_init, - thread_n_init, - sms, - use_atomic_add, - use_fp32_reduce, - is_zp_float); + + // Derive group_size + int group_size = -1; + if (has_act_order) { + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + RuntimeCheck( + workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); } diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh index d0c2d5414..f23f73cbf 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh @@ -27,29 +27,26 @@ #include "marlin.cuh" -namespace device::marlin -{ +namespace device::marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - template - __global__ void gptq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, - int size_k, - int size_n) - { +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) { return; - } +} #else - template - __global__ void gptq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, - int size_k, - int size_n) - { +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -57,22 +54,20 @@ namespace device::marlin int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) - { - return; + if (start_k_tile >= k_tiles) { + return; } int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() - { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); }; extern __shared__ int4 sh[]; @@ -81,9 +76,8 @@ namespace device::marlin int4 *sh_perm_ptr = sh; int4 *sh_pipe_ptr = sh_perm_ptr; - if constexpr (has_perm) - { - sh_pipe_ptr += perm_size; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; } constexpr int tile_ints = tile_k_size / pack_factor; @@ -92,232 +86,202 @@ namespace device::marlin constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; - auto load_perm_to_shared = [&](int k_tile_id) - { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; - int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); - if (threadIdx.x < perm_size) - { - sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; - } - __syncthreads(); + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); }; - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - cp_async_fence(); - return; - } + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * tile_n_size; - int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; - if constexpr (has_perm) - { - if (threadIdx.x < stage_size) - { - auto k_id = threadIdx.x / stage_n_threads; - auto n_id = threadIdx.x % stage_n_threads; + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; - uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor; + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); - } - } - else - { - if (threadIdx.x < stage_size) - { - auto k_id = threadIdx.x / stage_n_threads; - auto n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor; - - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + } else { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } } - } - cp_async_fence(); + cp_async_fence(); }; - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - return; - } + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } - auto warp_id = threadIdx.x / 32; - auto th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; - if (warp_id >= 4) - { - return; - } + if (warp_id >= 4) { + return; + } - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; - constexpr int tc_offsets[4] = {0, 1, 8, 9}; + constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = warp_id * 16 + tc_col; - constexpr int sh_stride = 64; - constexpr uint32_t mask = (1 << num_bits) - 1; + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - uint32_t vals[8]; + uint32_t vals[8]; - if constexpr (has_perm) - { - for (int i = 0; i < 4; i++) - { - int k_idx = tc_row + tc_offsets[i]; + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; - uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor; + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; - uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; - uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; - vals[i] = b1_cur_val; - vals[4 + i] = b2_cur_val; - } - } - else - { - uint32_t b1_vals[tile_ints]; - uint32_t b2_vals[tile_ints]; + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; #pragma unroll - for (int i = 0; i < tile_ints; i++) - { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; - } + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } #pragma unroll - for (int i = 0; i < 4; i++) - { - int cur_elem = tc_row + tc_offsets[i]; - int cur_int = cur_elem / pack_factor; - int cur_pos = cur_elem % pack_factor; - - vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } } - } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) - { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < 8; i++) - { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - out_ptr[out_offset + th_id * 4 + warp_id] = res; - } - else - { - constexpr int pack_idx[4] = {0, 2, 1, 3}; + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; - uint32_t res1 = 0; - uint32_t res2 = 0; + uint32_t res1 = 0; + uint32_t res2 = 0; #pragma unroll - for (int i = 0; i < 4; i++) - { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; - auto start_pipes = [&](int k_tile_id, int n_tile_id) - { + auto start_pipes = [&](int k_tile_id, int n_tile_id) { #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) - { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } - wait_for_stage(); + wait_for_stage(); }; #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) - { - int n_tile_id = 0; + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; - if constexpr (has_perm) - { - load_perm_to_shared(k_tile_id); - } + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } - start_pipes(k_tile_id, n_tile_id); + start_pipes(k_tile_id, n_tile_id); - while (n_tile_id < n_tiles) - { + while (n_tile_id < n_tiles) { #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) - { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; } - n_tile_id += repack_stages; - } } - } +} #endif } // namespace device::marlin -#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) \ - { \ - host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ - device::marlin::gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem)); \ - host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ - device::marlin::gptq_marlin_repack_kernel, \ - b_q_weight_ptr, \ - perm_ptr, \ - out_ptr, \ - size_k, \ - size_n); \ - } +#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ + device::marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem)); \ + host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ + device::marlin::gptq_marlin_repack_kernel, \ + b_q_weight_ptr, \ + perm_ptr, \ + out_ptr, \ + size_k, \ + size_n); \ + } void gptq_marlin_repack( tvm::ffi::TensorView b_q_weight, @@ -325,74 +289,71 @@ void gptq_marlin_repack( tvm::ffi::TensorView out, int64_t size_k, int64_t size_n, - int64_t num_bits) -{ - using namespace host; - - // Validate num_bits - RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / static_cast(num_bits); - - // Validate size alignment - RuntimeCheck( - size_k % device::marlin::tile_k_size == 0, - "size_k = ", - size_k, - " is not divisible by tile_k_size = ", - device::marlin::tile_k_size); - RuntimeCheck( - size_n % device::marlin::tile_n_size == 0, - "size_n = ", - size_n, - " is not divisible by tile_n_size = ", - device::marlin::tile_n_size); - - // Validate b_q_weight - auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; - auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; - bqw_dim0.set_value(size_k / pack_factor); - bqw_dim1.set_value(size_n); - auto device_ = SymbolicDevice{}; - device_.set_options(); - TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); - - // Validate out - auto out_dim0 = SymbolicSize{"out_dim0"}; - auto out_dim1 = SymbolicSize{"out_dim1"}; - out_dim0.set_value(size_k / device::marlin::tile_size); - out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); - TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); - - // Detect if there is act_order - bool has_perm = perm.size(0) != 0; - - // Get ptrs - uint32_t const *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const *perm_ptr = reinterpret_cast(perm.data_ptr()); - uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - DLDevice dl_device = device_.unwrap(); - int dev = dl_device.device_id; - cudaStream_t stream = LaunchKernel::resolve_device(dl_device); - int blocks; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); - - int max_shared_mem = 0; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); - RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); - - if (false) - { - } - CALL_IF_REPACK(4, false) - CALL_IF_REPACK(4, true) - CALL_IF_REPACK(8, false) - CALL_IF_REPACK(8, true) - else - { - Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); - } + int64_t num_bits) { + using namespace host; + + // Validate num_bits + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / static_cast(num_bits); + + // Validate size alignment + RuntimeCheck( + size_k % device::marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + device::marlin::tile_k_size); + RuntimeCheck( + size_n % device::marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + device::marlin::tile_n_size); + + // Validate b_q_weight + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(size_k / pack_factor); + bqw_dim1.set_value(size_n); + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); + + // Validate out + auto out_dim0 = SymbolicSize{"out_dim0"}; + auto out_dim1 = SymbolicSize{"out_dim1"}; + out_dim0.set_value(size_k / device::marlin::tile_size); + out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); + TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const *perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + DLDevice dl_device = device_.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + int blocks; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); + + int max_shared_mem = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + if (false) { + } + CALL_IF_REPACK(4, false) + CALL_IF_REPACK(4, true) + CALL_IF_REPACK(8, false) + CALL_IF_REPACK(8, true) + else { + Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } } #undef CALL_IF_REPACK diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h index e0e36cdd4..785d5e9b9 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h @@ -10,25 +10,24 @@ const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem -namespace device::marlin -{ - template < - typename scalar_t, // compute dtype, half or nv_float16 - const host::ScalarTypeId w_type_id, // weight ScalarType id - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m - // dimension (batchsize) of the - // threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const bool m_block_size_8, // whether m_block_size == 8 - // only works when thread_m_blocks == 1 - const int stages, // number of stages for the async global->shared - // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > - __global__ void Marlin(MARLIN_KERNEL_PARAMS); +namespace device::marlin { +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); } // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh index 9e99d0f4d..944ca3522 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh @@ -4,53 +4,49 @@ #include -namespace device::marlin -{ - // Marlin params +namespace device::marlin { +// Marlin params - // 8 warps are a good choice since every SM has 4 schedulers and having more - // than 1 warp per schedule allows some more latency hiding. At the same time, - // we want relatively few warps to have many registers per warp and small tiles. - static constexpr int default_threads = 256; +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; - static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory - static constexpr int min_thread_n = 64; - static constexpr int min_thread_k = 64; - static constexpr int max_thread_n = 256; +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; - static constexpr int tile_size = 16; - static constexpr int max_par = 16; +static constexpr int tile_size = 16; +static constexpr int max_par = 16; - // Repack params - static constexpr int repack_stages = 8; +// Repack params +static constexpr int repack_stages = 8; - static constexpr int repack_threads = 256; +static constexpr int repack_threads = 256; - static constexpr int tile_k_size = tile_size; - static constexpr int tile_n_size = tile_k_size * 4; +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; - // Helpers - template - struct Vec - { +// Helpers +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) - { - return elems[i]; + __device__ T &operator[](int i) { + return elems[i]; } - }; +}; - using I4 = Vec; +using I4 = Vec; - using host::div_ceil; +using host::div_ceil; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // No support for async #else - __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) - { +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -62,10 +58,9 @@ namespace device::marlin "r"(smem), "l"(glob_ptr), "n"(BYTES)); - } +} - __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) - { +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -74,18 +69,16 @@ namespace device::marlin "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); - } +} - __device__ inline void cp_async_fence() - { +__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); - } +} - template - __device__ inline void cp_async_wait() - { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); - } +} #endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h index 8f35f227d..4ea220265 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h @@ -24,209 +24,187 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert( \ - std::is_same::value || std::is_same::value, \ - "only float16 and bfloat16 is supported"); +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); -namespace device::marlin -{ +namespace device::marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - template < - typename scalar_t, // compute dtype, half or nv_float16 - const host::ScalarTypeId w_type_id, // weight ScalarType id - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m - // dimension (batchsize) of the - // threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const bool m_block_size_8, // whether m_block_size == 8 - // only works when thread_m_blocks == 1 - const int stages, // number of stages for the async global->shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > - __global__ void Marlin( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks, // extra global storage for barrier synchronization - bool use_fp32_reduce // whether to use fp32 global reduce - ) - { - } +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) { +} } // namespace device::marlin #else - // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 - // output/accumulation. - template - __device__ inline void - mma(const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - typename ScalarType::FragC &frag_c) - { +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void +mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); float *c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } - else if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } - else - { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } - } - - template - __device__ inline void mma_trans( - const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - const typename ScalarType::FragB &frag_b2, - typename ScalarType::FragC &frag_c) - { +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + const typename ScalarType::FragB &frag_b2, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); const uint32_t *b2 = reinterpret_cast(&frag_b2); float *c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), - "r"(b2[0]), - "r"(b[1]), - "r"(b2[1]), - "r"(a[0]), - "r"(a[1]), - "f"(c[0]), - "f"(c[1]), - "f"(c[2]), - "f"(c[3])); - } - else if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), - "r"(b2[0]), - "r"(b[1]), - "r"(b2[1]), - "r"(a[0]), - "r"(a[1]), - "f"(c[0]), - "f"(c[1]), - "f"(c[2]), - "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } - else - { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } - } +} - // Instruction for loading a full 16x16 matrix fragment of operand A from shared - // memory, directly in tensor core layout. - template - __device__ inline void ldsm(typename ScalarType::FragA &frag_a, const void *smem_ptr) - { +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA &frag_a, const void *smem_ptr) { uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - if constexpr (count == 4) - { - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); - } - else if constexpr (count == 2) - { - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); - } - else if constexpr (count == 1) - { - asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); } - else - { - static_assert(count == 1 || count == 2 || count == 4, "invalid count"); - } - } - - // Multiply dequantized values by the corresponding quantization scale; used - // only for grouped quantization. - template - __device__ inline void - scale(typename ScalarType::FragB &frag_b, typename ScalarType::FragS &frag_s, int i) - { +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void +scale(typename ScalarType::FragB &frag_b, typename ScalarType::FragS &frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); - } +} - template - __device__ inline void scale_and_sub(typename ScalarType::FragB &frag_b, scalar_t s, scalar_t zp) - { +template +__device__ inline void scale_and_sub(typename ScalarType::FragB &frag_b, scalar_t s, scalar_t zp) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s2 = ScalarType::num2num2(s); scalar_t2 zp2 = ScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); - } +} - template - __device__ inline void - sub_zp(typename ScalarType::FragB &frag_b, typename ScalarType::scalar_t2 &frag_zp, int i) - { +template +__device__ inline void +sub_zp(typename ScalarType::FragB &frag_b, typename ScalarType::scalar_t2 &frag_zp, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); - } - - // Same as above, but for act_order (each K is multiplied individually) - template - __device__ inline void scale4( - typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s_1, - typename ScalarType::FragS &frag_s_2, - typename ScalarType::FragS &frag_s_3, - typename ScalarType::FragS &frag_s_4, - int i) - { +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; @@ -238,106 +216,103 @@ namespace device::marlin frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); - } +} - // Given 2 floats multiply by 2 scales (halves) - template - __device__ inline void scale_float(float *c, typename ScalarType::FragS &s) - { +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { scalar_t *s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); - } - - // Wait until barrier reaches `count`, then lock for current threadblock. - __device__ inline void barrier_acquire(int *lock, int count) - { - if (threadIdx.x == 0) - { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - while (state != count); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state != count); } __syncthreads(); - } +} - // Release barrier and increment visitation count. - __device__ inline void barrier_release(int *lock, bool reset = false) - { +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { __syncthreads(); - if (threadIdx.x == 0) - { - if (reset) - { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); } - } - - // Wait until value of lock to be negative, and then add 1 - __device__ inline void wait_negative_and_add(int *lock) - { - if (threadIdx.x == 0) - { - int state = 0; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - while (state >= 0); - atomicAdd(lock, 1); +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int *lock) { + if (threadIdx.x == 0) { + int state = 0; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state >= 0); + atomicAdd(lock, 1); } __syncthreads(); - } - - template < - typename scalar_t, // compute dtype, half or nv_float16 - const host::ScalarTypeId w_type_id, // weight ScalarType id - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m - // dimension (batchsize) of the - // threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const bool m_block_size_8, // whether m_block_size == 8 - // only works when thread_m_blocks == 1 - const int stages, // number of stages for the async global->shared - // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > - __global__ void Marlin( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const uint16_t *__restrict__ scale2_ptr, // fp16 global scale (for nvfp4 - // only) - const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int lda, // A.stride(0), equal to prob_k is A is contiguous - int *locks, // extra global storage for barrier synchronization - bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce, // whether to use fp32 global reduce - int max_shared_mem) - { +} + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t *__restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int *locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM @@ -359,19 +334,15 @@ namespace device::marlin static constexpr auto w_type = host::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; - constexpr bool is_int_type = - w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + constexpr bool is_int_type = w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; // see comments of dequant.h for more details - constexpr bool dequant_skip_flop = !is_int_type || - has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == host::kU8); + constexpr bool dequant_skip_flop = !is_int_type || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == host::kU8); scalar_t2 global_scale; - if constexpr (w_type == host::kFE2M1f) - { - uint16_t val = scale2_ptr[0]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + if constexpr (w_type == host::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } constexpr bool has_act_order = group_blocks == 0; @@ -383,25 +354,22 @@ namespace device::marlin // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions int parallel = 1; - if (prob_m > m_block_size) - { - parallel = prob_m / m_block_size; - prob_m = m_block_size; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; } int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (!has_act_order && group_blocks != -1) - { - if (group_blocks >= thread_k_blocks) - { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); - } + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } } int slice_row = (iters * blockIdx.x) % k_tiles; @@ -417,97 +385,89 @@ namespace device::marlin // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers - if (slice_col_par >= n_tiles) - { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) - { - // when parallel * n_tiles >= sms - // then there are at most $sms$ conflict tile blocks - locks_off = blockIdx.x; + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; } - else - { - locks_off = (iters * blockIdx.x) / k_tiles - 1; + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; } // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) - { - slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) - { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else - { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) { + slice_iters = 0; } - } - if (parallel * n_tiles >= gridDim.x) - { - if (slice_count > 1 && slice_idx == slice_count - 1) - { - locks_off++; + if (slice_iters == 0) { + return; } - } - else - { - locks_off++; - } - - if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) - { - constexpr int threads_per_m = 16 * thread_n_blocks / 8; - int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); - if (m_block_size_8) - m_per_thread = div_ceil(8, threads / threads_per_m); - for (int i = 0; i < m_per_thread; i++) - { - int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; - if (row < prob_m) - { - int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; - C[row * prob_n / 8 + col] = {0, 0, 0, 0}; - } + if (slice_row + slice_iters > k_tiles) { + slice_iters = k_tiles - slice_row; + } + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) { + slice_count++; + } + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) { + slice_idx = slice_count - 1; + } else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) { + slice_idx--; + } + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) { + m_per_thread = div_ceil(8, threads / threads_per_m); + } + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) { + locks[locks_off] = 1 - slice_count; + } + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; } - // After write zero to output, write a negative value to lock. - // Every SM that processes the same slice would wait for - // the negative value, and then atomicAdd 1 to it. - // After all SMs are processed, the lock value would back to 0 again. - __syncthreads(); - if (threadIdx.x == 0) - locks[locks_off] = 1 - slice_count; - } - - if (slice_col == n_tiles) - { - A += 16 * thread_m_blocks * lda / 8; - C += 16 * thread_m_blocks * prob_n / 8; - slice_col = 0; - par_id++; - } }; init_slice(true); @@ -549,8 +509,8 @@ namespace device::marlin int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -579,8 +539,7 @@ namespace device::marlin // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + - (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; @@ -598,33 +557,24 @@ namespace device::marlin // No act_order int s_gl_rd; - if constexpr (!has_act_order) - { - if constexpr (group_blocks == -1) - { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - else - { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; - } + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; + } } auto s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // Zero-points int zp_gl_rd; - if constexpr (has_zp) - { - if constexpr (group_blocks == -1) - { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - else - { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; - } + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } } auto zp_sh_wr = threadIdx.x; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; @@ -633,21 +583,20 @@ namespace device::marlin // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + } else if constexpr (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; } - else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; // Zero-points have the same read layout as the scales // (without column-wise case) @@ -655,20 +604,14 @@ namespace device::marlin constexpr int num_row_threads = 4; constexpr int num_ints_per_thread = 8 / pack_factor; int zp_sh_rd; - if constexpr (has_zp) - { - if constexpr (is_zp_float) - { - if constexpr (group_blocks != -1) - { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } - } - else - { - zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } } // Precompute which thread should not read memory in which iterations; this is @@ -676,8 +619,9 @@ namespace device::marlin // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } // To ensure that writing and reading A tiles to/from shared memory, the // latter in fragment format, is fully bank conflict free, we need to use a @@ -685,25 +629,25 @@ namespace device::marlin // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the // same shared memory banks. Further, it seems (based on NSight-Compute) that // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) - { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + } int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - { + for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } } // Since B-accesses have non-constant stride they have to be computed at @@ -712,8 +656,9 @@ namespace device::marlin // optimization. const int4 *B_ptr[b_sh_wr_iters]; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + } extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. @@ -744,1173 +689,956 @@ namespace device::marlin FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ // Zero accumulators. - auto zero_accums = [&]() - { + auto zero_accums = [&]() { #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; int sh_num_groups = -1; - auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) - { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups > act_s_max_num_groups) - { - sh_num_groups = act_s_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) - { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) - { - for (int i = 0; i < sh_num_groups; i++) - { - if (threadIdx.x < s_sh_stride) - { - cp_async4_pred( - &sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); - } + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; } - } - else - { - for (int i = 0; i < sh_num_groups; i++) - { - if (threadIdx.x < s_sh_stride) - { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; - } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } } - } }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) - { - if (pred) - { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - { + for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) - { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } - B_ptr[i] += b_gl_rd_delta_o; - } + B_ptr[i] += b_gl_rd_delta_o; + } - if constexpr (has_act_order) - { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) - { - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); - if (threadIdx.x < g_idx_stage) - { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } - else - { - if constexpr (group_blocks != -1) - { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) - { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - else - { - for (int i = 0; i < s_tb_groups; i++) - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) - { - int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) - { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) - { - if (zp_sh_wr_pred) - { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } else { + if constexpr (group_blocks != -1) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } } - zp_gl_rd += zp_gl_rd_delta; - } - } - else - { - for (int i = 0; i < zp_tb_groups; i++) - { - if (zp_sh_wr_pred) - { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + + if constexpr (has_zp && group_blocks != -1) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } } - zp_gl_rd += zp_gl_rd_delta; - } } - } } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); }; - auto fetch_col_zp_to_shared = [&]() - { - if (zp_sh_wr_pred) - { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } }; - auto fetch_col_scale_to_shared = [&]() - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } }; // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() - { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); }; // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) - { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; + auto fetch_to_registers = [&](int k, int pipe) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) - { - frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; - auto init_same_group = [&](int pipe) - { - if constexpr (!has_act_order) - { - return; - } - - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) - { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) - { - // No act-order case - if constexpr (group_blocks == -1) - { - // load only when starting a new slice - if (k == 0 && full_pipe == 0) - { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; } - else if constexpr (group_blocks != -1) - { - if constexpr (group_blocks >= thread_k_blocks) - { - if (k % b_sh_wr_iters == 0) - { - int4 *sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - else - { - reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; - } - } - else - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; - if constexpr (w_type_id != host::kFE2M1f.id()) - { - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - else - { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } } - } - } - return; - } + return; + } - // Act-order case + // Act-order case - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) - { - return; - } + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); - // Determine "position" inside the thread-block (based on warp and - // thread-id) - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; - cur_k += warp_row * 16; + cur_k += warp_row * 16; - auto th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; - if (is_same_group[pipe]) - { - if (k % 2 == 0) - { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; - } - else - { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } - for (int i = 1; i < 4; i++) - { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; } - return; - } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread #pragma unroll - for (int i = 0; i < 4; i++) - { - int actual_k = cur_k + k_frag_offsets[i]; + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } }; - auto fetch_zp_to_registers = [&](int k, int full_pipe) - { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); - if constexpr (has_zp && !is_zp_float) - { - int pipe = full_pipe % stages; + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; - if constexpr (group_blocks == -1) - { - // load only when starting a new slice - if (k == 0 && full_pipe == 0) - { + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { #pragma unroll - for (int i = 0; i < num_ints_per_thread; i++) - { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - } - } - else if constexpr (group_blocks >= thread_k_blocks) - { - if (k % b_sh_wr_iters == 0) - { - int4 *sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); #pragma unroll - for (int i = 0; i < num_ints_per_thread; i++) - { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - else - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = 0; + int k_blocks = cur_k / 16; + int cur_group_id = 0; - // Suppress bogus and persistent divide-by-zero warning + // Suppress bogus and persistent divide-by-zero warning #pragma nv_diagnostic push #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; + cur_group_id = k_blocks / group_blocks; #pragma nv_diagnostic pop - int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; - sh_zp_stage += cur_group_id * zp_sh_stride; + sh_zp_stage += cur_group_id * zp_sh_stride; #pragma unroll - for (int i = 0; i < num_ints_per_thread; i++) - { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } } - } - else if constexpr (has_zp && is_zp_float) - { - int pipe = full_pipe % stages; + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; - if constexpr (group_blocks != -1) - { - if constexpr (group_blocks >= thread_k_blocks) - { - if (k % b_sh_wr_iters == 0) - { - int4 *sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; - } - } - else - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning #pragma nv_diagnostic push #pragma nv_diag_suppress divide_by_zero - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = k_blocks / group_blocks; #pragma nv_diagnostic pop - int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; - } + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } } - } }; - auto dequant_data = [&](int q, scalar_t2 *frag_b_ptr) - { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_t2 *frag_b_ptr) { + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) - { - int k2 = k % 2; - const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || - (group_blocks == -1 && is_first_matmul_in_slice); - if constexpr (has_zp && !is_zp_float) - { - if (is_new_zp) - { - if constexpr (group_blocks == -1) - is_first_matmul_in_slice = false; - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) - { - zp_quant_0 = frag_qzp[k2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } - else - { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k2][0]; - zp_quant_1 = frag_qzp[k2][1]; - } - - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) { + is_first_matmul_in_slice = false; + } + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } } - } - if constexpr (!dequant_skip_flop && has_zp && is_zp_float) - { - if (is_new_zp) - { - reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } } - } - if constexpr (w_type == host::kFE2M1f) - { - int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - } + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll - for (int j = 0; j < 4; j++) - { - FragB frag_b0; - FragB frag_b1; - int b_quant_0, b_quant_1; - - if constexpr (w_type_id == host::kFE2M1f.id()) - { - b_quant_1 = frag_b_quant[k2][0][j]; - b_quant_0 = b_quant_1 << 8; - } - else if constexpr (w_type.size_bits() == 4) - { - b_quant_0 = frag_b_quant[k2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } - else - { - static_assert(w_type.size_bits() == 8); - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) - { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); - } + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } - // Apply scale to frag_b0 - if constexpr (has_act_order) - { - static_assert(group_blocks != -1); - scale4( - frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4( - frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } - else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) - { - int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( - reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], - reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); - if (is_new_zp) - frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } - else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) - { - if (is_new_zp) - frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } - else if constexpr (group_blocks != -1) - { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); - } + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - { - if constexpr (m_block_size_8) - { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); - } - else - { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); - } + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } } - } }; // Since we slice across the k dimension of a tile in order to increase the // number of warps while keeping the n dimension of a tile reasonable, we have // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() - { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) - { - auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) - { + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { #pragma unroll - for (int i = red_off; i > 0; i /= 2) - { - if (i <= red_idx && red_idx < 2 * i) - { + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) - { - int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) - { - float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); } - sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) - { + if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) - { - float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); } - } - __syncthreads(); } - } }; // Since multiple threadblocks may process parts of the same column slice, we // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce_fp16 = [&](bool first = false, bool last = false) - { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) - { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr; - if constexpr (m_block_size_8) - { - c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - } - else - { - c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - } - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; - int row = (threadIdx.x % 32) / 4; + int row = (threadIdx.x % 32) / 4; - if (!first) - { + if (!first) { // Interestingly, doing direct global accesses here really seems to mess up // the compiler and lead to slowdowns, hence we also use async-copies even // though these fetches are not actually asynchronous. #pragma unroll - for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) - { - if constexpr (m_block_size_8) - { - cp_async4_pred( - &sh_red[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], - (threadIdx.x % 4) * 2 + i < prob_m); - } - else - { - cp_async4_pred( - &sh_red[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); } - } - cp_async_fence(); - cp_async_wait<0>(); - } #pragma unroll - for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) - { - bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || - (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); - if (mask) - { - if (!first) - { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) - { - int delta = 0; - if constexpr (m_block_size_8) - { - delta = j % 2 == 1 ? -2 : 0; - } - reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) - { - int4 c; + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; #pragma unroll - for (int j = 0; j < 2 * 4; j++) - { - int delta = 0; - if constexpr (m_block_size_8) - { - delta = j % 2 == 1 ? -2 : 0; + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); - } - if constexpr (m_block_size_8) - C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; - else - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; } - } } - } }; // Globally reduce over threadblocks that compute the same column block. // We use a tmp C buffer to reduce in full fp32 precision. - auto global_reduce_fp32 = [&](bool first = false, bool last = false) - { - constexpr int tb_m = thread_m_blocks * 16; - constexpr int tb_n = thread_n_blocks * 16; + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; - constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; - bool is_th_active = threadIdx.x < active_threads; + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; - constexpr int th_size = num_floats * sizeof(float) / 16; + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; - int c_cur_offset = locks_off * c_size; + int c_cur_offset = locks_off * c_size; - if (!is_th_active) - { - return; - } + if (!is_th_active) { + return; + } - if (!first) - { - float *frag_c_ptr = reinterpret_cast(&frag_c); + if (!first) { + float *frag_c_ptr = reinterpret_cast(&frag_c); #pragma unroll - for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) - { - sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; - float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); #pragma unroll - for (int f = 0; f < 4; f++) - { - frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; - } + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } } - } - if (!last) - { - int4 *frag_c_ptr = reinterpret_cast(&frag_c); + if (!last) { + int4 *frag_c_ptr = reinterpret_cast(&frag_c); #pragma unroll - for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) - { - C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } } - } }; // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() - { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr; - if constexpr (m_block_size_8) - { - c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; - c_sh_wr += 64 * (threadIdx.x / 32); - } - else - { - c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - } - - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) - { - scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr ( - !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) - { - res = __hmul2(res, s[0]); + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); } - if constexpr (w_type == host::kFE2M1f) - { - res = __hmul2(res, global_scale); - } + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); - if constexpr (m_block_size_8) - { - ((scalar_t *)sh_red)[idx] = res.x; - ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; - } - else - { - ((scalar_t2 *)sh_red)[idx] = res; - } - }; + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == host::kFE2M1f) { + res = __hmul2(res, global_scale); + } - if (threadIdx.x / 32 < thread_n_blocks / 4) - { + if constexpr (m_block_size_8) { + ((scalar_t *)sh_red)[idx] = res.x; + ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2 *)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - { + for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) - { - if constexpr (m_block_size_8) - { - int wr = c_sh_wr + 16 * j; - write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - else - { - int wr = c_sh_wr + 8 * j; - write( - wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write( - wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write( - wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write( - wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); } - } - c_sh_wr += 16 * (4 * c_sh_stride); } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) - { - if (c_gl_wr < c_gl_wr_end) - { - if (use_atomic_add && slice_count > 1) - { - scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); #pragma unroll - for (int a = 0; a < 4; a++) - { - atomicAdd(&C_half2[a], sh_red_half2[a]); + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; } - } - else - { - C[c_gl_wr] = sh_red[c_sh_rd]; - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; } - } - __syncthreads(); + __syncthreads(); }; // Start global fetch and register load pipelines. - auto start_pipes = [&]() - { + auto start_pipes = [&]() { #pragma unroll - for (int i = 0; i < stages - 1; i++) - { - if (has_act_order && i == 0) - { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) - { - last_g_idx = prob_k - 1; - } - fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } - if constexpr (has_zp && !is_zp_float && group_blocks == -1) - { - if (i == 0) - { - fetch_col_zp_to_shared(); - if constexpr (!dequant_skip_flop) - { - fetch_col_scale_to_shared(); + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } } - } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - if constexpr (has_act_order) - { - slice_k_start_shared_fetch += tb_k * (stages - 1); - } }; - if (slice_iters) - { - start_pipes(); + if (slice_iters) { + start_pipes(); } // Main loop. - while (slice_iters) - { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. #pragma unroll - for (int pipe = 0; pipe < stages;) - { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) - { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) - { - fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) - { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - - if constexpr (has_act_order) - { - slice_k_start += tb_k * stages; - - if (slice_k_start < prob_k) - { - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) - { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) - { - fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) - { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) - { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } } - cp_async_fence(); - } } - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) - { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) - { + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) - { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - if constexpr (m_block_size_8) - { - int idx = (threadIdx.x / 4) % 2; - scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); #pragma unroll - for (int i = 0; i < 8; i++) - { - frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } } - } } - } - } - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr ( - !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) - { - if (threadIdx.x / 32 < thread_n_blocks / 4) - { + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - { + for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) - { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); - - if constexpr (!m_block_size_8) - { - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } } - } } - } - } - if (slice_count > 1 && !use_atomic_add) - { - // only globally reduce if there is more than one block in a slice - barrier_acquire(&locks[locks_off], slice_idx); - if (use_fp32_reduce) - { - global_reduce_fp32(slice_idx == 0, last); - } - else - { - global_reduce_fp16(slice_idx == 0, last); - } - barrier_release(&locks[locks_off], last); - } - if (use_atomic_add && slice_count > 1 && slice_idx != 0) - wait_negative_and_add(&locks[locks_off]); - if (last || use_atomic_add) - // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - is_first_matmul_in_slice = true; - init_slice(); - - if (slice_iters) - { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) { + wait_negative_and_add(&locks[locks_off]); + } + if (last || use_atomic_add) { + // only the last block in a slice actually writes the result + write_result(); + } + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) - { + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + } + if (slice_col == 0) { #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) - { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - } - else - { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] -= b_gl_stride; + } + } - start_pipes(); + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } } - } } - } +} } // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp index 15f46457f..04ce7a537 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp @@ -18,276 +18,274 @@ namespace host { // here. // class ScalarType { - public: - enum NanRepr : uint8_t { - NAN_NONE = 0, // nans are not supported - NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s - NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s - - NAN_REPR_ID_MAX - }; - - constexpr ScalarType( - uint8_t exponent, - uint8_t mantissa, - bool signed_, - int32_t bias, - bool finite_values_only = false, - NanRepr nan_repr = NAN_IEEE_754) - : exponent(exponent), - mantissa(mantissa), - signed_(signed_), - bias(bias), - finite_values_only(finite_values_only), - nan_repr(nan_repr) {}; - - static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { - return ScalarType(0, size_bits - 1, true, bias); - } - - static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { - return ScalarType(0, size_bits, false, bias); - } - - // IEEE 754 compliant floating point type - static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { - assert(mantissa > 0 && exponent > 0); - return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); - } - - // IEEE 754 non-compliant floating point type - static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { - assert(nan_repr < NAN_REPR_ID_MAX); - assert(mantissa > 0 && exponent > 0); - assert(nan_repr != NAN_IEEE_754); - return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); - } - - uint8_t const exponent; // size of the exponent field (0 for integer types) - uint8_t const mantissa; // size of the mantissa field (size of the integer - // excluding the sign bit for integer types) - bool const signed_; // flag if the type supports negative numbers (i.e. has a - // sign bit) - int32_t const bias; // stored values equal value + bias, - // used for quantized type - - // Extra Floating point info - bool const finite_values_only; // i.e. no +/-inf if true - NanRepr const nan_repr; // how NaNs are represented - // (not applicable for integer types) - - using Id = int64_t; - - private: - // Field size in id - template - static constexpr size_t member_id_field_width() { - using T = std::decay_t; - return std::is_same_v ? 1 : sizeof(T) * 8; - } - - template - static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { - auto new_val = f(val, member); - if constexpr (sizeof...(rest) > 0) { - return reduce_members_helper(f, new_val, rest...); - } else { - return new_val; +public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX }; - } - - template - constexpr auto reduce_members(Fn f, Init init) const { - // Should be in constructor order for `from_id` - return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); - }; - - template - static constexpr auto reduce_member_types(Fn f, Init init) { - constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); - return dummy_type.reduce_members(f, init); - }; - - static constexpr auto id_size_bits() { - return reduce_member_types( - [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); - } - - public: - // unique id for this scalar type that can be computed at compile time for - // c++17 template specialization this is not needed once we migrate to - // c++20 and can pass literal classes as template parameters - constexpr Id id() const { - static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); - - auto or_and_advance = [](std::pair result, auto member) -> std::pair { - auto [id, bit_offset] = result; - auto constexpr bits = member_id_field_width(); - return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + +private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); }; - return reduce_members(or_and_advance, std::pair{}).first; - } - - // create a ScalarType from an id, for c++17 template specialization, - // this is not needed once we migrate to c++20 and can pass literal - // classes as template parameters - static constexpr ScalarType from_id(Id id) { - auto extract_and_advance = [id](auto result, auto member) { - using T = decltype(member); - auto [tuple, bit_offset] = result; - auto constexpr bits = member_id_field_width(); - auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); - auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); - return std::pair{new_tuple, bit_offset + bits}; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); }; - auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); - return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); - } - - constexpr int64_t size_bits() const { - return mantissa + exponent + is_signed(); - } - constexpr bool is_signed() const { - return signed_; - } - constexpr bool is_integer() const { - return exponent == 0; - } - constexpr bool is_floating_point() const { - return exponent > 0; - } - constexpr bool is_ieee_754() const { - return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; - } - constexpr bool has_nans() const { - return is_floating_point() && nan_repr != NAN_NONE; - } - constexpr bool has_infs() const { - return is_floating_point() && finite_values_only == false; - } - constexpr bool has_bias() const { - return bias != 0; - } + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } -#ifndef __CUDACC__ - private: - double _floating_point_max() const { - assert(mantissa <= 52 && exponent <= 11); +public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } - uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; - if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { - max_mantissa -= 1; + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); } - uint64_t max_exponent = (uint64_t(1) << exponent) - 2; - if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { - assert(exponent < 11); - max_exponent += 1; + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ +private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; - // adjust the exponent to match that of a double - // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e - // is the exponent bits), there is some precedent for non-standard biases, - // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes - // but to avoid premature over complication we are just assuming the - // standard exponent bias until there is a need to support non-standard - // biases - uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; - uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 - - uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; - - // shift the mantissa into the position for a double and - // the exponent - uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); - - return *reinterpret_cast(&double_raw); - } - - constexpr std::variant _raw_max() const { - if (is_floating_point()) { - return {_floating_point_max()}; - } else { - assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); - return {(int64_t(1) << mantissa) - 1}; + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); } - } - - constexpr std::variant _raw_min() const { - if (is_floating_point()) { - assert(is_signed()); - constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); - - double max = _floating_point_max(); - uint64_t max_raw = *reinterpret_cast(&max); - uint64_t min_raw = max_raw | sign_bit_double; - return {*reinterpret_cast(&min_raw)}; - } else { - assert(!is_signed() || size_bits() <= 64); - if (is_signed()) { - // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 - // then perform an arithmetic shift right to set all the bits above - // (size_bits() - 1) to 1 - return {INT64_MIN >> (64 - size_bits())}; - } else { - return {int64_t(0)}; - } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } } - } - - public: - // Max representable value for this scalar type. - // (accounting for bias if there is one) - constexpr std::variant max() const { - return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); - } - - // Min representable value for this scalar type. - // (accounting for bias if there is one) - constexpr std::variant min() const { - return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); - } -#endif // __CUDACC__ - - public: - std::string str() const { - /* naming generally follows: https://github.com/jax-ml/ml_dtypes - * for floating point types (leading f) the scheme is: - * `float_em[flags]` - * flags: - * - no-flags: means it follows IEEE 754 conventions - * - f: means finite values only (no infinities) - * - n: means nans are supported (non-standard encoding) - * for integer types the scheme is: - * `[u]int[b]` - * - if bias is not present it means its zero - */ - if (is_floating_point()) { - auto ret = - "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); - if (!is_ieee_754()) { - if (finite_values_only) { - ret += "f"; + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } } - if (nan_repr != NAN_NONE) { - ret += "n"; + } + +public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + +public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; } - } - return ret; - } else { - auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); - if (has_bias()) { - ret += "b" + std::to_string(bias); - } - return ret; } - } - constexpr bool operator==(ScalarType const& other) const { - return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && - finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; - } + constexpr bool operator==(ScalarType const &other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } }; using ScalarTypeId = ScalarType::Id; @@ -331,5 +329,4 @@ static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kBFloat16 = kFE8M7; static inline constexpr auto kFloat16Id = kFloat16.id(); -} // namespace host - +} // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h index 57573171a..9a06fb380 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h @@ -16,26 +16,25 @@ using source_location_t = std::source_location; #else struct source_location_fallback { - public: - static constexpr source_location_fallback current() noexcept { - return source_location_fallback{}; - } - constexpr source_location_fallback() noexcept = default; - constexpr unsigned line() const noexcept { - return 0; - } - constexpr unsigned column() const noexcept { - return 0; - } - constexpr const char* file_name() const noexcept { - return ""; - } - constexpr const char* function_name() const noexcept { - return ""; - } +public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char *file_name() const noexcept { + return ""; + } + constexpr const char *function_name() const noexcept { + return ""; + } }; using source_location_t = source_location_fallback; #endif - diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h index 9f48edd96..f30492621 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h @@ -23,199 +23,178 @@ #include "utils.cuh" #endif -namespace host -{ - struct SymbolicSize; - struct SymbolicDType; - struct SymbolicDevice; - - namespace details - { - inline constexpr auto kAnyDeviceID = -1; - inline constexpr auto kAnySize = static_cast(-1); - inline constexpr auto kNullSize = static_cast(-1); - inline constexpr auto kNullDType = static_cast(18u); - inline constexpr auto kNullDevice = static_cast(-1); - - template - struct ArrayView - { - const T *data; - size_t size; - - __host__ __device__ ArrayView() : data(nullptr), size(0) {} - __host__ __device__ ArrayView(const T *d, size_t s) : data(d), size(s) {} - - template - __host__ __device__ ArrayView(const std::array &arr) - : data(arr.data()), size(arr.size()) {} - - __host__ __device__ const T &operator[](size_t i) const { return data[i]; } - __host__ __device__ bool empty() const { return size == 0; } - }; - - template - struct PrintAbleSpan - { - const T *data; - size_t length; - - PrintAbleSpan(const T *p, size_t l) : data(p), length(l) {} - size_t size() const { return length; } - const T &operator[](size_t i) const { return data[i]; } - }; - - inline constexpr const char *kDeviceStringMap[] = { - "", // 0 - "cpu", // 1 - "cuda", // 2 - "cuda_host", // 3 - "opencl", // 4 - "vulkan", // 5 - "metal", // 6 - "vpi", // 7 - "rocm", // 8 - "rocm_host", // 9 - "ext_dev", // 10 - "cuda_managed", // 11 - "oneapi", // 12 - "webgpu", // 13 - "hexagon", // 14 - "maia", // 15 - "trn", // 16 - }; - - constexpr int kMaxDeviceType = 16; - - struct PrintableDevice - { - DLDevice device; - }; - - template - struct _dtype_trait; - - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 8, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 32, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 64, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 8, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 32, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 64, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat, 32, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat, 64, 1}; - }; +namespace host { +struct SymbolicSize; +struct SymbolicDType; +struct SymbolicDevice; + +namespace details { +inline constexpr auto kAnyDeviceID = -1; +inline constexpr auto kAnySize = static_cast(-1); +inline constexpr auto kNullSize = static_cast(-1); +inline constexpr auto kNullDType = static_cast(18u); +inline constexpr auto kNullDevice = static_cast(-1); + +template +struct ArrayView { + const T *data; + size_t size; + + __host__ __device__ ArrayView() : data(nullptr), size(0) {} + __host__ __device__ ArrayView(const T *d, size_t s) : data(d), size(s) {} + + template + __host__ __device__ ArrayView(const std::array &arr) + : data(arr.data()), size(arr.size()) {} + + __host__ __device__ const T &operator[](size_t i) const { return data[i]; } + __host__ __device__ bool empty() const { return size == 0; } +}; + +template +struct PrintAbleSpan { + const T *data; + size_t length; + + PrintAbleSpan(const T *p, size_t l) : data(p), length(l) {} + size_t size() const { return length; } + const T &operator[](size_t i) const { return data[i]; } +}; + +inline constexpr const char *kDeviceStringMap[] = { + "", // 0 + "cpu", // 1 + "cuda", // 2 + "cuda_host", // 3 + "opencl", // 4 + "vulkan", // 5 + "metal", // 6 + "vpi", // 7 + "rocm", // 8 + "rocm_host", // 9 + "ext_dev", // 10 + "cuda_managed", // 11 + "oneapi", // 12 + "webgpu", // 13 + "hexagon", // 14 + "maia", // 15 + "trn", // 16 +}; + +constexpr int kMaxDeviceType = 16; + +struct PrintableDevice { + DLDevice device; +}; + +template +struct _dtype_trait; + +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 8, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 32, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 64, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 8, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 32, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 64, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat, 32, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat, 64, 1}; +}; #ifdef __CUDACC__ - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLBfloat, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat8_e4m3fn, 8, 1}; - }; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLBfloat, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat8_e4m3fn, 8, 1}; +}; #endif - template - struct _device_trait - { - static constexpr DLDevice value = {Code, kAnyDeviceID}; - }; +template +struct _device_trait { + static constexpr DLDevice value = {Code, kAnyDeviceID}; +}; - template - inline constexpr std::array kDTypeList = { - _dtype_trait::value...}; +template +inline constexpr std::array kDTypeList = { + _dtype_trait::value...}; - template - inline constexpr std::array kDeviceList = { - _device_trait::value...}; +template +inline constexpr std::array kDeviceList = { + _device_trait::value...}; - } // namespace details +} // namespace details - inline std::ostream &operator<<(std::ostream &os, DLDevice device) - { +inline std::ostream &operator<<(std::ostream &os, DLDevice device) { int code = static_cast(device.device_type); - if (code < 1 || code > details::kMaxDeviceType) - RuntimeCheck(false, "Unknown device: ", code); + if (code < 1 || code > details::kMaxDeviceType) { + RuntimeCheck(false, "Unknown device: ", code); + } os << details::kDeviceStringMap[code]; - if (device.device_id != details::kAnyDeviceID && device.device_type != kDLCPU) - os << ":" << device.device_id; + if (device.device_id != details::kAnyDeviceID && device.device_type != kDLCPU) { + os << ":" << device.device_id; + } return os; - } +} - inline std::ostream &operator<<(std::ostream &os, details::PrintableDevice pd) - { +inline std::ostream &operator<<(std::ostream &os, details::PrintableDevice pd) { return os << pd.device; - } +} - template - inline std::ostream &operator<<(std::ostream &os, const details::PrintAbleSpan &span) - { +template +inline std::ostream &operator<<(std::ostream &os, const details::PrintAbleSpan &span) { os << "["; - for (size_t i = 0; i < span.size(); ++i) - { - if (i > 0) - os << ", "; - os << span[i]; + for (size_t i = 0; i < span.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << span[i]; } os << "]"; return os; - } - - // ============================================== - // SymbolicSize 完整定义 - // ============================================== - struct SymbolicSize - { - public: +} + +// ============================================== +// SymbolicSize 完整定义 +// ============================================== +struct SymbolicSize { +public: explicit SymbolicSize(std::string_view ann = {}) : m_value(details::kNullSize), m_ann(ann) {} @@ -225,275 +204,243 @@ namespace host std::string_view get_name() const { return m_ann; } bool has_value() const { return m_value != details::kNullSize; } - void set_value(int64_t v) - { - RuntimeCheck(!has_value(), "Size already set"); - m_value = v; + void set_value(int64_t v) { + RuntimeCheck(!has_value(), "Size already set"); + m_value = v; } - std::optional get_value() const - { - return has_value() ? std::optional(m_value) : std::nullopt; + std::optional get_value() const { + return has_value() ? std::optional(m_value) : std::nullopt; } - int64_t unwrap(DebugInfo info = {}) const - { - RuntimeCheck(info, has_value(), "Size not set"); - return m_value; + int64_t unwrap(DebugInfo info = {}) const { + RuntimeCheck(info, has_value(), "Size not set"); + return m_value; } - void verify(int64_t v, const char *prefix, int64_t dim) - { - if (has_value()) - { - if (m_value != v) [[unlikely]] - { - Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " got ", v); + void verify(int64_t v, const char *prefix, int64_t dim) { + if (has_value()) { + if (m_value != v) [[unlikely]] { + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " got ", v); + } + } else { + set_value(v); } - } - else - { - set_value(v); - } } - std::string value_or_name(const char *prefix, int64_t dim) const - { - if (auto v = get_value()) - return std::to_string(*v); - return m_name_str(prefix, dim); + std::string value_or_name(const char *prefix, int64_t dim) const { + if (auto v = get_value()) { + return std::to_string(*v); + } + return m_name_str(prefix, dim); } - private: - std::string m_name_str(const char *prefix, int64_t dim) const - { - std::ostringstream os; - os << prefix << '#' << dim; - if (!m_ann.empty()) - os << "('" << m_ann << "')"; - return os.str(); +private: + std::string m_name_str(const char *prefix, int64_t dim) const { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_ann.empty()) { + os << "('" << m_ann << "')"; + } + return os.str(); } int64_t m_value; std::string_view m_ann; - }; +}; - inline bool operator==(DLDevice a, DLDevice b) - { +inline bool operator==(DLDevice a, DLDevice b) { return a.device_type == b.device_type && a.device_id == b.device_id; - } - - // ============================================== - // SymbolicDType 完整定义 - // ============================================== - struct SymbolicDType - { - public: +} + +// ============================================== +// SymbolicDType 完整定义 +// ============================================== +struct SymbolicDType { +public: SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} SymbolicDType(const SymbolicDType &) = delete; SymbolicDType &operator=(const SymbolicDType &) = delete; bool has_value() const { return m_value.code != details::kNullDType; } - void set_value(DLDataType v) - { - RuntimeCheck(!has_value(), "DType already set"); - RuntimeCheck(m_check(v), "DType not allowed: ", v); - m_value = v; + void set_value(DLDataType v) { + RuntimeCheck(!has_value(), "DType already set"); + RuntimeCheck(m_check(v), "DType not allowed: ", v); + m_value = v; } - std::optional get_value() const - { - return has_value() ? std::optional(m_value) : std::nullopt; + std::optional get_value() const { + return has_value() ? std::optional(m_value) : std::nullopt; } - DLDataType unwrap(DebugInfo info = {}) const - { - RuntimeCheck(info, has_value(), "DType not set"); - return m_value; + DLDataType unwrap(DebugInfo info = {}) const { + RuntimeCheck(info, has_value(), "DType not set"); + return m_value; } void set_options(details::ArrayView opts) { m_opts = opts; } template - void set_options() - { - m_opts = details::ArrayView(details::kDTypeList.data(), details::kDTypeList.size()); + void set_options() { + m_opts = details::ArrayView(details::kDTypeList.data(), details::kDTypeList.size()); } - void verify(DLDataType dtype) - { - if (has_value()) - { - RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " got ", dtype); - } - else - { - set_value(dtype); - } + void verify(DLDataType dtype) { + if (has_value()) { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " got ", dtype); + } else { + set_value(dtype); + } } template - bool is_type() const - { - return m_value == details::_dtype_trait::value; + bool is_type() const { + return m_value == details::_dtype_trait::value; } - private: - bool m_check(DLDataType v) const - { - if (m_opts.empty()) - return true; - for (size_t i = 0; i < m_opts.size; ++i) - if (m_opts[i] == v) - return true; - return false; +private: + bool m_check(DLDataType v) const { + if (m_opts.empty()) { + return true; + } + for (size_t i = 0; i < m_opts.size; ++i) { + if (m_opts[i] == v) { + return true; + } + } + return false; } details::ArrayView m_opts; DLDataType m_value; - }; - - // ============================================== - // SymbolicDevice 完整定义 - // ============================================== - struct SymbolicDevice - { - public: +}; + +// ============================================== +// SymbolicDevice 完整定义 +// ============================================== +struct SymbolicDevice { +public: SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} SymbolicDevice(const SymbolicDevice &) = delete; SymbolicDevice &operator=(const SymbolicDevice &) = delete; bool has_value() const { return m_value.device_type != details::kNullDevice; } - void set_value(DLDevice v) - { - RuntimeCheck(!has_value(), "Device already set"); - RuntimeCheck(m_check(v), "Device not allowed: ", details::PrintableDevice{v}); - m_value = v; + void set_value(DLDevice v) { + RuntimeCheck(!has_value(), "Device already set"); + RuntimeCheck(m_check(v), "Device not allowed: ", details::PrintableDevice{v}); + m_value = v; } - std::optional get_value() const - { - return has_value() ? std::optional(m_value) : std::nullopt; + std::optional get_value() const { + return has_value() ? std::optional(m_value) : std::nullopt; } - DLDevice unwrap(DebugInfo info = {}) const - { - RuntimeCheck(info, has_value(), "Device not set"); - return m_value; + DLDevice unwrap(DebugInfo info = {}) const { + RuntimeCheck(info, has_value(), "Device not set"); + return m_value; } void set_options(details::ArrayView opts) { m_opts = opts; } template - void set_options() - { - m_opts = details::ArrayView(details::kDeviceList.data(), details::kDeviceList.size()); - } - - void verify(DLDevice dev) - { - if (has_value()) - { - RuntimeCheck(m_value == dev, "Device mismatch: expected ", - details::PrintableDevice{m_value}, " got ", details::PrintableDevice{dev}); - } - else - { - set_value(dev); - } - } - - private: - bool m_check(DLDevice v) const - { - if (m_opts.empty()) - return true; - for (size_t i = 0; i < m_opts.size; ++i) - { - auto o = m_opts[i]; - if (o.device_type != v.device_type) - continue; - if (o.device_id == details::kAnyDeviceID || o.device_id == v.device_id) - return true; - } - return false; + void set_options() { + m_opts = details::ArrayView(details::kDeviceList.data(), details::kDeviceList.size()); + } + + void verify(DLDevice dev) { + if (has_value()) { + RuntimeCheck(m_value == dev, "Device mismatch: expected ", + details::PrintableDevice{m_value}, " got ", details::PrintableDevice{dev}); + } else { + set_value(dev); + } + } + +private: + bool m_check(DLDevice v) const { + if (m_opts.empty()) { + return true; + } + for (size_t i = 0; i < m_opts.size; ++i) { + auto o = m_opts[i]; + if (o.device_type != v.device_type) { + continue; + } + if (o.device_id == details::kAnyDeviceID || o.device_id == v.device_id) { + return true; + } + } + return false; } details::ArrayView m_opts; DLDevice m_value; - }; - - // ============================================== - // BaseRef / Ref 类型(现在类型已完整定义) - // ============================================== - namespace details - { - template - struct BaseRef - { - BaseRef() : m_ref(&m_cache) {} - explicit BaseRef(T &r) : m_ref(&r) {} - BaseRef(const BaseRef &) = delete; - BaseRef &operator=(const BaseRef &) = delete; - - T *operator->() const { return m_ref; } - T &operator*() const { return *m_ref; } - void rebind(T &r) { m_ref = &r; } - - private: - T *m_ref; - T m_cache; - }; - - struct SizeRef : public BaseRef - { - using BaseRef::BaseRef; - SizeRef(int64_t v); - }; - - struct DTypeRef : public BaseRef - { - using BaseRef::BaseRef; - DTypeRef(DLDataType); - DTypeRef(std::initializer_list); - DTypeRef(ArrayView); - }; - - struct DeviceRef : public BaseRef - { - using BaseRef::BaseRef; - DeviceRef(DLDevice); - DeviceRef(std::initializer_list); - DeviceRef(ArrayView); - }; - - inline SizeRef::SizeRef(int64_t v) - { - if (v != kAnySize) +}; + +// ============================================== +// BaseRef / Ref 类型(现在类型已完整定义) +// ============================================== +namespace details { +template +struct BaseRef { + BaseRef() : m_ref(&m_cache) {} + explicit BaseRef(T &r) : m_ref(&r) {} + BaseRef(const BaseRef &) = delete; + BaseRef &operator=(const BaseRef &) = delete; + + T *operator->() const { return m_ref; } + T &operator*() const { return *m_ref; } + void rebind(T &r) { m_ref = &r; } + +private: + T *m_ref; + T m_cache; +}; + +struct SizeRef : public BaseRef { + using BaseRef::BaseRef; + SizeRef(int64_t v); +}; + +struct DTypeRef : public BaseRef { + using BaseRef::BaseRef; + DTypeRef(DLDataType); + DTypeRef(std::initializer_list); + DTypeRef(ArrayView); +}; + +struct DeviceRef : public BaseRef { + using BaseRef::BaseRef; + DeviceRef(DLDevice); + DeviceRef(std::initializer_list); + DeviceRef(ArrayView); +}; + +inline SizeRef::SizeRef(int64_t v) { + if (v != kAnySize) { (**this).set_value(v); } - inline DTypeRef::DTypeRef(DLDataType v) { (**this).set_value(v); } - inline DTypeRef::DTypeRef(std::initializer_list l) : DTypeRef(ArrayView(l.begin(), l.size())) {} - inline DTypeRef::DTypeRef(ArrayView v) { (**this).set_options(v); } - inline DeviceRef::DeviceRef(DLDevice v) { (**this).set_value(v); } - inline DeviceRef::DeviceRef(std::initializer_list l) : DeviceRef(ArrayView(l.begin(), l.size())) {} - inline DeviceRef::DeviceRef(ArrayView v) { (**this).set_options(v); } +} +inline DTypeRef::DTypeRef(DLDataType v) { (**this).set_value(v); } +inline DTypeRef::DTypeRef(std::initializer_list l) : DTypeRef(ArrayView(l.begin(), l.size())) {} +inline DTypeRef::DTypeRef(ArrayView v) { (**this).set_options(v); } +inline DeviceRef::DeviceRef(DLDevice v) { (**this).set_value(v); } +inline DeviceRef::DeviceRef(std::initializer_list l) : DeviceRef(ArrayView(l.begin(), l.size())) {} +inline DeviceRef::DeviceRef(ArrayView v) { (**this).set_options(v); } - } // namespace details +} // namespace details - template - inline bool is_type(DLDataType dtype) - { +template +inline bool is_type(DLDataType dtype) { return dtype == details::_dtype_trait::value; - } +} - // ============================================== - // TensorMatcher - // ============================================== - struct TensorMatcher - { +// ============================================== +// TensorMatcher +// ============================================== +struct TensorMatcher { using SizeRef = details::SizeRef; using DTypeRef = details::DTypeRef; using DeviceRef = details::DeviceRef; @@ -504,47 +451,42 @@ namespace host explicit TensorMatcher(std::initializer_list s) : m_shape(s.begin(), s.size()), m_strides(nullptr, 0) {} - TensorMatcher &&with_strides(std::initializer_list s) && - { - RuntimeCheck(m_strides.empty(), "Strides already set"); - RuntimeCheck(m_shape.size == s.size(), "Stride/shape size mismatch"); - m_strides = details::ArrayView(s.begin(), s.size()); - return std::move(*this); + TensorMatcher &&with_strides(std::initializer_list s) && { + RuntimeCheck(m_strides.empty(), "Strides already set"); + RuntimeCheck(m_shape.size == s.size(), "Stride/shape size mismatch"); + m_strides = details::ArrayView(s.begin(), s.size()); + return std::move(*this); } template - TensorMatcher &&with_dtype(DTypeRef &&d) && - { - m_dtype.rebind(*d); - m_dtype->template set_options(); - return std::move(*this); + TensorMatcher &&with_dtype(DTypeRef &&d) && { + m_dtype.rebind(*d); + m_dtype->template set_options(); + return std::move(*this); } template - TensorMatcher &&with_dtype() && - { - m_dtype->template set_options(); - return std::move(*this); + TensorMatcher &&with_dtype() && { + m_dtype->template set_options(); + return std::move(*this); } template - TensorMatcher &&with_device(DeviceRef &&d) && - { - m_device.rebind(*d); - m_device->template set_options(); - return std::move(*this); + TensorMatcher &&with_device(DeviceRef &&d) && { + m_device.rebind(*d); + m_device->template set_options(); + return std::move(*this); } template - TensorMatcher &&with_device() && - { - m_device->template set_options(); - return std::move(*this); + TensorMatcher &&with_device() && { + m_device->template set_options(); + return std::move(*this); } const TensorMatcher &&verify(tvm::ffi::TensorView, DebugInfo = {}) const &&; - private: +private: static void s_print_tensor(std::ostringstream &, tvm::ffi::TensorView); void m_verify_impl(tvm::ffi::TensorView) const; @@ -552,70 +494,62 @@ namespace host details::ArrayView m_strides; DTypeRef m_dtype; DeviceRef m_device; - }; +}; - inline void TensorMatcher::s_print_tensor(std::ostringstream &os, tvm::ffi::TensorView v) - { +inline void TensorMatcher::s_print_tensor(std::ostringstream &os, tvm::ffi::TensorView v) { os << "Tensor<"; size_t d = 0; - for (int64_t s : v.shape()) - { - if (d++) - os << ", "; - os << s; + for (int64_t s : v.shape()) { + if (d++) { + os << ", "; + } + os << s; } os << ">[strides=<"; d = 0; - for (int64_t s : v.strides()) - { - if (d++) - os << ", "; - os << s; + for (int64_t s : v.strides()) { + if (d++) { + os << ", "; + } + os << s; } os << ">, dtype=" << v.dtype(); os << ", device=" << details::PrintableDevice{v.device()} << "]"; - } - - inline const TensorMatcher &&TensorMatcher::verify(tvm::ffi::TensorView v, DebugInfo info) const && - { - try - { - m_verify_impl(v); - } - catch (PanicError &e) - { - std::ostringstream os; - os << "Tensor match failed: "; - s_print_tensor(os, v); - os << " @ " << info.file_name() << ":" << info.line() << "\n- cause: " << e.root_cause(); - throw PanicError(os.str()); +} + +inline const TensorMatcher &&TensorMatcher::verify(tvm::ffi::TensorView v, DebugInfo info) const && { + try { + m_verify_impl(v); + } catch (PanicError &e) { + std::ostringstream os; + os << "Tensor match failed: "; + s_print_tensor(os, v); + os << " @ " << info.file_name() << ":" << info.line() << "\n- cause: " << e.root_cause(); + throw PanicError(os.str()); } return std::move(*this); - } +} - inline void TensorMatcher::m_verify_impl(tvm::ffi::TensorView v) const - { +inline void TensorMatcher::m_verify_impl(tvm::ffi::TensorView v) const { size_t dim = static_cast(v.dim()); RuntimeCheck(dim == m_shape.size, "Dim mismatch: expected ", m_shape.size, " got ", dim); - for (size_t i = 0; i < dim; ++i) - m_shape[i]->verify(v.size(i), "shape", (int64_t)i); - - if (!m_strides.empty()) - { - for (size_t i = 0; i < dim; ++i) - { - if (v.size(i) != 1 || !m_strides[i]->has_value()) - m_strides[i]->verify(v.stride(i), "stride", (int64_t)i); - } + for (size_t i = 0; i < dim; ++i) { + m_shape[i]->verify(v.size(i), "shape", (int64_t)i); } - else - { - RuntimeCheck(v.is_contiguous(), "Tensor not contiguous"); + + if (!m_strides.empty()) { + for (size_t i = 0; i < dim; ++i) { + if (v.size(i) != 1 || !m_strides[i]->has_value()) { + m_strides[i]->verify(v.stride(i), "stride", (int64_t)i); + } + } + } else { + RuntimeCheck(v.is_contiguous(), "Tensor not contiguous"); } m_dtype->verify(v.dtype()); m_device->verify(v.device()); - } +} } // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh index d73c2ac04..18d5da7c3 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh @@ -127,7 +127,8 @@ template SGL_DEVICE void PDLWaitPrimary() { #if SGL_ARCH_HOPPER_OR_GREATER if constexpr (kUsePDL) { - asm volatile("griddepcontrol.wait;" ::: "memory"); + asm volatile("griddepcontrol.wait;" :: + : "memory"); } #endif } @@ -142,7 +143,8 @@ template SGL_DEVICE void PDLTriggerSecondary() { #if SGL_ARCH_HOPPER_OR_GREATER if constexpr (kUsePDL) { - asm volatile("griddepcontrol.launch_dependents;" :::); + asm volatile("griddepcontrol.launch_dependents;" :: + :); } #endif } diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h index bf7a5ce40..d6892d0dd 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h @@ -45,197 +45,172 @@ #include "source_location.h" #endif -#include -#include #include #include +#include #include #include #include +#include #include -namespace host -{ +namespace host { - template - inline constexpr bool dependent_false_v = false; +template +inline constexpr bool dependent_false_v = false; - /// \brief Source-location wrapper for debug/error messages. - struct DebugInfo : public source_location_t - { +/// \brief Source-location wrapper for debug/error messages. +struct DebugInfo : public source_location_t { DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} - }; +}; - /// \brief Exception type thrown by `RuntimeCheck` and `Panic`. - struct PanicError : public std::runtime_error - { - public: +/// \brief Exception type thrown by `RuntimeCheck` and `Panic`. +struct PanicError : public std::runtime_error { +public: explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} - auto root_cause() const -> std::string_view - { - const auto str = std::string_view{m_message}; - const auto pos = str.find(": "); - return pos == std::string_view::npos ? str : str.substr(pos + 2); + auto root_cause() const -> std::string_view { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); } - private: +private: std::string m_message; - }; +}; - /// \brief Unconditionally abort with a formatted error message. - template - [[noreturn]] - inline auto panic(DebugInfo location, Args &&...args) -> void - { +/// \brief Unconditionally abort with a formatted error message. +template +[[noreturn]] inline auto panic(DebugInfo location, Args &&...args) -> void { std::ostringstream os; os << "Runtime check failed at " << location.file_name() << ":" << location.line(); - if constexpr (sizeof...(args) > 0) - { - os << ": "; - (os << ... << std::forward(args)); - } - else - { - os << " in " << location.function_name(); + if constexpr (sizeof...(args) > 0) { + os << ": "; + (os << ... << std::forward(args)); + } else { + os << " in " << location.function_name(); } throw PanicError(std::move(os).str()); - } - - /** - * \brief Runtime assertion: panics with a formatted message when `condition` - * is false. Extra `args` are streamed to the error message. - * - * Example: - * \code - * RuntimeCheck(n > 0, "n must be positive, got ", n); - * \endcode - */ - template - struct RuntimeCheck - { +} + +/** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ +template +struct RuntimeCheck { template - explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) - { - if (condition) - return; - [[unlikely]] ::host::panic(location, std::forward(args)...); + explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) { + if (condition) { + return; + } + [[unlikely]] ::host::panic(location, std::forward(args)...); } template - explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) - { - if (condition) - return; - [[unlikely]] ::host::panic(location, std::forward(args)...); + explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) { + if (condition) { + return; + } + [[unlikely]] ::host::panic(location, std::forward(args)...); } - }; - - template - struct Panic - { - explicit Panic(Args &&...args, DebugInfo location = {}) - { - ::host::panic(location, std::forward(args)...); +}; + +template +struct Panic { + explicit Panic(Args &&...args, DebugInfo location = {}) { + ::host::panic(location, std::forward(args)...); } - explicit Panic(DebugInfo location, Args &&...args) - { - ::host::panic(location, std::forward(args)...); + explicit Panic(DebugInfo location, Args &&...args) { + ::host::panic(location, std::forward(args)...); } - [[noreturn]] ~Panic() - { - std::terminate(); + [[noreturn]] ~Panic() { + std::terminate(); } - }; +}; - template - explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; +template +explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; - template - explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; +template +explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; - template - explicit Panic(Args &&...) -> Panic; +template +explicit Panic(Args &&...) -> Panic; - template - explicit Panic(DebugInfo, Args &&...) -> Panic; +template +explicit Panic(DebugInfo, Args &&...) -> Panic; - namespace pointer - { +namespace pointer { - // we only allow void * pointer arithmetic for safety +// we only allow void * pointer arithmetic for safety - template ::value && ...)>> - inline auto offset(void *ptr, U... offset) -> void * - { - return static_cast(ptr) + (... + offset); - } +template ::value && ...)>> +inline auto offset(void *ptr, U... offset) -> void * { + return static_cast(ptr) + (... + offset); +} - template ::value && ...)>> - inline auto offset(const void *ptr, U... offset) -> const void * - { - return static_cast(ptr) + (... + offset); - } +template ::value && ...)>> +inline auto offset(const void *ptr, U... offset) -> const void * { + return static_cast(ptr) + (... + offset); +} - } // namespace pointer +} // namespace pointer - /// \brief Integer ceiling division: ceil(a / b). - template - inline constexpr auto div_ceil(T a, U b) - { +/// \brief Integer ceiling division: ceil(a / b). +template +inline constexpr auto div_ceil(T a, U b) { static_assert(std::is_integral::value, "T must be integral"); static_assert(std::is_integral::value, "U must be integral"); return (a + b - 1) / b; - } +} - /// \brief Returns the byte width of a DLPack data type. - inline auto dtype_bytes(DLDataType dtype) -> std::size_t - { +/// \brief Returns the byte width of a DLPack data type. +inline auto dtype_bytes(DLDataType dtype) -> std::size_t { return static_cast(dtype.bits / 8); - } +} - // ====================== 修复开始:纯 C++11 兼容版 irange ====================== - // 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 +// ====================== 修复开始:纯 C++11 兼容版 irange ====================== +// 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 - template - struct IntegerRange - { +template +struct IntegerRange { T start_; T end_; - struct Iterator - { - T value; - - T operator*() const { return value; } - Iterator &operator++() - { - ++value; - return *this; - } - bool operator!=(const Iterator &other) const - { - return value != other.value; - } + struct Iterator { + T value; + + T operator*() const { return value; } + Iterator &operator++() { + ++value; + return *this; + } + bool operator!=(const Iterator &other) const { + return value != other.value; + } }; Iterator begin() const { return {start_}; } Iterator end() const { return {end_}; } - }; +}; - /// Python-style integer range: irange(n) -> [0, n) - template - IntegerRange irange(T end) - { +/// Python-style integer range: irange(n) -> [0, n) +template +IntegerRange irange(T end) { return {0, end}; - } +} - /// Python-style integer range: irange(start, end) -> [start, end) - template - IntegerRange irange(T start, T end) - { +/// Python-style integer range: irange(start, end) -> [start, end) +template +IntegerRange irange(T start, T end) { return {start, end}; - } - // ====================== 修复结束 ====================== +} +// ====================== 修复结束 ====================== } // namespace host diff --git a/test/infiniop/gptq_marlin_gemm.py b/test/infiniop/gptq_marlin_gemm.py index 9ba296d18..8119fbe8d 100644 --- a/test/infiniop/gptq_marlin_gemm.py +++ b/test/infiniop/gptq_marlin_gemm.py @@ -41,17 +41,21 @@ mnk_factors = MNK_FACTORS act_order = [False, True] + def to_iter(x): return x if isinstance(x, (list, tuple)) else (x,) -_TEST_CASES = list(itertools.product( - to_iter(k_chunk), - to_iter(n_chunk), - to_iter(quant_type), - to_iter(group_size), - to_iter(mnk_factors), - to_iter(act_order), -)) + +_TEST_CASES = list( + itertools.product( + to_iter(k_chunk), + to_iter(n_chunk), + to_iter(quant_type), + to_iter(group_size), + to_iter(mnk_factors), + to_iter(act_order), + ) +) _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] @@ -70,6 +74,7 @@ def to_iter(x): SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + def quantize_weights( w: torch.Tensor, quant_type: ScalarType, @@ -164,10 +169,12 @@ def reshape_w(w): maybe_w_zp, ) + def get_pack_factor(num_bits): assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits + def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" @@ -182,6 +189,7 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): return q_w + def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute q_w = marlin_permute_weights(q_w, size_k, size_n, perm) @@ -230,6 +238,7 @@ def get_weight_perm(num_bits: int): perm = torch.from_numpy(perm) return perm + def get_scale_perms(): scale_perm: list[int] = [] for i in range(8): @@ -253,6 +262,7 @@ def marlin_permute_scales( return s + def pack_cols( q_w: torch.Tensor, num_bits: int, @@ -278,6 +288,7 @@ def pack_cols( return q_res + def marlin_zero_points( zp: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: @@ -300,6 +311,7 @@ def marlin_zero_points( return zp + def permute_rows( q_w: torch.Tensor, w_ref: torch.Tensor, @@ -329,6 +341,7 @@ def permute_rows( rand_perm.to(device=orig_device), ) + def gptq_quantize_weights( w: torch.Tensor, quant_type: ScalarType, @@ -377,6 +390,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): sort_indices.to(device=orig_device), ) + def marlin_quantize( w: torch.Tensor, quant_type: ScalarType, @@ -415,6 +429,7 @@ def marlin_quantize( return res_list + def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): size_k, size_n = w.shape @@ -443,6 +458,7 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int return res_list + def marlin_make_workspace( device: torch.device, max_blocks_per_sm: int = 1 ) -> torch.Tensor: @@ -488,7 +504,7 @@ def test( if size_k % group_size != 0: return - + print( f"Testing Gptq Marlin Gemm on {InfiniDeviceNames[device]} with M-K-N:({size_m, size_k, size_n}), group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" ) @@ -509,28 +525,63 @@ def test( marlin_zp = None marlin_s2 = None output_ref = torch.matmul(a_input.torch_tensor(), w_ref) - b = TestTensor(marlin_q_w.shape, marlin_q_w.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_q_w) + b = TestTensor( + marlin_q_w.shape, + marlin_q_w.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_q_w, + ) c = TestTensor(output_ref.shape, None, dtype, device) - b_scales = TestTensor(marlin_s.shape, marlin_s.stride(), dtype, device, mode="manual", set_tensor=marlin_s) + b_scales = TestTensor( + marlin_s.shape, + marlin_s.stride(), + dtype, + device, + mode="manual", + set_tensor=marlin_s, + ) global_scale = None if marlin_zp is not None: - b_zeros = TestTensor(marlin_zp.shape, marlin_zp.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_zp) + b_zeros = TestTensor( + marlin_zp.shape, + marlin_zp.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_zp, + ) else: b_zeros = None if g_idx is not None: - b_g_idx = TestTensor(g_idx.shape, g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx) + b_g_idx = TestTensor( + g_idx.shape, + g_idx.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=g_idx, + ) else: b_g_idx = None if sort_indices is not None: - perm = TestTensor(sort_indices.shape, sort_indices.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=sort_indices) + perm = TestTensor( + sort_indices.shape, + sort_indices.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=sort_indices, + ) else: perm = None - - is_k_full=True - use_atomic_add=False - use_fp32_reduce=False - is_zp_float=False - + + is_k_full = True + use_atomic_add = False + use_fp32_reduce = False + is_zp_float = False + if sync is not None: sync() @@ -554,7 +605,7 @@ def test( for tensor in [c, a_input, b, b_scales, global_scale, b_zeros, b_g_idx, perm]: if tensor is not None: tensor.destroy_desc() - + workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetGptqMarlinGemmWorkspaceSize( @@ -577,18 +628,17 @@ def lib_gptq_marlin_gemm(): b_zeros.data() if b_zeros is not None else None, b_g_idx.data() if b_g_idx is not None else None, perm.data() if perm is not None else None, - quant_type.id, - is_k_full, - use_atomic_add, - use_fp32_reduce, - is_zp_float, + quant_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, None, ) ) lib_gptq_marlin_gemm() - max_diff = torch.mean(torch.abs(c.actual_tensor() - output_ref)) / torch.mean( torch.abs(output_ref) ) @@ -603,7 +653,11 @@ def lib_gptq_marlin_gemm(): NUM_ITERATIONS, ) profile_operation( - " lib", lambda: lib_gptq_marlin_gemm(), device, NUM_PRERUN, NUM_ITERATIONS + " lib", + lambda: lib_gptq_marlin_gemm(), + device, + NUM_PRERUN, + NUM_ITERATIONS, ) check_error(LIBINFINIOP.infiniopDestroyGptqMarlinGemmDescriptor(descriptor)) diff --git a/test/infiniop/libinfiniop/scalar_type.py b/test/infiniop/libinfiniop/scalar_type.py index bc9f067c1..571c1dca7 100644 --- a/test/infiniop/libinfiniop/scalar_type.py +++ b/test/infiniop/libinfiniop/scalar_type.py @@ -350,5 +350,3 @@ class scalar_types: # colloquial names bfloat16 = float16_e8m7 float16 = float16_e5m10 - - From ee0467279d0830b4e9180ad8442056b6a5c2ab34 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 1 Apr 2026 14:12:32 +0800 Subject: [PATCH 3/3] issue/1083: modified global --- .../ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu index 3e424ac4f..59271f78b 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu @@ -13,13 +13,13 @@ namespace device::marlin { -__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; +INFINIOP_CUDA_KERNEL MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -__global__ void permute_cols_kernel( +INFINIOP_CUDA_KERNEL permute_cols_kernel( int4 const *__restrict__ a_int4_ptr, int const *__restrict__ perm_int_ptr, int4 *__restrict__ out_int4_ptr, @@ -32,7 +32,7 @@ __global__ void permute_cols_kernel( // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. -__global__ void permute_cols_kernel( +INFINIOP_CUDA_KERNEL permute_cols_kernel( int4 const *__restrict__ a_int4_ptr, int const *__restrict__ perm_int_ptr, int4 *__restrict__ out_int4_ptr,