Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiter/utility/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"gfx942": {"fp8": torch.float8_e4m3fnuz},
"gfx950": {"fp8": torch.float8_e4m3fn},
"gfx1250": {"fp8": torch.float8_e4m3fn},
"gfx1201": {"fp8": torch.float8_e4m3fn},
}

_8bit_fallback = torch.uint8
Expand Down
27 changes: 20 additions & 7 deletions csrc/include/aiter_opus_plus.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "hip_reduce.h"
#include "opus.hpp"
#include "opus/opus.hpp"
// todo: remove this to use aiterTensor dtype
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
Expand All @@ -17,26 +18,37 @@ using index_t = int;

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// scaled type conversion: v_pk_mul_f32 + v_med3_f32 + v_cvt_pk_{fp8,bf8}_f32
// Identical ISA to ck_tile::vec_convert for performance parity
// gfx11/gfx12 can use scalar/package-math style multiplies here.
#if defined(__gfx11__) || defined(__gfx12__)
OPUS_D fp32x2_t amd_scalar_mul_f32(fp32x2_t a, fp32x2_t b)
{
fp32x2_t c;
c[0] = a[0] * b[0];
c[1] = a[1] * b[1];
return c;
}

OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b)
{
return amd_scalar_mul_f32(a, b);
}
#else
OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b)
{
fp32x2_t c;
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
#endif

// fp32x2 -> fp8x2 with scale + saturation clamp (E4M3)
// ISA: v_pk_mul_f32 + v_med3_f32 x2 + v_cvt_pk_fp8_f32
template <typename S, std::enable_if_t<std::is_same_v<S, fp32x2_t>, bool> = true>
OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale)
{
fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale});
#if defined(__gfx950__)
constexpr float hi = 448.0f, lo = -448.0f;
#else
constexpr float hi = 240.0f, lo = -240.0f;
#endif
constexpr float hi = finfo<fp8_t>::max();
constexpr float lo = finfo<fp8_t>::min();
float a = tmp[0], b = tmp[1];
int w;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
Expand All @@ -61,7 +73,8 @@ template <typename S, std::enable_if_t<std::is_same_v<S, fp32x2_t>, bool> = true
OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale)
{
fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale});
constexpr float hi = 57344.0f, lo = -57344.0f;
constexpr float hi = finfo<bf8_t>::max();
constexpr float lo = finfo<bf8_t>::min();
float a = tmp[0], b = tmp[1];
int w;
asm volatile("v_med3_f32 %1, %1, %3, %4\n"
Expand Down
23 changes: 20 additions & 3 deletions csrc/include/ck_tile/vec_convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b)
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// use scalar math for RDNA4/3 without v_pk_mul_f32
CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b)
{
fp32x2_v c;
c[0] = a[0] * b[0];
c[1] = a[1] * b[1];
return c;
}
CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b)
{
int16x2_t c;
Expand Down Expand Up @@ -145,8 +153,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv
using vec_ti = vector_traits<fp32x2_v>;
constexpr int vec_size = vec_ti::vector_size;
constexpr auto interpret = numeric_traits<fp8_t>::f8_interpret;
fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});

fp32x2_v tmp;
#if defined(__gfx11__) || defined(__gfx12__)
tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif
return (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
? amd_assembly_cvt_pk_fp8_f32(tmp[0], tmp[1])
Expand All @@ -155,7 +167,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv
// fp32x2 -> int8x2
CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale)
{
fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
fp32x2_v tmp;
#if defined(__gfx11__) || defined(__gfx12__)
tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#else
tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale});
#endif

int8x2_v out;
out[0] = static_cast<int8_t>(tmp[0]);
Expand Down
25 changes: 25 additions & 0 deletions csrc/include/hip_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,28 @@ __device__ constexpr T wave_reduce(T local, F reduce_op)

if constexpr(WarpSize > 16)
{
// DPP broadcasts (0x142, 0x143) are not supported on GFX10+ (gfx12 included)
// Use ds_bpermute instead for cross-lane communication
#if defined(__gfx12__) || defined(__gfx11__)
// Use shuffle for gfx12 instead of DPP broadcast
T v_remote = rocprim::warp_shuffle(local, 15, WarpSize);
local = reduce_op(v_remote, local);
#else
// row_bcast:15
local = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142>(local), local);
#endif
}

