From 78891f7d23e7091e63d1ce9597805d7534fff416 Mon Sep 17 00:00:00 2001 From: sayantn Date: Mon, 2 Mar 2026 00:27:37 +0530 Subject: [PATCH] Add immediate AMX intrinsics --- crates/core_arch/src/x86_64/amx.rs | 181 +++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index 4e20e014cf..03bbe3e449 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -398,6 +398,22 @@ pub unsafe fn _tile_cvtrowd2ps(row: u32) -> __m512 { tcvtrowd2ps(TILE as i8, row).as_m512() } +/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer +/// elements to packed single-precision (32-bit) floating-point elements. +#[inline] +#[rustc_legacy_const_generics(0, 1)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowd2psi() -> __m512 { + static_assert_uimm_bits!(TILE, 3); + static_assert_uimm_bits!(ROW, 6); + tcvtrowd2psi(TILE as i8, ROW as u32).as_m512() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -414,6 +430,23 @@ pub unsafe fn _tile_cvtrowps2phh(row: u32) -> __m512h { tcvtrowps2phh(TILE as i8, row).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +#[inline] +#[rustc_legacy_const_generics(0, 1)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowps2phhi() -> __m512h { + static_assert_uimm_bits!(TILE, 3); + static_assert_uimm_bits!(ROW, 6); + tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -430,6 +463,23 @@ pub unsafe fn _tile_cvtrowps2phl(row: u32) -> __m512h { tcvtrowps2phl(TILE as i8, row).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +#[inline] +#[rustc_legacy_const_generics(0, 1)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cvtrowps2phli() -> __m512h { + static_assert_uimm_bits!(TILE, 3); + static_assert_uimm_bits!(ROW, 6); + tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h() +} + /// Moves one row of tile data into a zmm vector register #[inline] #[rustc_legacy_const_generics(0)] @@ -444,6 +494,21 @@ pub unsafe fn _tile_movrow(row: u32) -> __m512i { tilemovrow(TILE as i8, row).as_m512i() } +/// Moves one row of tile data into a zmm vector register +#[inline] +#[rustc_legacy_const_generics(0, 1)] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr( + all(test, any(target_os = "linux", target_env = "msvc")), + assert_instr(tilemovrow, TILE = 0, ROW = 0) +)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_movrowi() -> __m512i { + static_assert_uimm_bits!(TILE, 3); + static_assert_uimm_bits!(ROW, 6); + tilemovrowi(TILE as i8, ROW as u32).as_m512i() +} + #[allow(improper_ctypes)] unsafe extern "C" { #[link_name = "llvm.x86.ldtilecfg"] @@ -492,12 +557,20 @@ unsafe extern "C" { fn tmmultf32ps(dst: i8, a: i8, b: i8); #[link_name = "llvm.x86.tcvtrowd2ps"] fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowd2psi"] + fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16; #[link_name = "llvm.x86.tcvtrowps2phh"] fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phhi"] + fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phl"] fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phli"] + fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tilemovrow"] fn tilemovrow(tile: i8, row: u32) -> i32x16; + #[link_name = "llvm.x86.tilemovrowi"] + fn tilemovrowi(tile: i8, row: u32) -> i32x16; } #[cfg(test)] @@ -1032,6 +1105,50 @@ mod tests { } } + macro_rules! wrap_imm4 { + ($name:ident :: <$TILE:literal>, $row:expr) => { + match $row { + 0 => $name::<$TILE, 0>(), + 1 => $name::<$TILE, 1>(), + 2 => $name::<$TILE, 2>(), + 3 => $name::<$TILE, 3>(), + 4 => $name::<$TILE, 4>(), + 5 => $name::<$TILE, 5>(), + 6 => $name::<$TILE, 6>(), + 7 => $name::<$TILE, 7>(), + 8 => $name::<$TILE, 8>(), + 9 => $name::<$TILE, 9>(), + 10 => $name::<$TILE, 10>(), + 11 => $name::<$TILE, 11>(), + 12 => $name::<$TILE, 12>(), + 13 => $name::<$TILE, 13>(), + 14 => $name::<$TILE, 14>(), + 15 => $name::<$TILE, 15>(), + _ => panic!("row index out of range"), + } + }; + } + + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_movrowi() { + unsafe { + _init_amx(); + let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = wrap_imm4!(_tile_movrowi::<0>, i); + assert_eq!(*row.as_u8x64().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowd2ps() { unsafe { @@ -1051,6 +1168,26 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowd2psi() { + unsafe { + _init_amx(); + let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = wrap_imm4!(_tile_cvtrowd2psi::<0>, i); + assert_eq!(*row.as_f32x16().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phh() { unsafe { @@ -1073,6 +1210,28 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowps2phhi() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + for i in 0..16 { + let row = wrap_imm4!(_tile_cvtrowps2phhi::<0>, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phl() { unsafe { @@ -1095,6 +1254,28 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test_tile_cvtrowps2phli() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_loadd::<0>(array.as_ptr().cast(), 64); + for i in 0..16 { + let row = wrap_imm4!(_tile_cvtrowps2phli::<0>, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 }) + ); + } + } + } + #[simd_test(enable = "amx-tf32")] fn test_tile_mmultf32ps() { unsafe {