From 7870e5552eb6b20af903ba9764e7d98bc1b687f2 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Fri, 3 Apr 2026 10:32:01 +0800 Subject: [PATCH 1/4] metax support flash-attn Signed-off-by: Ceng23333 <441651826@qq.com> --- include/infinicore/adaptor/aten_adaptor.hpp | 12 +- .../adaptor/flash_attention_adaptor.hpp | 23 ++- scripts/install.py | 8 +- scripts/set_env.py | 75 ++++++++++ src/infinicore/adaptor/aten_adaptor.cc | 2 +- .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 50 +++++-- .../mha_varlen_flashattn.cc | 136 +++++++++++++++++- ...inary_cross_entropy_with_logits_metax.maca | 3 +- src/infiniop/ops/equal/metax/equal_metax.maca | 28 ++-- src/infiniop/ops/hardswish/cuda/kernel.cuh | 9 ++ src/infiniop/ops/hardtanh/cuda/kernel.cuh | 9 ++ xmake.lua | 28 ++-- xmake/metax.lua | 100 ++++++++++++- 13 files changed, 438 insertions(+), 45 deletions(-) diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 0c3237dc9..2f1884e0c 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -5,9 +5,11 @@ #include -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) #include -#include +#endif + +#ifdef ENABLE_NVIDIA_API #include #endif @@ -30,12 +32,10 @@ inline at::ScalarType to_at_dtype(DataType dtype) { } inline at::Device to_at_device(const Device &device) { - if (device.getType() == Device::Type::NVIDIA) { + if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX) { return at::Device(at::kCUDA, device.getIndex()); } else if (device.getType() == Device::Type::CPU) { return at::Device(at::kCPU); - } else if (device.getType() == Device::Type::QY) { - return at::Device(at::kCUDA, device.getIndex()); } else { throw std::runtime_error("Unsupported device type for ATen"); } @@ -43,7 +43,7 @@ inline at::Device to_at_device(const Device &device) { at::Tensor to_aten_tensor(const infinicore::Tensor &t); -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) c10::cuda::CUDAStream get_cuda_stream(); #endif } // namespace infinicore::adaptor diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp index 8a9e152fd..9ffcc42d6 100644 --- a/include/infinicore/adaptor/flash_attention_adaptor.hpp +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -2,7 +2,12 @@ #pragma once #include "aten_adaptor.hpp" +// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension +// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds +// where the namespace is empty. +#if !defined(ENABLE_METAX_API) namespace flash { +#endif std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) @@ -39,7 +44,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_); + std::optional gen_ +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + // MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn. + , + std::optional &flash_attn_mars_ext_ +#endif + ); std::vector mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) @@ -108,7 +119,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits); + int num_splits +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + // MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn. + , + std::optional &flash_attn_mars_ext_ +#endif + ); +#if !defined(ENABLE_METAX_API) } // namespace flash +#endif #endif // ENABLE_FLASH_ATTN diff --git a/scripts/install.py b/scripts/install.py index 2e420ee9f..bac2d95d1 100644 --- a/scripts/install.py +++ b/scripts/install.py @@ -2,7 +2,11 @@ import subprocess import platform import sys -from set_env import set_env +from set_env import ( + set_env, + ensure_metax_hpc_compiler_includes, + xmake_flags_need_metax_aten_torch_includes, +) PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) os.chdir(PROJECT_DIR) @@ -12,6 +16,8 @@ def run_cmd(cmd): def install(xmake_config_flags=""): + if xmake_flags_need_metax_aten_torch_includes(xmake_config_flags): + ensure_metax_hpc_compiler_includes() run_cmd(f"xmake f {xmake_config_flags} -cv") run_cmd("xmake") run_cmd("xmake install") diff --git a/scripts/set_env.py b/scripts/set_env.py index d1d4c2184..8fd1a63d9 100644 --- a/scripts/set_env.py +++ b/scripts/set_env.py @@ -2,6 +2,81 @@ import platform +def _maca_root_from_env(): + return ( + os.environ.get("MACA_PATH") + or os.environ.get("MACA_HOME") + or os.environ.get("MACA_ROOT") + or "" + ).strip() + + +def metax_hpc_compiler_include_dirs(): + """Directories needed so g++ finds cuda_runtime_api.h (cu-bridge) when compiling against PyTorch c10/cuda headers on MetaX/HPCC.""" + maca = _maca_root_from_env() + if not maca: + return [] + return [ + os.path.join(maca, "tools", "cu-bridge", "include"), + os.path.join(maca, "include", "hcr"), + os.path.join(maca, "include"), + ] + + +def _prepend_path_var(name, prefixes): + """Prepend colon-separated *prefixes* to env var *name* (POSIX).""" + if not prefixes: + return + chunk = ":".join(prefixes) + cur = os.environ.get(name, "") + os.environ[name] = f"{chunk}:{cur}" if cur else chunk + + +def ensure_metax_hpc_compiler_includes(): + """ + Prepend HPCC/cu-bridge includes to CPATH, CPLUS_INCLUDE_PATH, and C_INCLUDE_PATH. + g++ uses CPLUS_INCLUDE_PATH for .cc files; C_INCLUDE_PATH alone is not enough. + """ + dirs = metax_hpc_compiler_include_dirs() + if not dirs: + return + for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"): + _prepend_path_var(var, dirs) + + +def _parse_xmake_cli_flag_values(flags: str): + """Parse a string like '--metax-gpu=y --aten=y' into {key: value}.""" + parts = flags.replace("=", " ").split() + d = {} + i = 0 + n = len(parts) + while i < n: + p = parts[i] + if p.startswith("--") and len(p) > 2: + key = p[2:].lower() + i += 1 + if i < n and not parts[i].startswith("--"): + d[key] = parts[i].lower() + i += 1 + else: + d[key] = "y" + else: + i += 1 + return d + + +def _truthy_flag_value(v: str) -> bool: + return v in ("y", "yes", "true", "1", "on") + + +def xmake_flags_need_metax_aten_torch_includes(flags: str) -> bool: + """True when install.py-style args enable MetaX GPU and ATen (PyTorch) together.""" + d = _parse_xmake_cli_flag_values(flags) + return _truthy_flag_value(d.get("metax-gpu", "n")) and _truthy_flag_value( + d.get("aten", "n") + ) + + def set_env(): if os.environ.get("INFINI_ROOT") == None: os.environ["INFINI_ROOT"] = os.path.expanduser("~/.infini") diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index 2ffe396ef..d701ae60b 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { options); } -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) c10::cuda::CUDAStream get_cuda_stream() { return c10::cuda::getStreamFromExternal( cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 677b85d88..2eb89ce67 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -4,6 +4,18 @@ #include +#ifdef ENABLE_FLASH_ATTN +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) +#include +#endif +#endif + +#if defined(ENABLE_METAX_API) +#define INFINICORE_FLASH_OP(name) ::name +#else +#define INFINICORE_FLASH_OP(name) flash::name +#endif + namespace infinicore::op::mha_kvcache_impl::flashattn { struct PlannedMeta { @@ -33,22 +45,24 @@ void *plan(Tensor out, void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN +#ifdef ENABLE_NVIDIA_API c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#elif defined(ENABLE_METAX_API) + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif auto *p = reinterpret_cast(planned_meta); - auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); - auto q = infinicore::adaptor::to_aten_tensor(p->q); -#if defined(ENABLE_NVIDIA_API) - auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); - auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); -#elif defined(ENABLE_QY_API) + // FlashAttention kernels expect standard dense layout (contiguous last dimension). + auto out_at = infinicore::adaptor::to_aten_tensor(p->out); + const bool out_need_copy_back = !out_at.is_contiguous(); + auto out_tensor = out_need_copy_back ? out_at.contiguous() : out_at; + auto q = infinicore::adaptor::to_aten_tensor(p->q).contiguous(); auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous(); auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous(); -#endif - auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); - auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k).contiguous()); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous()); auto alibi_slopes = p->alibi_slopes - ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous()) : std::nullopt; std::optional k_new = std::nullopt; @@ -65,7 +79,11 @@ void run(void *planned_meta) { auto out = use_dynamic_out ? std::optional(std::nullopt) : std::optional(out_tensor); - auto result = flash::mha_fwd_kvcache( +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + std::optional flash_attn_mars_ext = std::nullopt; +#endif + + auto result = INFINICORE_FLASH_OP(mha_fwd_kvcache)( q, k_cache, v_cache, @@ -85,11 +103,19 @@ void run(void *planned_meta) { -1, 0.0f, false, - 0); + 0 +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + , + flash_attn_mars_ext +#endif + ); if (use_dynamic_out) { out_tensor.copy_(result[0]); } + if (out_need_copy_back) { + out_at.copy_(out_tensor); + } #else throw std::runtime_error("FlashAttention is not enabled in this build"); #endif diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index aff085898..2f0848b55 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -4,6 +4,12 @@ #include +#ifdef ENABLE_FLASH_ATTN +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) +#include +#endif +#endif + namespace infinicore::op::mha_varlen_impl::flashattn { struct PlannedMeta { @@ -39,11 +45,135 @@ void *plan(Tensor out, scale}; } -void run(void *planned_meta) { +namespace { + #ifdef ENABLE_FLASH_ATTN +struct VarlenFlashPrepared { + at::Tensor q; + at::Tensor k; + at::Tensor v; + at::Tensor out_at; + bool out_need_copy_back; + at::Tensor out_work; + std::optional out_opt; + at::Tensor cu_seqlens_q; + at::Tensor cu_seqlens_kv; + std::optional block_table; + std::optional alibi_slopes; + int max_seqlen_q; + int max_seqlen_k; + float scale; +}; + +VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) { + VarlenFlashPrepared t; + // FlashAttention kernels expect standard dense layout (contiguous last dimension). + t.q = infinicore::adaptor::to_aten_tensor(p->q).contiguous(); + t.k = infinicore::adaptor::to_aten_tensor(p->k).contiguous(); + t.v = infinicore::adaptor::to_aten_tensor(p->v).contiguous(); + t.out_at = infinicore::adaptor::to_aten_tensor(p->out); + t.out_need_copy_back = !t.out_at.is_contiguous(); + t.out_work = t.out_need_copy_back ? t.out_at.contiguous() : t.out_at; + t.out_opt = std::optional(t.out_work); + t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q).contiguous(); + t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k).contiguous(); + t.block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous()); + t.max_seqlen_q = p->max_seqlen_q; + t.max_seqlen_k = p->max_seqlen_k; + t.alibi_slopes = p->alibi_slopes + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous()) + : std::nullopt; + t.scale = p->scale; + return t; +} + +void copy_varlen_flash_output_back(VarlenFlashPrepared &t) { + if (t.out_need_copy_back) { + t.out_at.copy_(t.out_work); + } +} + +#if defined(ENABLE_METAX_API) +/** + * MetaX / hpcc: pip `flash_attn_2_cuda` exposes `mha_varlen_fwd` in the global namespace and + * relies on the current CUDA stream. Kept separate from the NVIDIA `run()` body to reduce merge conflicts. + */ +void run_flashattn_varlen_metax(PlannedMeta *p) { c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto t = prepare_varlen_flash_tensors(p); + std::optional seqused_k = std::nullopt; + std::optional leftpad_k = std::nullopt; + // `flash_attn_2_cuda` (pip / MetaX) may append an extra trailing argument (`flash_attn_mars_ext_`) + // depending on the HPCC/MetaX stack version. +#if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + std::optional flash_attn_mars_ext = std::nullopt; + ::mha_varlen_fwd( + t.q, + t.k, + t.v, + t.out_opt, + t.cu_seqlens_q, + t.cu_seqlens_kv, + seqused_k, + leftpad_k, + t.block_table, + t.alibi_slopes, + t.max_seqlen_q, + t.max_seqlen_k, + 0.0, + t.scale, + false, + true, + -1, + -1, + 0.0, + false, + std::nullopt, + flash_attn_mars_ext); +#else + ::mha_varlen_fwd( + t.q, + t.k, + t.v, + t.out_opt, + t.cu_seqlens_q, + t.cu_seqlens_kv, + seqused_k, + leftpad_k, + t.block_table, + t.alibi_slopes, + t.max_seqlen_q, + t.max_seqlen_k, + 0.0, + t.scale, + false, + true, + -1, + -1, + 0.0, + false, + std::nullopt); +#endif + copy_varlen_flash_output_back(t); +} +#endif + +#endif // ENABLE_FLASH_ATTN +} // namespace + +void run(void *planned_meta) { +#ifdef ENABLE_FLASH_ATTN auto *p = reinterpret_cast(planned_meta); +#if defined(ENABLE_METAX_API) + run_flashattn_varlen_metax(p); + return; +#endif + + // Original InfiniCore path (NVIDIA + xmake flash-attn-nvidia). MetaX is handled above. +#if defined(ENABLE_NVIDIA_API) + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto q = infinicore::adaptor::to_aten_tensor(p->q); auto k = infinicore::adaptor::to_aten_tensor(p->k); auto v = infinicore::adaptor::to_aten_tensor(p->v); @@ -80,6 +210,10 @@ void run(void *planned_meta) { 0.0, false, std::nullopt); +#else + throw std::runtime_error("FlashAttention varlen: no supported GPU backend in this build"); +#endif + #else throw std::runtime_error("FlashAttention is not enabled in this build"); #endif diff --git a/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca b/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca index c14bb75bc..083821638 100644 --- a/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca +++ b/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca @@ -1,9 +1,8 @@ #include "../../../devices/metax/metax_common.h" #include "../../../devices/metax/metax_handle.h" #include "../../../devices/metax/metax_kernel_common.h" - #include "binary_cross_entropy_with_logits_metax.h" - +#include #include namespace op::bce_with_logits::metax { diff --git a/src/infiniop/ops/equal/metax/equal_metax.maca b/src/infiniop/ops/equal/metax/equal_metax.maca index 265e5b5a6..655b63561 100644 --- a/src/infiniop/ops/equal/metax/equal_metax.maca +++ b/src/infiniop/ops/equal/metax/equal_metax.maca @@ -1,11 +1,23 @@ #include "equal_metax.h" #include "../../../elementwise/metax/elementwise_metax.h" - -#include "../cuda/kernel.cuh" +#include namespace op::equal::metax { +struct EqualOp { + static constexpr size_t num_inputs = 2; + + template + __device__ __forceinline__ bool operator()(const Tin0 &a, const Tin1 &b) const { + if constexpr (std::is_same_v) { + return static_cast(a == b); + } else { + return false; + } + } +}; + Descriptor::~Descriptor() = default; infiniStatus_t Descriptor::create( @@ -50,17 +62,17 @@ infiniStatus_t Descriptor::calculate( switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, EqualOp, bool, half, half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, EqualOp, bool, float, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_I32: - return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_I64: - return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, EqualOp, bool, double, double>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/hardswish/cuda/kernel.cuh b/src/infiniop/ops/hardswish/cuda/kernel.cuh index 25dfd55a0..bfb775d2f 100644 --- a/src/infiniop/ops/hardswish/cuda/kernel.cuh +++ b/src/infiniop/ops/hardswish/cuda/kernel.cuh @@ -2,6 +2,15 @@ #define __HARDSWISH_CUDA_H__ #include +#if defined(ENABLE_METAX_API) +#include +#elif defined(__MACACC__) +#include +#include +#else +#include +#include +#endif namespace op::hardswish::cuda { diff --git a/src/infiniop/ops/hardtanh/cuda/kernel.cuh b/src/infiniop/ops/hardtanh/cuda/kernel.cuh index fa8c3d130..4f6efe319 100644 --- a/src/infiniop/ops/hardtanh/cuda/kernel.cuh +++ b/src/infiniop/ops/hardtanh/cuda/kernel.cuh @@ -1,6 +1,15 @@ #ifndef __HARDTANH_CUDA_H__ #define __HARDTANH_CUDA_H__ +#if defined(ENABLE_METAX_API) +#include +#elif defined(__MACACC__) +#include +#include +#else +#include +#include +#endif #include namespace op::hardtanh::cuda { diff --git a/xmake.lua b/xmake.lua index 8f32bf7cc..3c6db7750 100644 --- a/xmake.lua +++ b/xmake.lua @@ -167,6 +167,8 @@ option_end() if has_config("metax-gpu") then add_defines("ENABLE_METAX_API") + -- Container torch build expects this for ATen headers on hpcc. + add_defines("USE_HPCC") if has_config("use-mc") then add_defines("ENABLE_METAX_MC_API") end @@ -235,14 +237,14 @@ option_end() -- Flash-Attn option("flash-attn") - set_default("") + set_default(nil) set_showmenu(true) set_description("Path to flash-attention repo. If not set, flash-attention will not used.") option_end() if has_config("aten") then add_defines("ENABLE_ATEN") - if get_config("flash-attn") ~= false then + if get_config("flash-attn") and get_config("flash-attn") ~= "" then add_defines("ENABLE_FLASH_ATTN") end end @@ -258,6 +260,7 @@ if has_config("graph") then add_defines("USE_INFINIRT_GRAPH") end + -- InfiniCCL option("ccl") set_default(false) @@ -460,24 +463,29 @@ target("infinicore_cpp_api") add_linkdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") - if get_config("flash-attn") ~= "" and get_config("flash-attn") ~= nil then + if get_config("flash-attn") and get_config("flash-attn") ~= "" then add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"}) if has_config("nv-gpu") then add_deps("flash-attn-nvidia") end - if has_config("qy-gpu") then - add_deps("flash-attn-qy") + if has_config("metax-gpu") then + add_deps("flash-attn-metax") end end - if get_config("flash-attn") and get_config("flash-attn") ~= "" and has_config("qy-gpu") then - local flash_so_qy = _qy_flash_attn_cuda_so_path() - local flash_dir_qy = path.directory(flash_so_qy) - local flash_name_qy = path.filename(flash_so_qy) + -- MetaX: link pip-built flash_attn_2_cuda*.so. + -- The `.so` path resolver lives in `xmake/metax.lua` and reads: + -- - `FLASH_ATTN_2_CUDA_SO` (exact `.so` override) + -- - `FLASH_ATTN_METAX_CUDA_SO_CONTAINER` (override expected container path) + -- Path is fixed at target definition time (before_link sandbox has no os.iorunv in xmake 3.x). + if get_config("flash-attn") and get_config("flash-attn") ~= "" and has_config("metax-gpu") then + local flash_so_metax = _metax_flash_attn_cuda_so_path() + local flash_dir_metax = path.directory(flash_so_metax) + local flash_name_metax = path.filename(flash_so_metax) before_link(function (target) target:add( "shflags", - "-Wl,--no-as-needed -L" .. flash_dir_qy .. " -l:" .. flash_name_qy .. " -Wl,-rpath," .. flash_dir_qy, + "-Wl,--no-as-needed -L" .. flash_dir_metax .. " -l:" .. flash_name_metax .. " -Wl,-rpath," .. flash_dir_metax, {force = true} ) end) diff --git a/xmake/metax.lua b/xmake/metax.lua index e7071d4bb..eada04015 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -1,5 +1,80 @@ local MACA_ROOT = os.getenv("MACA_PATH") or os.getenv("MACA_HOME") or os.getenv("MACA_ROOT") +local FLASH_ATTN_ROOT = get_config("flash-attn") + +-- MetaX flash-attn (pip `flash_attn_2_cuda`) may append an extra trailing argument +-- (`flash_attn_mars_ext_`) depending on the underlying HPCC/MetaX stack version. +do + -- Intentionally empty: HPCC version parsing is deferred to `before_build` + -- on `infinicore_cpp_api` where `os.iorunv` is available in this xmake sandbox. +end + +-- Set numeric HPCC version macro for flash-attn signature/call compatibility. +-- Must be done before compiling `infinicore_cpp_api` sources. +target("infinicore_cpp_api") + before_build(function (target) + if not has_config("metax-gpu") then + return + end + if not (get_config("flash-attn") and get_config("flash-attn") ~= "") then + return + end + + local version_txt = "/opt/hpcc/Version.txt" + if os.isfile(version_txt) then + local content = os.iorunv("cat", {version_txt}) or "" + content = content:trim() + -- Example: `Version:2.32.0.6` + local hpcc_major_str = content:match("Version:(%d+)") or content:match("^(%d+)") + if hpcc_major_str and hpcc_major_str ~= "" then + local hpcc_major = tonumber(hpcc_major_str) + if hpcc_major then + local define = "INFINICORE_HPCC_VERSION_MAJOR=" .. tostring(hpcc_major) + -- `defines` is the logical flag list for the target, + -- but we also pass `-D...` directly to ensure it reaches compilation. + target:add("defines", define) + target:add("cxflags", "-D" .. define) + target:add("cxxflags", "-D" .. define) + end + end + end + end) +target_end() + +-- Resolve MetaX flash-attn .so path. +-- `xmake.lua` calls `_metax_flash_attn_cuda_so_path()` during `infinicore_cpp_api` +-- linking, so this helper must live here (and must be global, not `local`). +local FLASH_ATTN_METAX_CUDA_SO_CONTAINER_DEFAULT = + "/opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-aarch64-linux-gnu.so" + +function _metax_flash_attn_cuda_so_path() + -- Highest priority: override the exact `.so` file to link. + local env_path = os.getenv("FLASH_ATTN_2_CUDA_SO") + if env_path and env_path ~= "" then + env_path = env_path:trim() + if os.isfile(env_path) then + return env_path + end + print(string.format("warning: metax+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s, fallback to container/default path", env_path)) + end + + -- Second priority: allow overriding the "expected" container path via env. + local container_path = os.getenv("FLASH_ATTN_METAX_CUDA_SO_CONTAINER") + if not container_path or container_path == "" then + container_path = FLASH_ATTN_METAX_CUDA_SO_CONTAINER_DEFAULT + end + + if not os.isfile(container_path) then + print( + string.format( + "warning: metax+flash-attn: expected %s; install flash-attn in conda env, or export FLASH_ATTN_2_CUDA_SO.", + container_path + ) + ) + end + return container_path +end + add_includedirs(MACA_ROOT .. "/include") add_linkdirs(MACA_ROOT .. "/lib") if has_config("use-mc") then @@ -57,8 +132,8 @@ target("infiniop-metax") add_includedirs(MACA_ROOT .. "/include/mcr") add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", { cxflags = { - "-include stdlib.h", - "-Wno-return-type", + "-include stdlib.h", + "-Wno-return-type", "-Wno-implicit-function-declaration", "-Wno-builtin-declaration-mismatch" } @@ -66,6 +141,27 @@ target("infiniop-metax") end target_end() +target("flash-attn-metax") + set_kind("phony") + set_default(false) + + if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then + before_build(function (target) + local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() + local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + + -- Validate build/runtime env in container and keep these paths available for downstream linking. + target:add("includedirs", TORCH_DIR .. "/include", TORCH_DIR .. "/include/torch/csrc/api/include", PYTHON_INCLUDE, {public = false}) + target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, {public = false}) + end) + else + before_build(function (target) + print("Flash Attention not available, skipping flash-attn-metax integration") + end) + end +target_end() + target("infinirt-metax") set_kind("static") set_languages("cxx17") From 9c4d486692899afa74d44dafc2189938cc18372c Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Fri, 3 Apr 2026 12:34:31 +0800 Subject: [PATCH 2/4] fix rebased Signed-off-by: Ceng23333 <441651826@qq.com> --- include/infinicore/adaptor/aten_adaptor.hpp | 8 ++-- src/infinicore/adaptor/aten_adaptor.cc | 2 +- .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 15 +++++--- src/infiniop/ops/equal/metax/equal_metax.maca | 38 ++++--------------- src/infiniop/ops/hardswish/cuda/kernel.cuh | 9 ----- src/infiniop/ops/hardtanh/cuda/kernel.cuh | 9 ----- 6 files changed, 24 insertions(+), 57 deletions(-) diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 2f1884e0c..29e532d29 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -5,7 +5,7 @@ #include -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) #include #endif @@ -32,7 +32,9 @@ inline at::ScalarType to_at_dtype(DataType dtype) { } inline at::Device to_at_device(const Device &device) { - if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX) { + // PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA). + // Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability. + if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) { return at::Device(at::kCUDA, device.getIndex()); } else if (device.getType() == Device::Type::CPU) { return at::Device(at::kCPU); @@ -43,7 +45,7 @@ inline at::Device to_at_device(const Device &device) { at::Tensor to_aten_tensor(const infinicore::Tensor &t); -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStream get_cuda_stream(); #endif } // namespace infinicore::adaptor diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index d701ae60b..04db643f9 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { options); } -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStream get_cuda_stream() { return c10::cuda::getStreamFromExternal( cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 2eb89ce67..423a5c793 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -52,17 +52,22 @@ void run(void *planned_meta) { #endif auto *p = reinterpret_cast(planned_meta); - // FlashAttention kernels expect standard dense layout (contiguous last dimension). + // Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense. auto out_at = infinicore::adaptor::to_aten_tensor(p->out); const bool out_need_copy_back = !out_at.is_contiguous(); auto out_tensor = out_need_copy_back ? out_at.contiguous() : out_at; - auto q = infinicore::adaptor::to_aten_tensor(p->q).contiguous(); + auto q = infinicore::adaptor::to_aten_tensor(p->q); +#if defined(ENABLE_NVIDIA_API) + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); +#elif defined(ENABLE_QY_API) || defined(ENABLE_METAX_API) auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous(); auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous(); - auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k).contiguous()); - auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous()); +#endif + auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); auto alibi_slopes = p->alibi_slopes - ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous()) + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; std::optional k_new = std::nullopt; diff --git a/src/infiniop/ops/equal/metax/equal_metax.maca b/src/infiniop/ops/equal/metax/equal_metax.maca index 655b63561..7d3008fb6 100644 --- a/src/infiniop/ops/equal/metax/equal_metax.maca +++ b/src/infiniop/ops/equal/metax/equal_metax.maca @@ -1,22 +1,10 @@ #include "equal_metax.h" #include "../../../elementwise/metax/elementwise_metax.h" -#include -namespace op::equal::metax { - -struct EqualOp { - static constexpr size_t num_inputs = 2; +#include "../cuda/kernel.cuh" - template - __device__ __forceinline__ bool operator()(const Tin0 &a, const Tin1 &b) const { - if constexpr (std::is_same_v) { - return static_cast(a == b); - } else { - return false; - } - } -}; +namespace op::equal::metax { Descriptor::~Descriptor() = default; @@ -25,54 +13,44 @@ infiniStatus_t Descriptor::create( Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, std::vector input_desc_vec) { - auto handle = reinterpret_cast(handle_); - const auto &a_desc = input_desc_vec.at(0); auto compute_dtype = a_desc->dtype(); auto out_dtype = out_desc->dtype(); - const auto &b_desc = input_desc_vec.at(1); const auto &c_shape = out_desc->shape(); const auto &a_shape = a_desc->shape(); const auto &b_shape = b_desc->shape(); - CHECK_DTYPE(compute_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_F64); - CHECK_DTYPE(out_dtype, INFINI_DTYPE_BOOL); - CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); - CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, compute_dtype, out_desc, input_desc_vec) - return INFINI_STATUS_SUCCESS; } - infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *output, std::vector inputs, void *stream) const { - if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, EqualOp, bool, half, half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, EqualOp, bool, float, float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_I32: - return _device_info->calculate<256, EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_I64: - return _device_info->calculate<256, EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, EqualOp, bool, double, double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/hardswish/cuda/kernel.cuh b/src/infiniop/ops/hardswish/cuda/kernel.cuh index bfb775d2f..25dfd55a0 100644 --- a/src/infiniop/ops/hardswish/cuda/kernel.cuh +++ b/src/infiniop/ops/hardswish/cuda/kernel.cuh @@ -2,15 +2,6 @@ #define __HARDSWISH_CUDA_H__ #include -#if defined(ENABLE_METAX_API) -#include -#elif defined(__MACACC__) -#include -#include -#else -#include -#include -#endif namespace op::hardswish::cuda { diff --git a/src/infiniop/ops/hardtanh/cuda/kernel.cuh b/src/infiniop/ops/hardtanh/cuda/kernel.cuh index 4f6efe319..fa8c3d130 100644 --- a/src/infiniop/ops/hardtanh/cuda/kernel.cuh +++ b/src/infiniop/ops/hardtanh/cuda/kernel.cuh @@ -1,15 +1,6 @@ #ifndef __HARDTANH_CUDA_H__ #define __HARDTANH_CUDA_H__ -#if defined(ENABLE_METAX_API) -#include -#elif defined(__MACACC__) -#include -#include -#else -#include -#include -#endif #include namespace op::hardtanh::cuda { From af889fd88d01023eca5e46aee690e3687b3f54ee Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Fri, 3 Apr 2026 13:49:02 +0800 Subject: [PATCH 3/4] fix format Signed-off-by: Ceng23333 <441651826@qq.com> --- include/infinicore/adaptor/aten_adaptor.hpp | 1 + scripts/install.py | 8 +-- scripts/set_env.py | 62 +++++++++---------- .../mha_varlen_flashattn.cc | 4 +- ...inary_cross_entropy_with_logits_metax.maca | 3 +- src/infiniop/ops/equal/metax/equal_metax.maca | 10 +++ 6 files changed, 50 insertions(+), 38 deletions(-) diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 29e532d29..59dbcb638 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -7,6 +7,7 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) #include +#include #endif #ifdef ENABLE_NVIDIA_API diff --git a/scripts/install.py b/scripts/install.py index bac2d95d1..62061922a 100644 --- a/scripts/install.py +++ b/scripts/install.py @@ -4,8 +4,8 @@ import sys from set_env import ( set_env, - ensure_metax_hpc_compiler_includes, - xmake_flags_need_metax_aten_torch_includes, + ensure_aten_torch_compiler_includes, + xmake_flags_need_aten_torch_compiler_includes, ) PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -16,8 +16,8 @@ def run_cmd(cmd): def install(xmake_config_flags=""): - if xmake_flags_need_metax_aten_torch_includes(xmake_config_flags): - ensure_metax_hpc_compiler_includes() + if xmake_flags_need_aten_torch_compiler_includes(xmake_config_flags): + ensure_aten_torch_compiler_includes() run_cmd(f"xmake f {xmake_config_flags} -cv") run_cmd("xmake") run_cmd("xmake install") diff --git a/scripts/set_env.py b/scripts/set_env.py index 8fd1a63d9..a6e22e427 100644 --- a/scripts/set_env.py +++ b/scripts/set_env.py @@ -2,25 +2,15 @@ import platform -def _maca_root_from_env(): - return ( - os.environ.get("MACA_PATH") - or os.environ.get("MACA_HOME") - or os.environ.get("MACA_ROOT") - or "" - ).strip() - - -def metax_hpc_compiler_include_dirs(): - """Directories needed so g++ finds cuda_runtime_api.h (cu-bridge) when compiling against PyTorch c10/cuda headers on MetaX/HPCC.""" - maca = _maca_root_from_env() - if not maca: - return [] - return [ - os.path.join(maca, "tools", "cu-bridge", "include"), - os.path.join(maca, "include", "hcr"), - os.path.join(maca, "include"), - ] +def _hpcc_toolkit_root() -> str: + """HPCC/MACA install root (cu-bridge, headers). Env vars first; else common container path.""" + for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"): + v = os.environ.get(key, "").strip() + if v: + return v + if os.path.isdir("/opt/hpcc"): + return "/opt/hpcc" + return "" def _prepend_path_var(name, prefixes): @@ -32,14 +22,16 @@ def _prepend_path_var(name, prefixes): os.environ[name] = f"{chunk}:{cur}" if cur else chunk -def ensure_metax_hpc_compiler_includes(): - """ - Prepend HPCC/cu-bridge includes to CPATH, CPLUS_INCLUDE_PATH, and C_INCLUDE_PATH. - g++ uses CPLUS_INCLUDE_PATH for .cc files; C_INCLUDE_PATH alone is not enough. - """ - dirs = metax_hpc_compiler_include_dirs() - if not dirs: +def ensure_aten_torch_compiler_includes() -> None: + """If HPCC root is known, prepend cu-bridge + HPCC headers for g++ compiling ATen .cc (c10/cuda).""" + root = _hpcc_toolkit_root() + if not root: return + dirs = [ + os.path.join(root, "tools", "cu-bridge", "include"), + os.path.join(root, "include", "hcr"), + os.path.join(root, "include"), + ] for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"): _prepend_path_var(var, dirs) @@ -69,12 +61,20 @@ def _truthy_flag_value(v: str) -> bool: return v in ("y", "yes", "true", "1", "on") -def xmake_flags_need_metax_aten_torch_includes(flags: str) -> bool: - """True when install.py-style args enable MetaX GPU and ATen (PyTorch) together.""" +# xmake.lua GPU / accelerator backends (any of these + aten may compile C++ against torch+cuda-style headers). +_XMAKE_GPU_BACKEND_KEYS = frozenset( + { + "metax-gpu", + } +) + + +def xmake_flags_need_aten_torch_compiler_includes(flags: str) -> bool: + """True when ATen is enabled with any GPU/accelerator backend (install.py / xmake f ...).""" d = _parse_xmake_cli_flag_values(flags) - return _truthy_flag_value(d.get("metax-gpu", "n")) and _truthy_flag_value( - d.get("aten", "n") - ) + if not _truthy_flag_value(d.get("aten", "n")): + return False + return any(_truthy_flag_value(d.get(k, "n")) for k in _XMAKE_GPU_BACKEND_KEYS) def set_env(): diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index 2f0848b55..161f9814f 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -81,8 +81,8 @@ VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) { t.max_seqlen_q = p->max_seqlen_q; t.max_seqlen_k = p->max_seqlen_k; t.alibi_slopes = p->alibi_slopes - ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous()) - : std::nullopt; + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous()) + : std::nullopt; t.scale = p->scale; return t; } diff --git a/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca b/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca index 083821638..c14bb75bc 100644 --- a/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca +++ b/src/infiniop/ops/binary_cross_entropy_with_logits/metax/binary_cross_entropy_with_logits_metax.maca @@ -1,8 +1,9 @@ #include "../../../devices/metax/metax_common.h" #include "../../../devices/metax/metax_handle.h" #include "../../../devices/metax/metax_kernel_common.h" + #include "binary_cross_entropy_with_logits_metax.h" -#include + #include namespace op::bce_with_logits::metax { diff --git a/src/infiniop/ops/equal/metax/equal_metax.maca b/src/infiniop/ops/equal/metax/equal_metax.maca index 7d3008fb6..265e5b5a6 100644 --- a/src/infiniop/ops/equal/metax/equal_metax.maca +++ b/src/infiniop/ops/equal/metax/equal_metax.maca @@ -13,27 +13,37 @@ infiniStatus_t Descriptor::create( Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, std::vector input_desc_vec) { + auto handle = reinterpret_cast(handle_); + const auto &a_desc = input_desc_vec.at(0); auto compute_dtype = a_desc->dtype(); auto out_dtype = out_desc->dtype(); + const auto &b_desc = input_desc_vec.at(1); const auto &c_shape = out_desc->shape(); const auto &a_shape = a_desc->shape(); const auto &b_shape = b_desc->shape(); + CHECK_DTYPE(compute_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_F64); + CHECK_DTYPE(out_dtype, INFINI_DTYPE_BOOL); + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, compute_dtype, out_desc, input_desc_vec) + return INFINI_STATUS_SUCCESS; } + infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *output, std::vector inputs, void *stream) const { + if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } From f8fab935b1eae79ec49c407e8106d91ebfb49026 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Fri, 3 Apr 2026 14:11:33 +0800 Subject: [PATCH 4/4] simplify mha_varlen_flashattn.cc Signed-off-by: Ceng23333 <441651826@qq.com> --- .../mha_varlen_flashattn.cc | 43 ++++++------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index 161f9814f..c38b7a81c 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -67,21 +67,21 @@ struct VarlenFlashPrepared { VarlenFlashPrepared prepare_varlen_flash_tensors(PlannedMeta *p) { VarlenFlashPrepared t; - // FlashAttention kernels expect standard dense layout (contiguous last dimension). - t.q = infinicore::adaptor::to_aten_tensor(p->q).contiguous(); + // Varlen flash-attn: keep k/v contiguous for dense/paged layout; avoid extra copies for q/metadata when already dense. + t.q = infinicore::adaptor::to_aten_tensor(p->q); t.k = infinicore::adaptor::to_aten_tensor(p->k).contiguous(); t.v = infinicore::adaptor::to_aten_tensor(p->v).contiguous(); t.out_at = infinicore::adaptor::to_aten_tensor(p->out); t.out_need_copy_back = !t.out_at.is_contiguous(); t.out_work = t.out_need_copy_back ? t.out_at.contiguous() : t.out_at; t.out_opt = std::optional(t.out_work); - t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q).contiguous(); - t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k).contiguous(); - t.block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table).contiguous()); + t.cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); + t.cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); + t.block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); t.max_seqlen_q = p->max_seqlen_q; t.max_seqlen_k = p->max_seqlen_k; t.alibi_slopes = p->alibi_slopes - ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes).contiguous()) + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; t.scale = p->scale; return t; @@ -107,6 +107,7 @@ void run_flashattn_varlen_metax(PlannedMeta *p) { // depending on the HPCC/MetaX stack version. #if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) std::optional flash_attn_mars_ext = std::nullopt; +#endif ::mha_varlen_fwd( t.q, t.k, @@ -128,32 +129,12 @@ void run_flashattn_varlen_metax(PlannedMeta *p) { -1, 0.0, false, - std::nullopt, - flash_attn_mars_ext); -#else - ::mha_varlen_fwd( - t.q, - t.k, - t.v, - t.out_opt, - t.cu_seqlens_q, - t.cu_seqlens_kv, - seqused_k, - leftpad_k, - t.block_table, - t.alibi_slopes, - t.max_seqlen_q, - t.max_seqlen_k, - 0.0, - t.scale, - false, - true, - -1, - -1, - 0.0, - false, - std::nullopt); + std::nullopt +#if defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + , + flash_attn_mars_ext #endif + ); copy_varlen_flash_output_back(t); } #endif