diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index f51acd2ed5..fd2f6bcfd5 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -10,6 +10,7 @@ "gfx942": {"fp8": torch.float8_e4m3fnuz}, "gfx950": {"fp8": torch.float8_e4m3fn}, "gfx1250": {"fp8": torch.float8_e4m3fn}, + "gfx1201": {"fp8": torch.float8_e4m3fn}, } _8bit_fallback = torch.uint8 diff --git a/csrc/include/aiter_opus_plus.h b/csrc/include/aiter_opus_plus.h index cc4f395d58..6d5802b938 100644 --- a/csrc/include/aiter_opus_plus.h +++ b/csrc/include/aiter_opus_plus.h @@ -4,6 +4,7 @@ #include "hip_reduce.h" #include "opus.hpp" +#include "opus/opus.hpp" // todo: remove this to use aiterTensor dtype #include #include @@ -17,14 +18,28 @@ using index_t = int; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // scaled type conversion: v_pk_mul_f32 + v_med3_f32 + v_cvt_pk_{fp8,bf8}_f32 -// Identical ISA to ck_tile::vec_convert for performance parity +// gfx11/gfx12 can use scalar/package-math style multiplies here. +#if defined(__gfx11__) || defined(__gfx12__) +OPUS_D fp32x2_t amd_scalar_mul_f32(fp32x2_t a, fp32x2_t b) +{ + fp32x2_t c; + c[0] = a[0] * b[0]; + c[1] = a[1] * b[1]; + return c; +} +OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b) +{ + return amd_scalar_mul_f32(a, b); +} +#else OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b) { fp32x2_t c; asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); return c; } +#endif // fp32x2 -> fp8x2 with scale + saturation clamp (E4M3) // ISA: v_pk_mul_f32 + v_med3_f32 x2 + v_cvt_pk_fp8_f32 @@ -32,11 +47,8 @@ template , bool> = true OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); -#if defined(__gfx950__) - constexpr float hi = 448.0f, lo = -448.0f; -#else - constexpr float hi = 240.0f, lo = -240.0f; -#endif + constexpr float hi = finfo::max(); + constexpr float lo = finfo::min(); float a = tmp[0], b = tmp[1]; int w; asm volatile("v_med3_f32 %1, %1, %3, %4\n" @@ -61,7 +73,8 @@ template , bool> = true OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); - constexpr float hi = 57344.0f, lo = -57344.0f; + constexpr float hi = finfo::max(); + constexpr float lo = finfo::min(); float a = tmp[0], b = tmp[1]; int w; asm volatile("v_med3_f32 %1, %1, %3, %4\n" diff --git a/csrc/include/ck_tile/vec_convert.h b/csrc/include/ck_tile/vec_convert.h index aaabcd3508..e8c954c799 100644 --- a/csrc/include/ck_tile/vec_convert.h +++ b/csrc/include/ck_tile/vec_convert.h @@ -43,6 +43,14 @@ CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b) asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); return c; } +// use scalar math for RDNA4/3 without v_pk_mul_f32 +CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b) +{ + fp32x2_v c; + c[0] = a[0] * b[0]; + c[1] = a[1] * b[1]; + return c; +} CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b) { int16x2_t c; @@ -145,8 +153,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv using vec_ti = vector_traits; constexpr int vec_size = vec_ti::vector_size; constexpr auto interpret = numeric_traits::f8_interpret; - fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); - + fp32x2_v tmp; +#if defined(__gfx11__) || defined(__gfx12__) + tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#else + tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#endif return (interpret == fp8_interpretation::E4M3_FNUZ) || (interpret == fp8_interpretation::E4M3_OCP) ? amd_assembly_cvt_pk_fp8_f32(tmp[0], tmp[1]) @@ -155,7 +167,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv // fp32x2 -> int8x2 CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale) { - fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); + fp32x2_v tmp; +#if defined(__gfx11__) || defined(__gfx12__) + tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#else + tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#endif int8x2_v out; out[0] = static_cast(tmp[0]); diff --git a/csrc/include/hip_reduce.h b/csrc/include/hip_reduce.h index 0d787fa4f3..2850f77343 100644 --- a/csrc/include/hip_reduce.h +++ b/csrc/include/hip_reduce.h @@ -113,14 +113,28 @@ __device__ constexpr T wave_reduce(T local, F reduce_op) if constexpr(WarpSize > 16) { +// DPP broadcasts (0x142, 0x143) are not supported on GFX10+ (gfx12 included) +// Use ds_bpermute instead for cross-lane communication +#if defined(__gfx12__) || defined(__gfx11__) + // Use shuffle for gfx12 instead of DPP broadcast + T v_remote = rocprim::warp_shuffle(local, 15, WarpSize); + local = reduce_op(v_remote, local); +#else // row_bcast:15 local = reduce_op(rocprim::detail::warp_move_dpp(local), local); +#endif } if constexpr(WarpSize > 32) { +#if defined(__gfx12__) || defined(__gfx11__) + // Use shuffle for gfx12 instead of DPP broadcast + T v_remote = rocprim::warp_shuffle(local, 31, WarpSize); + local = reduce_op(v_remote, local); +#else // row_bcast:31 local = reduce_op(rocprim::detail::warp_move_dpp(local), local); +#endif } if constexpr(threadBroadcast && WarpSize > 4) @@ -167,7 +181,12 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num) data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#if defined(__gfx12__) || defined(__gfx11__) + // DPP broadcast 0x142 not supported on gfx12, use shuffle + data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data); +#else data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#endif if constexpr(threadBroadcast) { data = rocprim::warp_shuffle(data, thread_num - 1, thread_num); @@ -180,8 +199,14 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num) data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#if defined(__gfx12__) || defined(__gfx11__) + // DPP broadcasts not supported on gfx12, use shuffle + data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data); + data = reduce_op(rocprim::warp_shuffle(data, 31, WarpSize), data); +#else data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#endif if constexpr(threadBroadcast) { data = rocprim::warp_shuffle(data, thread_num - 1, thread_num); diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 0b1839bbaa..b7a7dd4ac6 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1405,7 +1405,7 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { // OPUS_H_D constexpr index_t get_warp_size() { -#if defined(__gfx1250__) +#if defined(__gfx1250__) || defined(__gfx12__) return 32; #elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; @@ -2078,8 +2078,8 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// -// wmma (gfx1250 / RDNA4, wave32) -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +// wmma (gfx1250 / gfx12 / RDNA4, wave32) +#if defined(__gfx1250__) || defined(__gfx12__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ @@ -2130,6 +2130,12 @@ struct wmma { // Format code for scaled WMMA (f8f6f4); -1 for types that don't support scaling static constexpr int fmt_a = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; static constexpr int fmt_b = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; +#if defined(__gfx12__) + static_assert( + wave_m == 16 && wave_n == 16 && wave_k == 16, + "gfx12 WMMA in OPUS only supports 16x16x16 shapes; use SWMMAC or a gfx1250 path for wider K" + ); +#endif // Regular (non-scaled) dispatch template @@ -2164,6 +2170,17 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) +#elif defined(__gfx12__) + // RDNA4 gfx12 WMMA: 16x16x16 for the common floating-point cases. + else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp16_t, 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, bf16_t, 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(u8_t, u8_t, i32_t, 16, 16, 16, __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) #endif __builtin_unreachable(); } @@ -2260,14 +2277,13 @@ struct wmma { #undef DISPATCH_WMMA_BF16F32_ #undef DISPATCH_WMMA_8BIT_ -// f16/bf16 16x16x32 +// gfx1250: wide-k WMMA and scaled WMMA +#if defined(__gfx1250__) using wmma_f32_16x16x32_f16 = wmma; using wmma_f16_16x16x32_f16 = wmma; using wmma_f32_16x16x32_bf16 = wmma; using wmma_bf16_16x16x32_bf16 = wmma; -// f32 16x16x4 using wmma_f32_16x16x4_f32 = wmma; -// fp8/bf8 16x16x64 using wmma_f32_16x16x64_fp8_fp8 = wmma; using wmma_f32_16x16x64_fp8_bf8 = wmma; using wmma_f32_16x16x64_bf8_fp8 = wmma; @@ -2276,7 +2292,6 @@ using wmma_f16_16x16x64_fp8_fp8 = wmma; using wmma_f16_16x16x64_fp8_bf8 = wmma; using wmma_f16_16x16x64_bf8_fp8 = wmma; using wmma_f16_16x16x64_bf8_bf8 = wmma; -// fp8/bf8 16x16x128 using wmma_f32_16x16x128_fp8_fp8 = wmma; using wmma_f32_16x16x128_fp8_bf8 = wmma; using wmma_f32_16x16x128_bf8_fp8 = wmma; @@ -2285,12 +2300,21 @@ using wmma_f16_16x16x128_fp8_fp8 = wmma; using wmma_f16_16x16x128_fp8_bf8 = wmma; using wmma_f16_16x16x128_bf8_fp8 = wmma; using wmma_f16_16x16x128_bf8_bf8 = wmma; -// Scaled WMMA (f8f6f4 unified instruction, supports fp8/bf8/fp4 via format code) using wmma_scale_f32_16x16x128_fp8_fp8 = wmma; using wmma_scale_f32_16x16x128_fp4_fp4 = wmma; -// Scaled WMMA (dedicated fp4 32x16x128 instruction) using wmma_scale_f32_32x16x128_fp4_fp4 = wmma; -#endif // __gfx1250__ (wmma) +#elif defined(__gfx12__) +using wmma_f32_16x16x16_f16_gfx12 = wmma; +using wmma_f16_16x16x16_f16_gfx12 = wmma; +using wmma_f32_16x16x16_bf16_gfx12 = wmma; +using wmma_bf16_16x16x16_bf16_gfx12 = wmma; +using wmma_i32_16x16x16_iu8_gfx12 = wmma; +using wmma_f32_16x16x16_fp8_fp8_gfx12 = wmma; +using wmma_f32_16x16x16_fp8_bf8_gfx12 = wmma; +using wmma_f32_16x16x16_bf8_fp8_gfx12 = wmma; +using wmma_f32_16x16x16_bf8_bf8_gfx12 = wmma; +#endif +#endif // __gfx1250__ || __gfx12__ (wmma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // adaptor @@ -2499,11 +2523,11 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250) +// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250/gfx12) // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +#if defined(__gfx1250__) || defined(__gfx12__) || !defined(__HIP_DEVICE_COMPILE__) namespace impl { template struct wmma_adaptor : public remove_cvref_t { @@ -2731,14 +2755,14 @@ OPUS_D decltype(auto) make_tiled_mma(MMA&& mma, ES, TS, A&& = {}) { } template OPUS_D decltype(auto) make_tiled_mma(ES, TS, WS, WA&& = {}, TA&& = {}) { -#if defined(__gfx1250__) +#if defined(__gfx1250__) || defined(__gfx12__) return TA{}(make_wmma(WS{}, WA{}, number{}), #else return TA{}(make_mfma(WS{}, WA{}, number{}), diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index dd394694d7..ae52964428 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -477,12 +477,21 @@ __global__ void smooth_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out, #pragma unroll for(int i = 0; i < async_load_num; i++) { - // buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size); - const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); - uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); - asm volatile( "s_mov_b32 m0 %0\n\t" + #if defined(__gfx12__) + int idx = threadIdx.x + i * block_size; + if(idx < smooth_scale_map_hash_size) + { + // RDNA4 doesn't support buffer_load_* with LDS modifier + // Use standard global load to VGPR then write to LDS + smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx]; + } + #else + const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); + uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); + asm volatile( "s_mov_b32 m0 %0\n\t" "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t" ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0"); + #endif } } @@ -1187,12 +1196,21 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v1(DTYPE_O* __restrict_ #pragma unroll for(int i = 0; i < async_load_num; i++) { - // buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size); - const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); - uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); - asm volatile( "s_mov_b32 m0 %0\n\t" + #if defined(__gfx12__) + int idx = threadIdx.x + i * block_size; + if(idx < smooth_scale_map_hash_size) + { + // RDNA4 doesn't support buffer_load_* with LDS modifier + // Use standard global load to VGPR then write to LDS + smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx]; + } + #else + const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); + uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); + asm volatile( "s_mov_b32 m0 %0\n\t" "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t" ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0"); + #endif } } int smscale_map_idx_list = 0; diff --git a/csrc/kernels/rmsnorm_quant_kernels.cu b/csrc/kernels/rmsnorm_quant_kernels.cu index 8a8359d445..fea8c46372 100644 --- a/csrc/kernels/rmsnorm_quant_kernels.cu +++ b/csrc/kernels/rmsnorm_quant_kernels.cu @@ -143,7 +143,12 @@ __global__ void add_rmsnorm_quant_kernel( vec2_f* thread_data_float2 = reinterpret_cast(&thread_data_float); for(int i = 0; i < thread_data_size / 2; i++) { - asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(rcp)); + #if defined(__gfx11__) || defined(__gfx12__) + thread_data_float[2*i] *= rcp[0]; + thread_data_float[2*i+1] *= rcp[1]; + #else + asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(rcp)); + #endif } float* thread_data_weight2 = reinterpret_cast(&thread_data_weight); @@ -170,7 +175,12 @@ __global__ void add_rmsnorm_quant_kernel( // : "v"(thread_data_weight2[i]) // ); // } - asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(thread_data_weight_float2)); + #if defined(__gfx11__) || defined(__gfx12__) + thread_data_float[2*i] *= rcp[0]; + thread_data_float[2*i+1] *= rcp[1]; + #else + asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(thread_data_weight_float2)); + #endif } if constexpr(FUSE_QUANT)