From 53025a82b0b94af1b9d2112fddb0903a73a14791 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Thu, 19 Mar 2026 07:40:18 +0000 Subject: [PATCH 1/3] patch quant kernels for gfx12 --- aiter/utility/dtypes.py | 1 + csrc/include/ck_tile/vec_convert.h | 23 +++++++++++++++++--- csrc/include/hip_reduce.h | 25 ++++++++++++++++++++++ csrc/kernels/quant_kernels.cu | 34 +++++++++++++++++++++++------- 4 files changed, 72 insertions(+), 11 deletions(-) diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 9a90fe48cd..a7468991db 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -9,6 +9,7 @@ "gfx942": {"fp8": torch.float8_e4m3fnuz}, "gfx950": {"fp8": torch.float8_e4m3fn}, "gfx1250": {"fp8": torch.float8_e4m3fn}, + "gfx1201": {"fp8": torch.float8_e4m3fn}, } _8bit_fallback = torch.uint8 diff --git a/csrc/include/ck_tile/vec_convert.h b/csrc/include/ck_tile/vec_convert.h index aaabcd3508..e8c954c799 100644 --- a/csrc/include/ck_tile/vec_convert.h +++ b/csrc/include/ck_tile/vec_convert.h @@ -43,6 +43,14 @@ CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b) asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); return c; } +// use scalar math for RDNA4/3 without v_pk_mul_f32 +CK_TILE_DEVICE fp32x2_v amd_scalar_mul_f32(fp32x2_v a, fp32x2_t b) +{ + fp32x2_v c; + c[0] = a[0] * b[0]; + c[1] = a[1] * b[1]; + return c; +} CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b) { int16x2_t c; @@ -145,8 +153,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv using vec_ti = vector_traits; constexpr int vec_size = vec_ti::vector_size; constexpr auto interpret = numeric_traits::f8_interpret; - fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); - + fp32x2_v tmp; +#if defined(__gfx11__) || defined(__gfx12__) + tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#else + tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#endif return (interpret == fp8_interpretation::E4M3_FNUZ) || (interpret == fp8_interpretation::E4M3_OCP) ? amd_assembly_cvt_pk_fp8_f32(tmp[0], tmp[1]) @@ -155,7 +167,12 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inv // fp32x2 -> int8x2 CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale) { - fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); + fp32x2_v tmp; +#if defined(__gfx11__) || defined(__gfx12__) + tmp = amd_scalar_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#else + tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); +#endif int8x2_v out; out[0] = static_cast(tmp[0]); diff --git a/csrc/include/hip_reduce.h b/csrc/include/hip_reduce.h index 79a88aec3e..cab6ee69db 100644 --- a/csrc/include/hip_reduce.h +++ b/csrc/include/hip_reduce.h @@ -112,14 +112,28 @@ __device__ constexpr T wave_reduce(T local, F reduce_op) if constexpr(WarpSize > 16) { +// DPP broadcasts (0x142, 0x143) are not supported on GFX10+ (gfx12 included) +// Use ds_bpermute instead for cross-lane communication +#if defined(__gfx12__) || defined(__gfx11__) + // Use shuffle for gfx12 instead of DPP broadcast + T v_remote = rocprim::warp_shuffle(local, 15, WarpSize); + local = reduce_op(v_remote, local); +#else // row_bcast:15 local = reduce_op(rocprim::detail::warp_move_dpp(local), local); +#endif } if constexpr(WarpSize > 32) { +#if defined(__gfx12__) || defined(__gfx11__) + // Use shuffle for gfx12 instead of DPP broadcast + T v_remote = rocprim::warp_shuffle(local, 31, WarpSize); + local = reduce_op(v_remote, local); +#else // row_bcast:31 local = reduce_op(rocprim::detail::warp_move_dpp(local), local); +#endif } if constexpr(threadBroadcast && WarpSize > 4) @@ -166,7 +180,12 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num) data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#if defined(__gfx12__) || defined(__gfx11__) + // DPP broadcast 0x142 not supported on gfx12, use shuffle + data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data); +#else data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#endif if constexpr(threadBroadcast) { data = rocprim::warp_shuffle(data, thread_num - 1, thread_num); @@ -179,8 +198,14 @@ __device__ constexpr T multithread_reduce(T data, F reduce_op, int thread_num) data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#if defined(__gfx12__) || defined(__gfx11__) + // DPP broadcasts not supported on gfx12, use shuffle + data = reduce_op(rocprim::warp_shuffle(data, 15, WarpSize), data); + data = reduce_op(rocprim::warp_shuffle(data, 31, WarpSize), data); +#else data = reduce_op(rocprim::detail::warp_move_dpp(data), data); data = reduce_op(rocprim::detail::warp_move_dpp(data), data); +#endif if constexpr(threadBroadcast) { data = rocprim::warp_shuffle(data, thread_num - 1, thread_num); diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 5c28742138..b0b58a1412 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -500,12 +500,21 @@ __global__ void smooth_per_token_scaled_quant_kernel(DTYPE_O* __restrict__ out, #pragma unroll for(int i = 0; i < async_load_num; i++) { - // buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size); - const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); - uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); - asm volatile( "s_mov_b32 m0 %0\n\t" + #if defined(__gfx12__) + int idx = threadIdx.x + i * block_size; + if(idx < smooth_scale_map_hash_size) + { + // RDNA4 doesn't support buffer_load_* with LDS modifier + // Use standard global load to VGPR then write to LDS + smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx]; + } + #else + const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); + uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); + asm volatile( "s_mov_b32 m0 %0\n\t" "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t" ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0"); + #endif } } @@ -1210,12 +1219,21 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v1(DTYPE_O* __restrict_ #pragma unroll for(int i = 0; i < async_load_num; i++) { - // buffer_hash.async_load(smooth_scale_map_hash_shared + threadIdx.x + i * block_size, threadIdx.x + i * block_size); - const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); - uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); - asm volatile( "s_mov_b32 m0 %0\n\t" + #if defined(__gfx12__) + int idx = threadIdx.x + i * block_size; + if(idx < smooth_scale_map_hash_size) + { + // RDNA4 doesn't support buffer_load_* with LDS modifier + // Use standard global load to VGPR then write to LDS + smooth_scale_map_hash_shared[idx] = smooth_scale_map_hash[idx]; + } + #else + const int lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast((smooth_scale_map_hash_shared + threadIdx.x / WARP_SIZE * WARP_SIZE + i * block_size)))); + uint32_t offset = threadIdx.x * sizeof(int) + i * block_size * sizeof(int); + asm volatile( "s_mov_b32 m0 %0\n\t" "buffer_load_dword %1, %2, 0 offen offset:0 lds\n\t" ::"s"(lds_ptr_sgpr), "v"(offset), "s"(buffer_hash.cached_rsrc): "memory", "m0"); + #endif } } int smscale_map_idx_list = 0; From 39f37253eb4d0880ccf6f07e8e645c3a204b63ad Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Thu, 19 Mar 2026 08:08:01 +0000 Subject: [PATCH 2/3] enable rmsnorm for gfx12xx --- csrc/kernels/rmsnorm_quant_kernels.cu | 54 ++++++++++++++++++++------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/csrc/kernels/rmsnorm_quant_kernels.cu b/csrc/kernels/rmsnorm_quant_kernels.cu index b02e0cd12d..a1394dab03 100644 --- a/csrc/kernels/rmsnorm_quant_kernels.cu +++ b/csrc/kernels/rmsnorm_quant_kernels.cu @@ -143,7 +143,12 @@ __global__ void add_rmsnorm_quant_kernel( vec2_f* thread_data_float2 = reinterpret_cast(&thread_data_float); for(int i = 0; i < thread_data_size / 2; i++) { - asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(rcp)); + #if defined(__gfx11__) || defined(__gfx12__) + thread_data_float[2*i] *= rcp[0]; + thread_data_float[2*i+1] *= rcp[1]; + #else + asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(rcp)); + #endif } float* thread_data_weight2 = reinterpret_cast(&thread_data_weight); @@ -154,23 +159,46 @@ __global__ void add_rmsnorm_quant_kernel( // thread_data_weight_float2[1] = ck_tile::type_convert(thread_data_weight[2 * i + 1]); if constexpr(std::is_same_v) { - asm volatile( - "v_lshlrev_b32_e32 %0, 16 %2\n" - "v_and_b32_e32 %1 0xffff0000 %2\n" - : "=v"(thread_data_weight_float2[0]), "=v"(thread_data_weight_float2[1]) - : "v"(thread_data_weight2[i]) - ); + #if defined(__gfx11__) || defined(__gfx12__) + // RDNA: Use bit_cast + shift to unpack bf16 from packed storage + uint32_t w = ck_tile::bit_cast(thread_data_weight2[i]); + uint16_t lo = static_cast(w & 0xFFFF); + uint16_t hi = static_cast(w >> 16); + thread_data_weight_float2[0] = ck_tile::type_convert(ck_tile::bit_cast(lo)); + thread_data_weight_float2[1] = ck_tile::type_convert(ck_tile::bit_cast(hi)); + #else + asm volatile( + "v_lshlrev_b32_e32 %0, 16 %2\n" + "v_and_b32_e32 %1 0xffff0000 %2\n" + : "=v"(thread_data_weight_float2[0]), "=v"(thread_data_weight_float2[1]) + : "v"(thread_data_weight2[i]) + ); + #endif } else { - asm volatile( - "v_cvt_f32_f16_e32 %0 %2\n" - "v_cvt_f32_f16_sdwa %1 %2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1\n" - : "=v"(thread_data_weight_float2[0]), "=v"(thread_data_weight_float2[1]) - : "v"(thread_data_weight2[i]) - ); + #if defined(__gfx11__) || defined(__gfx12__) + // RDNA: Use bit_cast + shift to unpack fp16 from packed storage + uint32_t w = ck_tile::bit_cast(thread_data_weight2[i]); + uint16_t lo = static_cast(w & 0xFFFF); + uint16_t hi = static_cast(w >> 16); + thread_data_weight_float2[0] = ck_tile::type_convert(ck_tile::bit_cast(lo)); + thread_data_weight_float2[1] = ck_tile::type_convert(ck_tile::bit_cast(hi)); + #else + asm volatile( + "v_cvt_f32_f16_e32 %0 %2\n" + "v_cvt_f32_f16_sdwa %1 %2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1\n" + : "=v"(thread_data_weight_float2[0]), "=v"(thread_data_weight_float2[1]) + : "v"(thread_data_weight2[i]) + ); + #endif } + #if defined(__gfx11__) || defined(__gfx12__) + thread_data_float[2*i] *= rcp[0]; + thread_data_float[2*i+1] *= rcp[1]; + #else asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(thread_data_float2[i]) : "v"(thread_data_float2[i]), "v"(thread_data_weight_float2)); + #endif } if constexpr(FUSE_QUANT) From 5b6a60acdb42348b6f83eb9cde6977b1da9e293e Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Mon, 30 Mar 2026 06:32:55 +0000 Subject: [PATCH 3/3] patch aiter opus dtyps range and wmma for rdna4 --- csrc/include/aiter_opus_plus.h | 27 +++++++++++++----- csrc/include/opus/opus.hpp | 52 +++++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/csrc/include/aiter_opus_plus.h b/csrc/include/aiter_opus_plus.h index cc4f395d58..6d5802b938 100644 --- a/csrc/include/aiter_opus_plus.h +++ b/csrc/include/aiter_opus_plus.h @@ -4,6 +4,7 @@ #include "hip_reduce.h" #include "opus.hpp" +#include "opus/opus.hpp" // todo: remove this to use aiterTensor dtype #include #include @@ -17,14 +18,28 @@ using index_t = int; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // scaled type conversion: v_pk_mul_f32 + v_med3_f32 + v_cvt_pk_{fp8,bf8}_f32 -// Identical ISA to ck_tile::vec_convert for performance parity +// gfx11/gfx12 can use scalar/package-math style multiplies here. +#if defined(__gfx11__) || defined(__gfx12__) +OPUS_D fp32x2_t amd_scalar_mul_f32(fp32x2_t a, fp32x2_t b) +{ + fp32x2_t c; + c[0] = a[0] * b[0]; + c[1] = a[1] * b[1]; + return c; +} +OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b) +{ + return amd_scalar_mul_f32(a, b); +} +#else OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b) { fp32x2_t c; asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); return c; } +#endif // fp32x2 -> fp8x2 with scale + saturation clamp (E4M3) // ISA: v_pk_mul_f32 + v_med3_f32 x2 + v_cvt_pk_fp8_f32 @@ -32,11 +47,8 @@ template , bool> = true OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); -#if defined(__gfx950__) - constexpr float hi = 448.0f, lo = -448.0f; -#else - constexpr float hi = 240.0f, lo = -240.0f; -#endif + constexpr float hi = finfo::max(); + constexpr float lo = finfo::min(); float a = tmp[0], b = tmp[1]; int w; asm volatile("v_med3_f32 %1, %1, %3, %4\n" @@ -61,7 +73,8 @@ template , bool> = true OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); - constexpr float hi = 57344.0f, lo = -57344.0f; + constexpr float hi = finfo::max(); + constexpr float lo = finfo::min(); float a = tmp[0], b = tmp[1]; int w; asm volatile("v_med3_f32 %1, %1, %3, %4\n" diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 0b1839bbaa..b7a7dd4ac6 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1405,7 +1405,7 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { // OPUS_H_D constexpr index_t get_warp_size() { -#if defined(__gfx1250__) +#if defined(__gfx1250__) || defined(__gfx12__) return 32; #elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; @@ -2078,8 +2078,8 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// -// wmma (gfx1250 / RDNA4, wave32) -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +// wmma (gfx1250 / gfx12 / RDNA4, wave32) +#if defined(__gfx1250__) || defined(__gfx12__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ @@ -2130,6 +2130,12 @@ struct wmma { // Format code for scaled WMMA (f8f6f4); -1 for types that don't support scaling static constexpr int fmt_a = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; static constexpr int fmt_b = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; +#if defined(__gfx12__) + static_assert( + wave_m == 16 && wave_n == 16 && wave_k == 16, + "gfx12 WMMA in OPUS only supports 16x16x16 shapes; use SWMMAC or a gfx1250 path for wider K" + ); +#endif // Regular (non-scaled) dispatch template @@ -2164,6 +2170,17 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) +#elif defined(__gfx12__) + // RDNA4 gfx12 WMMA: 16x16x16 for the common floating-point cases. + else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp16_t, 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, bf16_t, 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(u8_t, u8_t, i32_t, 16, 16, 16, __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) #endif __builtin_unreachable(); } @@ -2260,14 +2277,13 @@ struct wmma { #undef DISPATCH_WMMA_BF16F32_ #undef DISPATCH_WMMA_8BIT_ -// f16/bf16 16x16x32 +// gfx1250: wide-k WMMA and scaled WMMA +#if defined(__gfx1250__) using wmma_f32_16x16x32_f16 = wmma; using wmma_f16_16x16x32_f16 = wmma; using wmma_f32_16x16x32_bf16 = wmma; using wmma_bf16_16x16x32_bf16 = wmma; -// f32 16x16x4 using wmma_f32_16x16x4_f32 = wmma; -// fp8/bf8 16x16x64 using wmma_f32_16x16x64_fp8_fp8 = wmma; using wmma_f32_16x16x64_fp8_bf8 = wmma; using wmma_f32_16x16x64_bf8_fp8 = wmma; @@ -2276,7 +2292,6 @@ using wmma_f16_16x16x64_fp8_fp8 = wmma; using wmma_f16_16x16x64_fp8_bf8 = wmma; using wmma_f16_16x16x64_bf8_fp8 = wmma; using wmma_f16_16x16x64_bf8_bf8 = wmma; -// fp8/bf8 16x16x128 using wmma_f32_16x16x128_fp8_fp8 = wmma; using wmma_f32_16x16x128_fp8_bf8 = wmma; using wmma_f32_16x16x128_bf8_fp8 = wmma; @@ -2285,12 +2300,21 @@ using wmma_f16_16x16x128_fp8_fp8 = wmma; using wmma_f16_16x16x128_fp8_bf8 = wmma; using wmma_f16_16x16x128_bf8_fp8 = wmma; using wmma_f16_16x16x128_bf8_bf8 = wmma; -// Scaled WMMA (f8f6f4 unified instruction, supports fp8/bf8/fp4 via format code) using wmma_scale_f32_16x16x128_fp8_fp8 = wmma; using wmma_scale_f32_16x16x128_fp4_fp4 = wmma; -// Scaled WMMA (dedicated fp4 32x16x128 instruction) using wmma_scale_f32_32x16x128_fp4_fp4 = wmma; -#endif // __gfx1250__ (wmma) +#elif defined(__gfx12__) +using wmma_f32_16x16x16_f16_gfx12 = wmma; +using wmma_f16_16x16x16_f16_gfx12 = wmma; +using wmma_f32_16x16x16_bf16_gfx12 = wmma; +using wmma_bf16_16x16x16_bf16_gfx12 = wmma; +using wmma_i32_16x16x16_iu8_gfx12 = wmma; +using wmma_f32_16x16x16_fp8_fp8_gfx12 = wmma; +using wmma_f32_16x16x16_fp8_bf8_gfx12 = wmma; +using wmma_f32_16x16x16_bf8_fp8_gfx12 = wmma; +using wmma_f32_16x16x16_bf8_bf8_gfx12 = wmma; +#endif +#endif // __gfx1250__ || __gfx12__ (wmma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // adaptor @@ -2499,11 +2523,11 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250) +// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250/gfx12) // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +#if defined(__gfx1250__) || defined(__gfx12__) || !defined(__HIP_DEVICE_COMPILE__) namespace impl { template struct wmma_adaptor : public remove_cvref_t { @@ -2731,14 +2755,14 @@ OPUS_D decltype(auto) make_tiled_mma(MMA&& mma, ES, TS, A&& = {}) { } template OPUS_D decltype(auto) make_tiled_mma(ES, TS, WS, WA&& = {}, TA&& = {}) { -#if defined(__gfx1250__) +#if defined(__gfx1250__) || defined(__gfx12__) return TA{}(make_wmma(WS{}, WA{}, number{}), #else return TA{}(make_mfma(WS{}, WA{}, number{}),