if constexpr(WarpSize > 32)
{
#if defined(__gfx12__) || defined(__gfx11__)
// Use shuffle for gfx12 instead of DPP broadcast
T v_remote = rocprim::warp_shuffle(local, 31, WarpSize);
local = reduce_op(v_remote, local);
#else
// row_bcast:31
local = reduce_op(rocprim::detail::warp_move_dpp<T, 0x143>(local), local);
#endif
}

if constexpr(threadBroadcast && WarpSize > 4)
Expand Down Expand Up @@ -167,7 +181,12 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num)
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x4e>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x124>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x128>(data), data);
#if defined(__gfx12__) || defined(__gfx11__)
// DPP broadcast 0x142 not supported on gfx12, use shuffle
data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data);
#else
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142, 0xa>(data), data);
#endif
if constexpr(threadBroadcast)
{
data = rocprim::warp_shuffle(data, thread_num - 1, thread_num);
Expand All @@ -180,8 +199,14 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num)
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x4e>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x124>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x128>(data), data);
#if defined(__gfx12__) || defined(__gfx11__)
// DPP broadcasts not supported on gfx12, use shuffle
data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data);
data = reduce_op(rocprim::warp_shuffle(data, 31, WarpSize), data);
#else
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x142>(data), data);
data = reduce_op(rocprim::detail::warp_move_dpp<T, 0x143>(data), data);
#endif
if constexpr(threadBroadcast)
{
data = rocprim::warp_shuffle(data, thread_num - 1, thread_num);
Expand Down
52 changes: 38 additions & 14 deletions csrc/include/opus/opus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,7 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) {
//
OPUS_H_D constexpr index_t get_warp_size()
{
#if defined(__gfx1250__)
#if defined(__gfx1250__) || defined(__gfx12__)
return 32;
#elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
Expand Down Expand Up @@ -2078,8 +2078,8 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4;
#endif // __GFX9__ (mfma)

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// wmma (gfx1250 / RDNA4, wave32)
#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__)
// wmma (gfx1250 / gfx12 / RDNA4, wave32)
#if defined(__gfx1250__) || defined(__gfx12__) || !defined(__HIP_DEVICE_COMPILE__)
// f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c)
#define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \
(std::is_same_v<dtype_a, ta_> && std::is_same_v<dtype_b, tb_> && std::is_same_v<dtype_c, tc_> && \
Expand Down Expand Up @@ -2130,6 +2130,12 @@ struct wmma {
// Format code for scaled WMMA (f8f6f4); -1 for types that don't support scaling
static constexpr int fmt_a = std::is_same_v<dtype_a, fp8_t> ? 0 : std::is_same_v<dtype_a, bf8_t> ? 1 : std::is_same_v<dtype_a, fp4_t> ? 4 : -1;
static constexpr int fmt_b = std::is_same_v<dtype_b, fp8_t> ? 0 : std::is_same_v<dtype_b, bf8_t> ? 1 : std::is_same_v<dtype_b, fp4_t> ? 4 : -1;
#if defined(__gfx12__)
static_assert(
wave_m == 16 && wave_n == 16 && wave_k == 16,
"gfx12 WMMA in OPUS only supports 16x16x16 shapes; use SWMMAC or a gfx1250 path for wider K"
);
#endif

// Regular (non-scaled) dispatch
template<typename VA, typename VB, typename VC>
Expand Down Expand Up @@ -2164,6 +2170,17 @@ struct wmma {
else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8)
else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8)
else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8)
#elif defined(__gfx12__)
// RDNA4 gfx12 WMMA: 16x16x16 for the common floating-point cases.
else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12)
else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp16_t, 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12)
else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12)
else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, bf16_t, 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12)
else if constexpr DISPATCH_WMMA_8BIT_(u8_t, u8_t, i32_t, 16, 16, 16, __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12)
else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12)
else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12)
else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12)
else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12)
#endif
__builtin_unreachable();
}
Expand Down Expand Up @@ -2260,14 +2277,13 @@ struct wmma {
#undef DISPATCH_WMMA_BF16F32_
#undef DISPATCH_WMMA_8BIT_

// f16/bf16 16x16x32
// gfx1250: wide-k WMMA and scaled WMMA
#if defined(__gfx1250__)
using wmma_f32_16x16x32_f16 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 32>;
using wmma_f16_16x16x32_f16 = wmma<fp16_t, fp16_t, fp16_t, 16, 16, 32>;
using wmma_f32_16x16x32_bf16 = wmma<bf16_t, bf16_t, fp32_t, 16, 16, 32>;
using wmma_bf16_16x16x32_bf16 = wmma<bf16_t, bf16_t, bf16_t, 16, 16, 32>;
// f32 16x16x4
using wmma_f32_16x16x4_f32 = wmma<fp32_t, fp32_t, fp32_t, 16, 16, 4>;
// fp8/bf8 16x16x64
using wmma_f32_16x16x64_fp8_fp8 = wmma<fp8_t, fp8_t, fp32_t, 16, 16, 64>;
using wmma_f32_16x16x64_fp8_bf8 = wmma<fp8_t, bf8_t, fp32_t, 16, 16, 64>;
using wmma_f32_16x16x64_bf8_fp8 = wmma<bf8_t, fp8_t, fp32_t, 16, 16, 64>;
Expand All @@ -2276,7 +2292,6 @@ using wmma_f16_16x16x64_fp8_fp8 = wmma<fp8_t, fp8_t, fp16_t, 16, 16, 64>;
using wmma_f16_16x16x64_fp8_bf8 = wmma<fp8_t, bf8_t, fp16_t, 16, 16, 64>;
using wmma_f16_16x16x64_bf8_fp8 = wmma<bf8_t, fp8_t, fp16_t, 16, 16, 64>;
using wmma_f16_16x16x64_bf8_bf8 = wmma<bf8_t, bf8_t, fp16_t, 16, 16, 64>;
// fp8/bf8 16x16x128
using wmma_f32_16x16x128_fp8_fp8 = wmma<fp8_t, fp8_t, fp32_t, 16, 16, 128>;
using wmma_f32_16x16x128_fp8_bf8 = wmma<fp8_t, bf8_t, fp32_t, 16, 16, 128>;
using wmma_f32_16x16x128_bf8_fp8 = wmma<bf8_t, fp8_t, fp32_t, 16, 16, 128>;
Expand All @@ -2285,12 +2300,21 @@ using wmma_f16_16x16x128_fp8_fp8 = wmma<fp8_t, fp8_t, fp16_t, 16, 16, 128>;
using wmma_f16_16x16x128_fp8_bf8 = wmma<fp8_t, bf8_t, fp16_t, 16, 16, 128>;
using wmma_f16_16x16x128_bf8_fp8 = wmma<bf8_t, fp8_t, fp16_t, 16, 16, 128>;
using wmma_f16_16x16x128_bf8_bf8 = wmma<bf8_t, bf8_t, fp16_t, 16, 16, 128>;
// Scaled WMMA (f8f6f4 unified instruction, supports fp8/bf8/fp4 via format code)
using wmma_scale_f32_16x16x128_fp8_fp8 = wmma<fp8_t, fp8_t, fp32_t, 16, 16, 128>;
using wmma_scale_f32_16x16x128_fp4_fp4 = wmma<fp4_t, fp4_t, fp32_t, 16, 16, 128>;
// Scaled WMMA (dedicated fp4 32x16x128 instruction)
using wmma_scale_f32_32x16x128_fp4_fp4 = wmma<fp4_t, fp4_t, fp32_t, 32, 16, 128>;
#endif // __gfx1250__ (wmma)
#elif defined(__gfx12__)
using wmma_f32_16x16x16_f16_gfx12 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 16>;
using wmma_f16_16x16x16_f16_gfx12 = wmma<fp16_t, fp16_t, fp16_t, 16, 16, 16>;
using wmma_f32_16x16x16_bf16_gfx12 = wmma<bf16_t, bf16_t, fp32_t, 16, 16, 16>;
using wmma_bf16_16x16x16_bf16_gfx12 = wmma<bf16_t, bf16_t, bf16_t, 16, 16, 16>;
using wmma_i32_16x16x16_iu8_gfx12 = wmma<u8_t, u8_t, i32_t, 16, 16, 16>;
using wmma_f32_16x16x16_fp8_fp8_gfx12 = wmma<fp8_t, fp8_t, fp32_t, 16, 16, 16>;
using wmma_f32_16x16x16_fp8_bf8_gfx12 = wmma<fp8_t, bf8_t, fp32_t, 16, 16, 16>;
using wmma_f32_16x16x16_bf8_fp8_gfx12 = wmma<bf8_t, fp8_t, fp32_t, 16, 16, 16>;
using wmma_f32_16x16x16_bf8_bf8_gfx12 = wmma<bf8_t, bf8_t, fp32_t, 16, 16, 16>;
#endif
#endif // __gfx1250__ || __gfx12__ (wmma)

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// adaptor
Expand Down Expand Up @@ -2499,11 +2523,11 @@ template<typename d_a, typename d_b, typename d_c, typename WaveMNK /*seq<m, n,
OPUS_D decltype(auto) make_mfma(WaveMNK&&, A&& = {}, number<warp_size_> = {}) { return A{}(mfma<d_a, d_b, d_c, get<0>(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); }
#endif // __GFX9__

// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250)
// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250/gfx12)
// A:[(grpm_a<p>), (rept_a<y>, grpk_a<p>, pack_a<y>)], MxK
// B:[(grpn_b<p>), (rept_b<y>, grpk_b<p>, pack_b<y>)], NxK
// C:[(grpm_c<p>, rept_c<y>, pack_c<y>), (grpn_c<p>)], MxN
#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__)
#if defined(__gfx1250__) || defined(__gfx12__) || !defined(__HIP_DEVICE_COMPILE__)
namespace impl {
template<typename WMMA>
struct wmma_adaptor : public remove_cvref_t<WMMA> {
Expand Down Expand Up @@ -2731,14 +2755,14 @@ OPUS_D decltype(auto) make_tiled_mma(MMA&& mma, ES, TS, A&& = {}) {
}

template<typename d_a, typename d_b, typename d_c, typename ES /* expand-m/n/k */, typename TS /* tile-m/n/k */, typename WS /* wave-m/n/k*/,
#if defined(__gfx1250__)
#if defined(__gfx1250__) || defined(__gfx12__)
typename WA = wmma_adaptor,
#else
typename WA = mfma_adaptor,
#endif
typename TA = tiled_mma_adaptor, index_t warp_size = get_warp_size()>
OPUS_D decltype(auto) make_tiled_mma(ES, TS, WS, WA&& = {}, TA&& = {}) {
#if defined(__gfx1250__)
#if defined(__gfx1250__) || defined(__gfx12__)
return TA{}(make_wmma<d_a, d_b, d_c>(WS{}, WA{}, number<warp_size>{}),
#else
return TA{}(make_mfma<d_a, d_b, d_c>(WS{}, WA{}, number<warp_size>{}),
Expand Down
34 changes: 26 additions & 8 deletions csrc/kernels/quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,21 @@ __global__ void smooth_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out,
#pragma unroll
for(int i = 0; i < async_load_num; i++)
{
// buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size);
const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size))));
uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int);
asm volatile( "s_mov_b32 m0 %0\n\t"
#if defined(__gfx12__)
int idx = threadIdx.x + i * block_size;
if(idx < smooth_scale_map_hash_size)
{
// RDNA4 doesn't support buffer_load_* with LDS modifier
// Use standard global load to VGPR then write to LDS
smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx];
}
#else
const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size))));
uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int);
asm volatile( "s_mov_b32 m0 %0\n\t"
"buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t"
::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0");
#endif
}
}

