Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions crates/core_arch/src/x86_64/amx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,22 @@ pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
tcvtrowd2ps(TILE as i8, row).as_m512()
}

/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
/// elements to packed single-precision (32-bit) floating-point elements.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowd2psi(TILE as i8, ROW as u32).as_m512()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
Expand All @@ -414,6 +430,23 @@ pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
tcvtrowps2phh(TILE as i8, row).as_m512h()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
Expand All @@ -430,6 +463,23 @@ pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
tcvtrowps2phl(TILE as i8, row).as_m512h()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
}

/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0)]
Expand All @@ -444,6 +494,21 @@ pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
tilemovrow(TILE as i8, row).as_m512i()
}

/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tilemovrow, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_movrowi<const TILE: i32, const ROW: i32>() -> __m512i {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tilemovrowi(TILE as i8, ROW as u32).as_m512i()
}

#[allow(improper_ctypes)]
unsafe extern "C" {
#[link_name = "llvm.x86.ldtilecfg"]
Expand Down Expand Up @@ -492,12 +557,20 @@ unsafe extern "C" {
fn tmmultf32ps(dst: i8, a: i8, b: i8);
#[link_name = "llvm.x86.tcvtrowd2ps"]
fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
#[link_name = "llvm.x86.tcvtrowd2psi"]
fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16;
#[link_name = "llvm.x86.tcvtrowps2phh"]
fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phhi"]
fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phl"]
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phli"]
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tilemovrow"]
fn tilemovrow(tile: i8, row: u32) -> i32x16;
#[link_name = "llvm.x86.tilemovrowi"]
fn tilemovrowi(tile: i8, row: u32) -> i32x16;
}

#[cfg(test)]
Expand Down Expand Up @@ -1032,6 +1105,50 @@ mod tests {
}
}

macro_rules! wrap_imm4 {
($name:ident :: <$TILE:literal>, $row:expr) => {
match $row {
0 => $name::<$TILE, 0>(),
1 => $name::<$TILE, 1>(),
2 => $name::<$TILE, 2>(),
3 => $name::<$TILE, 3>(),
4 => $name::<$TILE, 4>(),
5 => $name::<$TILE, 5>(),
6 => $name::<$TILE, 6>(),
7 => $name::<$TILE, 7>(),
8 => $name::<$TILE, 8>(),
9 => $name::<$TILE, 9>(),
10 => $name::<$TILE, 10>(),
11 => $name::<$TILE, 11>(),
12 => $name::<$TILE, 12>(),
13 => $name::<$TILE, 13>(),
14 => $name::<$TILE, 14>(),
15 => $name::<$TILE, 15>(),
_ => panic!("row index out of range"),
}
};
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_movrowi() {
unsafe {
_init_amx();
let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);

for i in 0..16 {
let row = wrap_imm4!(_tile_movrowi::<0>, i);
assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowd2ps() {
unsafe {
Expand All @@ -1051,6 +1168,26 @@ mod tests {
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowd2psi() {
unsafe {
_init_amx();
let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);

for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowd2psi::<0>, i);
assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phh() {
unsafe {
Expand All @@ -1073,6 +1210,28 @@ mod tests {
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phhi() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2phhi::<0>, i);
assert_eq!(
*row.as_f16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phl() {
unsafe {
Expand All @@ -1095,6 +1254,28 @@ mod tests {
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2phli() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2phli::<0>, i);
assert_eq!(
*row.as_f16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
);
}
}
}

#[simd_test(enable = "amx-tf32")]
fn test_tile_mmultf32ps() {
unsafe {
Expand Down