Expand Down Expand Up @@ -1187,12 +1196,21 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v1(DTYPE_O* __restrict_
#pragma unroll
for(int i = 0; i < async_load_num; i++)
{
// buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size);
const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size))));
uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int);
asm volatile( "s_mov_b32 m0 %0\n\t"
#if defined(__gfx12__)
int idx = threadIdx.x + i * block_size;
if(idx < smooth_scale_map_hash_size)
{
// RDNA4 doesn't support buffer_load_* with LDS modifier
// Use standard global load to VGPR then write to LDS
smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx];
}
#else
const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size))));
uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int);
asm volatile( "s_mov_b32 m0 %0\n\t"
"buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t"
::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0");
#endif
}
}
int smscale_map_idx_list = 0;
Expand Down
14 changes: 12 additions & 2 deletions csrc/kernels/rmsnorm_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ __global__ void add_rmsnorm_quant_kernel(
vec2_f* thread_data_float2 = reinterpret_cast<vec2_f*>(&thread_data_float);
for(int i = 0; i < thread_data_size / 2; i++)
{
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(rcp));
#if defined(__gfx11__) || defined(__gfx12__)
thread_data_float[2*i] *= rcp[0];
thread_data_float[2*i+1] *= rcp[1];
#else
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(rcp));
#endif
}

float* thread_data_weight2 = reinterpret_cast<float*>(&thread_data_weight);
Expand All @@ -170,7 +175,12 @@ __global__ void add_rmsnorm_quant_kernel(
// : "v"(thread_data_weight2[i])
// );
// }
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(thread_data_weight_float2));
#if defined(__gfx11__) || defined(__gfx12__)
thread_data_float[2*i] *= rcp[0];
thread_data_float[2*i+1] *= rcp[1];
#else
asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(thread_data_weight_float2));
#endif
}

if constexpr(FUSE_QUANT)
Expand Down
Loading