diff --git a/README.md b/README.md index b60f79890..b1e04219d 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ API 定义以及使用方式详见 [`InfiniCore文档`](https://github.com/Infin ### 一、克隆项目 -由于仓库中含有子模块,所以在克隆时请添加 `--recursive` 或 `--recurse-submodules`,如: +由于仓库中含有子模块(如 `spdlog` / `nlohmann_json`),所以在克隆时请添加 `--recursive` 或 `--recurse-submodules`,如: ```shell git clone --recursive https://github.com/InfiniTensor/InfiniCore.git @@ -51,6 +51,10 @@ git clone --recursive https://github.com/InfiniTensor/InfiniCore.git git submodule update --init --recursive ``` +> 注:InfLLM-V2 CUDA kernels(`infllmv2_cuda_impl`)为**可选依赖**,不会随仓库子模块默认拉取。 +> 如需启用 `--infllmv2`(见下文),请自行在任意目录克隆/编译该项目,并将生成的 `infllm_v2/*.so` 路径传给 xmake; +> 或者将其手动放到 `InfiniCore/third_party/infllmv2_cuda_impl` 后再使用 `--infllmv2=y` 走自动探测。 + 配置`INFINI_ROOT` 和 `LD_LIBRARY_PATH` 环境变量。 默认`INFINI_ROOT`为`$HOME/.infini`,可以使用以下命令自动配置: @@ -108,6 +112,8 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] | `--ninetoothed=[y\|n]` | 是否编译九齿实现 | n | `--ccl=[y\|n]` | 是否编译 InfiniCCL 通信库接口实现 | n | `--graph=[y\|n]` | 是否编译 cuda graph 接口实现 | n +| `--aten=[y\|n]` | 是否链接 ATen / PyTorch(用于部分算子/对比测试) | n +| `--infllmv2=[y\|PATH]` | **可选**:启用 InfLLM-V2 attention(需 `--aten=y`)。值为 `y`(探测 `third_party/infllmv2_cuda_impl`)或指向 `libinfllm_v2.so` / `infllmv2_cuda_impl` 根目录 | (空) ##### 手动安装底层库 @@ -174,6 +180,64 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] ``` +##### 试验功能 -- 使用 InfLLM-V2 CUDA kernels(可选) + +InfLLM-V2 的 varlen/kvcache attention 需要额外的 CUDA kernels(`infllm_v2/*.so`)。该依赖为**可选**,需要你自行克隆并编译。 + +如果你希望将 `infllmv2_cuda_impl` 放在本仓库 `third_party/` 下(但不作为子模块管理),可以按以下方式拉取并编译,然后使用 `--infllmv2=y` 让 xmake 自动探测: + +```bash +cd InfiniCore + +# Core submodules only (InfLLM-v2 不作为子模块强制拉取) +git submodule sync third_party/spdlog third_party/nlohmann_json +git submodule update --init third_party/spdlog third_party/nlohmann_json + +# Fetch InfLLM-v2 into third_party if missing (NOT a git submodule). +INFLLMV2_DIR="$PWD/third_party/infllmv2_cuda_impl" +if [ ! -d "$INFLLMV2_DIR/.git" ]; then + rm -rf "$INFLLMV2_DIR" + git clone --depth 1 -b minicpm_sala_patches --recurse-submodules \ + https://github.com/Ceng23333/infllmv2_cuda_impl.git "$INFLLMV2_DIR" +fi + +cd "$INFLLMV2_DIR" +git submodule update --init --recursive +python3 setup.py install + +cd .. +python3 scripts/install.py --root --nv-gpu=y --cuda_arch=sm_80 --aten=y --infllmv2=y --ccl=y +xmake build -r _infinicore +xmake install _infinicore + +export PYTHONPATH="$PWD/test/infinicore:$PWD/python:${PYTHONPATH:-}" +python3 "$PWD/test/infinicore/ops/infllmv2_attention.py" --nvidia +python3 "$PWD/test/infinicore/ops/simple_gla_prefill.py" --nvidia +python3 "$PWD/test/infinicore/ops/simple_gla_decode_recurrent.py" --nvidia +``` + +1. 构建 `infllmv2_cuda_impl`(示例,路径可自定义): + +```shell +git clone /abs/path/to/infllmv2_cuda_impl +cd /abs/path/to/infllmv2_cuda_impl +python setup.py install +``` + +2. 配置并编译 InfiniCore(需要 `--aten=y`): + +```shell +# 方式 A:直接给 .so 的绝对路径(推荐,更明确) +xmake f --nv-gpu=y --aten=y --infllmv2=/abs/path/to/libinfllm_v2.so -cv +xmake build && xmake install + +# 方式 B:给 infllmv2_cuda_impl 根目录(会探测 build/lib.*/infllm_v2/*.so) +xmake f --nv-gpu=y --aten=y --infllmv2=/abs/path/to/infllmv2_cuda_impl -cv +xmake build && xmake install +``` + +运行时需要能找到该 `libinfllm_v2.so`(例如它的目录已在 rpath / `LD_LIBRARY_PATH` 中)。本项目在链接时会尝试写入 rpath 到对应目录,因此通常无需 `LD_PRELOAD`。 + 2. 编译安装 默认安装路径为 `$HOME/.infini`。 diff --git a/include/infinicore/adaptor/infllmv2_api.hpp b/include/infinicore/adaptor/infllmv2_api.hpp new file mode 100644 index 000000000..d70302be7 --- /dev/null +++ b/include/infinicore/adaptor/infllmv2_api.hpp @@ -0,0 +1,68 @@ +/** + * Vendor API declarations for InfLLM-v2 attention kernels. + * + * This header is intentionally placed under `infinicore/adaptor/` because it + * declares symbols provided by an external InfLLM-v2 shared library. + * + * NOTE: The vendor functions are declared in the global namespace to match the + * upstream InfLLM-v2 entrypoints (e.g. `entry.cu`) and to keep linkage stable. + */ +#pragma once + +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) + +#include +#include +#include + +/** Varlen forward: unpadded Q/K/V with cu_seqlens. Returns {out, softmax_lse, ...}. */ +std::vector mha_varlen_fwd( + at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + c10::optional &out_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + c10::optional &seqused_k, + c10::optional &leftpad_k_, + c10::optional &block_table_, + c10::optional &alibi_slopes_, + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float softmax_scale, + bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + float softcap, + bool return_softmax, + c10::optional gen_, + c10::optional &blockmask_); + +/** KV-cache forward (decode). Returns {out, softmax_lse}. */ +std::vector mha_fwd_kvcache( + at::Tensor &q, + const at::Tensor &kcache, + const at::Tensor &vcache, + c10::optional &k_, + c10::optional &v_, + c10::optional &seqlens_k_, + c10::optional &rotary_cos_, + c10::optional &rotary_sin_, + c10::optional &cache_batch_idx_, + c10::optional &leftpad_k_, + c10::optional &block_table_, + c10::optional &alibi_slopes_, + c10::optional &out_, + float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + float softcap, + bool is_rotary_interleaved, + int num_splits, + c10::optional &blockmask_); + +#endif // ENABLE_INFLLMV2 && ENABLE_ATEN + diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 18741c402..e8aad74d2 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -19,11 +19,20 @@ #include "ops/flash_attention.hpp" #include "ops/fmin.hpp" #include "ops/fmod.hpp" +#include "ops/simple_gla_attention.hpp" +#include "ops/simple_gla_decode_step.hpp" +#include "ops/simple_gla_recurrent_state_append.hpp" +#include "ops/simple_gla_prefill.hpp" +#include "ops/infllmv2_attention.hpp" #include "ops/hardswish.hpp" #include "ops/hardtanh.hpp" #include "ops/kv_caching.hpp" #include "ops/matmul.hpp" +#include "ops/mha_kvcache.hpp" +#include "ops/mha_varlen.hpp" +#include "ops/mul.hpp" #include "ops/ones.hpp" +#include "ops/zeros.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" @@ -34,6 +43,7 @@ #include "ops/reciprocal.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" +#include "ops/sigmoid.hpp" #include "ops/silu.hpp" #include "ops/silu_and_mul.hpp" #include "ops/swiglu.hpp" diff --git a/include/infinicore/ops/infllmv2_api.hpp b/include/infinicore/ops/infllmv2_api.hpp new file mode 100644 index 000000000..c36894554 --- /dev/null +++ b/include/infinicore/ops/infllmv2_api.hpp @@ -0,0 +1,12 @@ +/** + * Backward-compatible include for the InfLLM-v2 vendor shim. + * + * The InfLLM-v2 entrypoints are provided by an external shared library and are + * now declared under `infinicore/adaptor/infllmv2_api.hpp` to make the + * dependency boundary explicit. + * + * The vendor symbols themselves remain in the global namespace. + */ +#pragma once + +#include "infinicore/adaptor/infllmv2_api.hpp" diff --git a/include/infinicore/ops/infllmv2_attention.hpp b/include/infinicore/ops/infllmv2_attention.hpp new file mode 100644 index 000000000..07f99dffc --- /dev/null +++ b/include/infinicore/ops/infllmv2_attention.hpp @@ -0,0 +1,228 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +// Graph-recordable InfLLM-v2 attention ops. +// +// These wrappers provide `_` variants that write into a pre-allocated output +// tensor so they can participate in the graph recording system. +INFINICORE_GRAPH_OP_CLASS( + InfllmV2AttentionVarlen, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + int, + int, + float, + bool, + int, + int); + +INFINICORE_GRAPH_OP_CLASS( + InfllmV2AttentionKVCache, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + float, + bool, + int, + int); + +INFINICORE_GRAPH_OP_CLASS( + InfllmV2AttentionKVCacheUpdate, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + float, + bool, + int, + int); + +// Varlen InfLLM-V2 attention over unpadded Q/K/V. +// +// Shapes follow the FlashAttn-style varlen convention: +// q : [total_q, nheads, head_dim] +// k, v : [total_k, nheads_k, head_dim] +// cu_seqlens_q: [batch_size + 1] (int32) +// cu_seqlens_k: [batch_size + 1] (int32) +// +// Returns: +// [total_q, nheads, head_dim] +void infllmv2_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1); +Tensor infllmv2_varlen(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1); + +// Preferred names (attention-disambiguated). These are header-only aliases to the +// backward-compatible `infllmv2_*` symbols to avoid adding extra exported ABI. +inline void infllmv2_attention_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1) { + infllmv2_varlen_(out, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right); +} +inline Tensor infllmv2_attention_varlen(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1) { + return infllmv2_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right); +} + +// Decode-time InfLLM-V2 attention with KV cache. +// +// Shapes: +// q : [batch, seqlen_q, nheads, head_dim] +// k_cache : [num_blocks, block_size, nheads_k, head_dim] or [batch, seqlen_cache, nheads_k, head_dim] +// v_cache : same as k_cache +// cache_lens : [batch] (int32) total KV length per sequence +// +// Returns: +// [batch, seqlen_q, nheads, head_dim] +void infllmv2_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1); +Tensor infllmv2_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1); + +inline void infllmv2_attention_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1) { + infllmv2_kvcache_(out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right); +} +inline Tensor infllmv2_attention_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1) { + return infllmv2_kvcache(q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right); +} + +// Decode-time InfLLM-V2 attention with KV cache, updating cache in-place. +// +// Shapes: +// q : [batch, seqlen_q, nheads, head_dim] +// k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache) +// v_cache : same as k_cache +// k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets) +// cache_lens : [batch] (int32) current KV length per sequence BEFORE appending +// +// Returns: +// [batch, seqlen_q, nheads, head_dim] +void infllmv2_kvcache_update_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1); +Tensor infllmv2_kvcache_update(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1); + +inline void infllmv2_attention_kvcache_update_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1) { + infllmv2_kvcache_update_(out, q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right); +} +inline Tensor infllmv2_attention_kvcache_update(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left = -1, + int window_size_right = -1) { + return infllmv2_kvcache_update(q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right); +} + +} // namespace infinicore::op + diff --git a/include/infinicore/ops/sigmoid.hpp b/include/infinicore/ops/sigmoid.hpp new file mode 100644 index 000000000..949290392 --- /dev/null +++ b/include/infinicore/ops/sigmoid.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Sigmoid { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor output, Tensor input); + static common::OpDispatcher &dispatcher(); +}; + +Tensor sigmoid(Tensor input); +void sigmoid_(Tensor output, Tensor input); +} // namespace infinicore::op + diff --git a/include/infinicore/ops/simple_gla_attention.hpp b/include/infinicore/ops/simple_gla_attention.hpp new file mode 100644 index 000000000..beb60c2d9 --- /dev/null +++ b/include/infinicore/ops/simple_gla_attention.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +// Simple GLA (recurrent linear) attention with per-head decay. +// Shapes: q, k, v [B, T, H, D], g_gamma [H] (log-decay per head). +// Recurrence: gate = exp(g_gamma); S = S * gate + outer(k_t, v_t); o_t = (q_t * scale) @ S. +// Returns [B, T, H, D]. +class SimpleGlaAttention { +public: + using schema = void (*)(Tensor & out, const Tensor &q, const Tensor &k, const Tensor &v, + const Tensor &g_gamma, float scale); + static void execute(Tensor & out, const Tensor &q, const Tensor &k, const Tensor &v, + const Tensor &g_gamma, float scale); + static common::OpDispatcher &dispatcher(); +}; + +Tensor simple_gla_attention(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/simple_gla_decode_step.hpp b/include/infinicore/ops/simple_gla_decode_step.hpp new file mode 100644 index 000000000..69f823ecf --- /dev/null +++ b/include/infinicore/ops/simple_gla_decode_step.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +// One decode timestep of Simple GLA (same recurrence as SimpleGlaAttention). +// q, k, v: [B, 1, H, D]; g_gamma: [H] (log-decay per head); state: [B, H, D, D] float32 (in-place). +// Updates: state = state * exp(g_gamma) + outer(k, v); then out[b,0,h,:] = (q * scale) @ state[b,h]. +// Returns out with shape [B, 1, H, D] (same dtype as q). +class SimpleGlaDecodeStep { +public: + using schema = void (*)(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, const Tensor &v, + const Tensor &g_gamma, float scale); + static void execute(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, const Tensor &v, + const Tensor &g_gamma, float scale); + static common::OpDispatcher &dispatcher(); +}; + +Tensor simple_gla_decode_step(const Tensor &q, const Tensor &k, const Tensor &v, Tensor &state, + const Tensor &g_gamma, float scale); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/simple_gla_prefill.hpp b/include/infinicore/ops/simple_gla_prefill.hpp new file mode 100644 index 000000000..a701c2143 --- /dev/null +++ b/include/infinicore/ops/simple_gla_prefill.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "../tensor.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(SimpleGLAPrefill, + Tensor, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + float); + +// Fused/chunked Simple GLA prefill forward. +// q,k,v: [B,T,H,D] (F16/BF16), g_gamma: [H] (F32), returns [B,T,H,D] (same dtype). +Tensor simple_gla_prefill(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale); + +} // namespace infinicore::op + diff --git a/include/infinicore/ops/simple_gla_recurrent_state_append.hpp b/include/infinicore/ops/simple_gla_recurrent_state_append.hpp new file mode 100644 index 000000000..a7f0c611f --- /dev/null +++ b/include/infinicore/ops/simple_gla_recurrent_state_append.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "../device.hpp" +#include "../tensor.hpp" +#include "common/dispatcher.hpp" + +namespace infinicore::op { + +// Batched update of Simple GLA recurrent state (float32 [B,H,D,D]) for a contiguous +// K/V segment [B,L,H,D], matching L repeated simple_gla_decode_step applications: +// S <- g^L * S + sum_{j=0}^{L-1} g^{L-1-j} * outer(k_j, v_j) +// g_gamma: [H] (same log-gate as simple_gla_decode_step; gate = exp(g_gamma)). +class SimpleGlaRecurrentStateAppend { +public: + using schema = void (*)(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, const Tensor &g_gamma); + static void execute(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, const Tensor &g_gamma); + static common::OpDispatcher &dispatcher(); +}; + +void simple_gla_recurrent_state_append_segment(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, + const Tensor &g_gamma); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/zeros.hpp b/include/infinicore/ops/zeros.hpp new file mode 100644 index 000000000..709c41855 --- /dev/null +++ b/include/infinicore/ops/zeros.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "common/op.hpp" + +namespace infinicore::op { +class Zeros { + +public: + using schema = void (*)(Tensor); + static void execute(Tensor output); + static common::OpDispatcher &dispatcher(); +}; + +void zeros_(Tensor output); +} // namespace infinicore::op diff --git a/include/infiniop/ops/simple_gla_prefill.h b/include/infiniop/ops/simple_gla_prefill.h new file mode 100644 index 000000000..cdfd7d936 --- /dev/null +++ b/include/infiniop/ops/simple_gla_prefill.h @@ -0,0 +1,39 @@ +#ifndef __INFINIOP_SIMPLE_GLA_PREFILL_API_H__ +#define __INFINIOP_SIMPLE_GLA_PREFILL_API_H__ + +#include "../operator_descriptor.h" + +// Chunked/fused Simple GLA prefill forward. +// q, k, v: [B, T, H, D] (F16/BF16), g_gamma: [H] (F32), out: [B, T, H, D] (same dtype as q) +typedef struct InfiniopDescriptor *infiniopSimpleGLAPrefillDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateSimpleGLAPrefillDescriptor( + infiniopHandle_t handle, + infiniopSimpleGLAPrefillDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_gamma_desc); + +__INFINI_C __export infiniStatus_t infiniopGetSimpleGLAPrefillWorkspaceSize( + infiniopSimpleGLAPrefillDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopSimpleGLAPrefill( + infiniopSimpleGLAPrefillDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void const *q, + void const *k, + void const *v, + void const *g_gamma, + float scale, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroySimpleGLAPrefillDescriptor( + infiniopSimpleGLAPrefillDescriptor_t desc); + +#endif + diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index a2a807651..15624f91f 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -83,6 +83,7 @@ from infinicore.ops.hypot import hypot from infinicore.ops.index_add import index_add from infinicore.ops.index_copy import index_copy +from infinicore.ops.infllmv2_attention import infllmv2_kvcache, infllmv2_varlen from infinicore.ops.inner import inner from infinicore.ops.kron import kron from infinicore.ops.kthvalue import kthvalue @@ -108,6 +109,9 @@ from infinicore.ops.reciprocal import reciprocal from infinicore.ops.scatter import scatter from infinicore.ops.sinh import sinh +from infinicore.ops.simple_gla_attention import simple_gla_attention +from infinicore.ops.simple_gla_decode_step import simple_gla_decode_step +from infinicore.ops.simple_gla_prefill import simple_gla_prefill from infinicore.ops.squeeze import squeeze from infinicore.ops.sum import sum from infinicore.ops.take import take @@ -191,6 +195,11 @@ "block_diag", "kron", "bitwise_right_shift", + "infllmv2_varlen", + "infllmv2_kvcache", + "simple_gla_attention", + "simple_gla_decode_step", + "simple_gla_prefill", "kv_caching", "asinh", "baddbmm", diff --git a/python/infinicore/ops/infllmv2_attention.py b/python/infinicore/ops/infllmv2_attention.py new file mode 100644 index 000000000..ef4d91e4c --- /dev/null +++ b/python/infinicore/ops/infllmv2_attention.py @@ -0,0 +1,76 @@ +""" +InfLLM-V2 attention ops (varlen and kvcache). +Available only when InfiniCore is built with ENABLE_INFLLMV2 and linked to infllmv2 .so. +""" + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +_native_infllmv2_varlen = getattr(_infinicore, "infllmv2_varlen", None) +_native_infllmv2_kvcache = getattr(_infinicore, "infllmv2_kvcache", None) + +_MISSING_MSG = ( + "infllmv2_varlen / infllmv2_kvcache not found in _infinicore. " + "Build InfiniCore with: xmake f --aten=y --infllmv2=y (auto-detect under third_party/infllmv2_cuda_impl) " + "or --infllmv2=/abs/path/to/libinfllm_v2.so (recommended), then xmake build/install." +) + + +def infllmv2_varlen( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + scale: float = 1.0, + causal: bool = True, + window_size_left: int = -1, + window_size_right: int = -1, +): + """InfLLM-V2 varlen attention. q,k,v unpadded; cu_seqlens_q/k [batch+1]. Returns [total_q, nheads, head_dim].""" + if _native_infllmv2_varlen is None: + raise NotImplementedError(_MISSING_MSG) + return Tensor( + _native_infllmv2_varlen( + q._underlying, + k._underlying, + v._underlying, + cu_seqlens_q._underlying, + cu_seqlens_k._underlying, + max_seqlen_q, + max_seqlen_k, + scale, + causal, + window_size_left, + window_size_right, + ) + ) + + +def infllmv2_kvcache( + q: Tensor, + k_cache: Tensor, + v_cache: Tensor, + cache_lens: Tensor, + scale: float = 1.0, + causal: bool = True, + window_size_left: int = -1, + window_size_right: int = -1, +): + """InfLLM-V2 KV-cache (decode) attention. Returns [batch, seqlen_q, nheads, head_dim].""" + if _native_infllmv2_kvcache is None: + raise NotImplementedError(_MISSING_MSG) + return Tensor( + _native_infllmv2_kvcache( + q._underlying, + k_cache._underlying, + v_cache._underlying, + cache_lens._underlying, + scale, + causal, + window_size_left, + window_size_right, + ) + ) diff --git a/python/infinicore/ops/simple_gla_attention.py b/python/infinicore/ops/simple_gla_attention.py new file mode 100644 index 000000000..1600c1916 --- /dev/null +++ b/python/infinicore/ops/simple_gla_attention.py @@ -0,0 +1,24 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +_native_simple_gla_attention = getattr(_infinicore, "simple_gla_attention", None) +if _native_simple_gla_attention is None: + _MISSING_MSG = ( + "simple_gla_attention not found in _infinicore. Rebuild InfiniCore extension: " + "cd InfiniCore && xmake build _infinicore" + ) + + +def simple_gla_attention(q, k, v, g_gamma, *, scale): + """Simple GLA (recurrent linear) attention. q, k, v [B, T, H, D], g_gamma [H]. Returns [B, T, H, D].""" + if _native_simple_gla_attention is None: + raise NotImplementedError(_MISSING_MSG) + return Tensor( + _native_simple_gla_attention( + q._underlying, + k._underlying, + v._underlying, + g_gamma._underlying, + float(scale), + ) + ) diff --git a/python/infinicore/ops/simple_gla_decode_step.py b/python/infinicore/ops/simple_gla_decode_step.py new file mode 100644 index 000000000..2593946eb --- /dev/null +++ b/python/infinicore/ops/simple_gla_decode_step.py @@ -0,0 +1,29 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +_native = getattr(_infinicore, "simple_gla_decode_step", None) +if _native is None: + _MISSING_MSG = ( + "simple_gla_decode_step not found in _infinicore. Rebuild InfiniCore extension: " + "cd InfiniCore && xmake build _infinicore" + ) + + +def simple_gla_decode_step(q, k, v, state, g_gamma, *, scale): + """One Simple GLA decode step. + + q, k, v: [B, 1, H, D]. state: [B, H, D, D] float32, updated in-place (must be contiguous). + g_gamma: [H]. Returns output [B, 1, H, D]. + """ + if _native is None: + raise NotImplementedError(_MISSING_MSG) + return Tensor( + _native( + q._underlying, + k._underlying, + v._underlying, + state._underlying, + g_gamma._underlying, + float(scale), + ) + ) diff --git a/python/infinicore/ops/simple_gla_prefill.py b/python/infinicore/ops/simple_gla_prefill.py new file mode 100644 index 000000000..91a328fc1 --- /dev/null +++ b/python/infinicore/ops/simple_gla_prefill.py @@ -0,0 +1,24 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +_native_simple_gla_prefill = getattr(_infinicore, "simple_gla_prefill", None) +if _native_simple_gla_prefill is None: + _MISSING_MSG = ( + "simple_gla_prefill not found in _infinicore. Rebuild InfiniCore extension: " + "cd InfiniCore && xmake build _infinicore" + ) + + +def simple_gla_prefill(q, k, v, g_gamma, *, scale): + """Simple GLA prefill fused kernel. q, k, v [B, T, H, D], g_gamma [H] (F32). Returns [B, T, H, D].""" + if _native_simple_gla_prefill is None: + raise NotImplementedError(_MISSING_MSG) + return Tensor( + _native_simple_gla_prefill( + q._underlying, + k._underlying, + v._underlying, + g_gamma._underlying, + float(scale), + ) + ) diff --git a/python/infinicore/tensor.py b/python/infinicore/tensor.py index 8e2c9b2d6..af42cf19a 100644 --- a/python/infinicore/tensor.py +++ b/python/infinicore/tensor.py @@ -80,6 +80,14 @@ def is_pinned(self): def copy_(self, src): self._underlying.copy_(src._underlying) + def write_i32(self, linear_index, value): + """Write one int32 element at a contiguous linear index (metadata fast path).""" + self._underlying.write_i32(linear_index, int(value)) + + def write_i64(self, linear_index, value): + """Write one int64 element at a contiguous linear index (metadata fast path).""" + self._underlying.write_i64(linear_index, int(value)) + def to(self, *args, **kwargs): return Tensor( self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs) diff --git a/src/infinicore/context/allocators/pinnable_block_allocator.cc b/src/infinicore/context/allocators/pinnable_block_allocator.cc index 5574374a8..7151d162f 100644 --- a/src/infinicore/context/allocators/pinnable_block_allocator.cc +++ b/src/infinicore/context/allocators/pinnable_block_allocator.cc @@ -5,6 +5,7 @@ #include "../../utils.hpp" #include +#include #include #include @@ -72,6 +73,13 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) { block->frozen = pinned_mode_; block->in_use = true; + if (std::getenv("INFINICORE_DEBUG_ALLOC") != nullptr) { + infiniDevice_t dev; + int dev_id; + infinirtGetDevice(&dev, &dev_id); + spdlog::warn("PinnableBlockAllocator cudaMalloc request: requested={} aligned={} class={} device={} id={}", + size, size, cls.block_size, static_cast(dev), dev_id); + } INFINICORE_CHECK_ERROR(infinirtMalloc(&block->ptr, block->size)); all_blocks_[block->ptr] = block; @@ -97,6 +105,13 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) { block->frozen = pinned_mode_; block->in_use = true; + if (std::getenv("INFINICORE_DEBUG_ALLOC") != nullptr) { + infiniDevice_t dev; + int dev_id; + infinirtGetDevice(&dev, &dev_id); + spdlog::warn("PinnableBlockAllocator cudaMalloc request (large): requested={} aligned={} device={} id={}", + size, size, static_cast(dev), dev_id); + } INFINICORE_CHECK_ERROR(infinirtMalloc(&block->ptr, block->size)); large_blocks_.push_back(block); diff --git a/src/infinicore/ops/infllmv2_attention/infllmv2_attention.cc b/src/infinicore/ops/infllmv2_attention/infllmv2_attention.cc new file mode 100644 index 000000000..fe04c7d02 --- /dev/null +++ b/src/infinicore/ops/infllmv2_attention/infllmv2_attention.cc @@ -0,0 +1,675 @@ +/** + * InfLLM-V2 attention ops (varlen + kvcache). + * - With ENABLE_FLASH_ATTN: uses mha_varlen / mha_kvcache (Flash-style) fallback. + * - With ENABLE_INFLLMV2 + ENABLE_ATEN: calls InfLLM-V2 C++ API (mha_varlen_fwd, mha_fwd_kvcache). + * Build InfiniCore with: + * xmake f --aten=y --infllmv2=/abs/path/to/libinfllm_v2.so (recommended) + * or: + * xmake f --aten=y --infllmv2=/abs/path/to/infllmv2_cuda_impl + * or: + * xmake f --aten=y --infllmv2=y (auto-detect under third_party/infllmv2_cuda_impl if you checked it out) + * Linking is handled in xmake (adds DT_NEEDED + rpath to the resolved .so). + */ +#include "infinicore/ops/infllmv2_attention.hpp" + +#include "../../utils.hpp" + +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) +#include +#endif + +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) +#include "infinicore/adaptor/aten_adaptor.hpp" +#include "infinicore/adaptor/infllmv2_api.hpp" +#ifdef ENABLE_NVIDIA_API +#include +#include +#endif +#elif defined(ENABLE_FLASH_ATTN) +#include "infinicore/adaptor/flash_attention_adaptor.hpp" +#include "infinicore/ops/mha_kvcache.hpp" +#include "infinicore/ops/mha_varlen.hpp" +#endif + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(InfllmV2AttentionVarlen); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(InfllmV2AttentionKVCache); +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(InfllmV2AttentionKVCacheUpdate); + +namespace { +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) +inline std::string int_list_to_string(const c10::ArrayRef &xs) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < xs.size(); ++i) { + if (i) { + oss << ", "; + } + oss << xs[i]; + } + oss << "]"; + return oss.str(); +} + +inline void maybe_log_kvcache_inputs(const char *op_name, + const at::Tensor &q, + const at::Tensor &kcache, + const at::Tensor &vcache, + const at::Tensor &seqlens_k, + bool causal, + float scale) { + const char *flag = std::getenv("INFINICORE_INFLLMV2_DUMP_ATEN"); + if (!flag || flag[0] == '\0' || flag[0] == '0') { + return; + } + try { + auto cpu_lens = seqlens_k.to(at::kCPU); + int32_t len0 = cpu_lens.numel() > 0 ? cpu_lens.data_ptr()[0] : -1; + SPDLOG_INFO( + "[infllmv2][{}] q={} kcache={} vcache={} seqlens_k={} seqlens0={} causal={} scale={} q_stride={} k_stride={} v_stride={}", + op_name, + int_list_to_string(q.sizes()), + int_list_to_string(kcache.sizes()), + int_list_to_string(vcache.sizes()), + int_list_to_string(seqlens_k.sizes()), + len0, + causal ? 1 : 0, + scale, + int_list_to_string(q.strides()), + int_list_to_string(kcache.strides()), + int_list_to_string(vcache.strides())); + } catch (...) { + } +} +#else +inline void maybe_log_kvcache_inputs(const char * /*op_name*/, + ...) { + // no-op when ATen is not enabled +} +#endif +} // namespace + +namespace { +void infllmv2_varlen_impl(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cu_seqlens_q, cu_seqlens_k); + +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) + // Direct InfLLM-V2 kernels (link against infllmv2_cuda_impl). +#ifdef ENABLE_NVIDIA_API + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif + auto q_at = infinicore::adaptor::to_aten_tensor(q); + auto k_at = infinicore::adaptor::to_aten_tensor(k); + auto v_at = infinicore::adaptor::to_aten_tensor(v); + auto cu_q_at = infinicore::adaptor::to_aten_tensor(cu_seqlens_q); + auto cu_k_at = infinicore::adaptor::to_aten_tensor(cu_seqlens_k); + auto out_at = std::optional(infinicore::adaptor::to_aten_tensor(out)); + + c10::optional seqused_k = c10::nullopt; + c10::optional leftpad_k = c10::nullopt; + c10::optional block_table = c10::nullopt; + c10::optional alibi_slopes = c10::nullopt; + c10::optional gen_ = c10::nullopt; + c10::optional blockmask_ = c10::nullopt; + + mha_varlen_fwd( + q_at, + k_at, + v_at, + out_at, + cu_q_at, + cu_k_at, + seqused_k, + leftpad_k, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + 0.0f, + scale, + false, + causal, + window_size_left, + window_size_right, + 0.0f, + false, + gen_, + blockmask_); + return; + +#elif defined(ENABLE_FLASH_ATTN) + // Fallback: FlashAttention-based varlen op (same kernel family as InfLLM-V2). + auto dummy_block_table = infinicore::Tensor::zeros( + {cu_seqlens_q->shape()[0] - 1, 1}, + cu_seqlens_q->dtype(), + cu_seqlens_q->device()); + (void)causal; + (void)window_size_left; + (void)window_size_right; + auto tmp = infinicore::op::mha_varlen( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + dummy_block_table, + max_seqlen_q, + max_seqlen_k, + std::nullopt, + scale); + out->copy_(tmp); + return; +#else + (void)k; + (void)v; + (void)cu_seqlens_q; + (void)cu_seqlens_k; + (void)max_seqlen_q; + (void)max_seqlen_k; + (void)scale; + (void)causal; + (void)window_size_left; + (void)window_size_right; + throw std::runtime_error( + "InfLLM-V2 varlen attention requires ENABLE_INFLLMV2+ENABLE_ATEN or ENABLE_FLASH_ATTN build"); +#endif +} + +void infllmv2_kvcache_impl(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, cache_lens); + +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) +#ifdef ENABLE_NVIDIA_API + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif + auto q_at = infinicore::adaptor::to_aten_tensor(q); + auto kcache_at = infinicore::adaptor::to_aten_tensor(k_cache); + auto vcache_at = infinicore::adaptor::to_aten_tensor(v_cache); + auto seqlens_k_at = std::optional(infinicore::adaptor::to_aten_tensor(cache_lens)); + auto out_at = std::optional(infinicore::adaptor::to_aten_tensor(out)); + + c10::optional k_new = c10::nullopt; + c10::optional v_new = c10::nullopt; + c10::optional rotary_cos = c10::nullopt; + c10::optional rotary_sin = c10::nullopt; + c10::optional cache_batch_idx = c10::nullopt; + c10::optional leftpad_k = c10::nullopt; + c10::optional block_table = c10::nullopt; + c10::optional alibi_slopes = c10::nullopt; + c10::optional blockmask_ = c10::nullopt; + + maybe_log_kvcache_inputs("kvcache", q_at, kcache_at, vcache_at, seqlens_k_at.value(), causal, scale); + + // Let FlashAttn/InfLLM-v2 allocate output internally. Passing an explicit out_ tensor + // can interact badly with internal q reshapes in the seqlen_q==1 GQA fast path. + c10::optional out_kernel_opt = c10::nullopt; + auto outs = mha_fwd_kvcache( + q_at, + kcache_at, + vcache_at, + k_new, + v_new, + seqlens_k_at, + rotary_cos, + rotary_sin, + cache_batch_idx, + leftpad_k, + block_table, + alibi_slopes, + out_kernel_opt, + scale, + causal, + window_size_left, + window_size_right, + 0.0f, + false, + 0, + blockmask_); + out_at.value().copy_(outs[0]); + return; + +#elif defined(ENABLE_FLASH_ATTN) + (void)causal; + (void)window_size_left; + (void)window_size_right; + auto device = q->device(); + auto bs = cache_lens->shape()[0]; + auto one = infinicore::Tensor::ones({bs, 1}, cache_lens->dtype(), device); + auto block_table = one; + auto seqlens_k = cache_lens; + auto tmp = infinicore::op::mha_kvcache( + q, + k_cache, + v_cache, + seqlens_k, + block_table, + std::nullopt, + scale); + out->copy_(tmp); + return; +#else + (void)k_cache; + (void)v_cache; + (void)cache_lens; + (void)scale; + (void)causal; + (void)window_size_left; + (void)window_size_right; + throw std::runtime_error( + "InfLLM-V2 kvcache attention requires ENABLE_INFLLMV2+ENABLE_ATEN or ENABLE_FLASH_ATTN build"); +#endif +} + +void infllmv2_kvcache_update_impl(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, k_new, v_new, cache_lens); + +#if defined(ENABLE_INFLLMV2) && defined(ENABLE_ATEN) +#ifdef ENABLE_NVIDIA_API + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif + auto q_at = infinicore::adaptor::to_aten_tensor(q); + auto kcache_at = infinicore::adaptor::to_aten_tensor(k_cache); + auto vcache_at = infinicore::adaptor::to_aten_tensor(v_cache); + auto knew_at = infinicore::adaptor::to_aten_tensor(k_new); + auto vnew_at = infinicore::adaptor::to_aten_tensor(v_new); + auto seqlens_k_at = std::optional(infinicore::adaptor::to_aten_tensor(cache_lens)); + auto out_at = std::optional(infinicore::adaptor::to_aten_tensor(out)); + + c10::optional k_new_opt = std::optional(knew_at); + c10::optional v_new_opt = std::optional(vnew_at); + c10::optional rotary_cos = c10::nullopt; + c10::optional rotary_sin = c10::nullopt; + c10::optional cache_batch_idx = c10::nullopt; + c10::optional leftpad_k = c10::nullopt; + c10::optional block_table = c10::nullopt; + c10::optional alibi_slopes = c10::nullopt; + c10::optional blockmask_ = c10::nullopt; + + maybe_log_kvcache_inputs("kvcache_update", q_at, kcache_at, vcache_at, seqlens_k_at.value(), causal, scale); + + c10::optional out_kernel_opt = c10::nullopt; + auto outs = mha_fwd_kvcache( + q_at, + kcache_at, + vcache_at, + k_new_opt, + v_new_opt, + seqlens_k_at, + rotary_cos, + rotary_sin, + cache_batch_idx, + leftpad_k, + block_table, + alibi_slopes, + out_kernel_opt, + scale, + causal, + window_size_left, + window_size_right, + 0.0f, + false, + 0, + blockmask_); + out_at.value().copy_(outs[0]); + return; + +#elif defined(ENABLE_FLASH_ATTN) + (void)k_new; + (void)v_new; + // FlashAttn adaptor path currently doesn't support in-place cache update in this wrapper. + // Fall back to normal kvcache (expects cache already updated by caller). + infllmv2_kvcache_impl(out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right); + return; +#else + (void)k_cache; + (void)v_cache; + (void)k_new; + (void)v_new; + (void)cache_lens; + (void)scale; + (void)causal; + (void)window_size_left; + (void)window_size_right; + throw std::runtime_error( + "InfLLM-V2 kvcache_update attention requires ENABLE_INFLLMV2+ENABLE_ATEN build"); +#endif +} +} // namespace + +InfllmV2AttentionVarlen::InfllmV2AttentionVarlen(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cu_seqlens_q, cu_seqlens_k); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right); +} + +void InfllmV2AttentionVarlen::execute(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + InfllmV2AttentionVarlen, + out, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right); +} + +// NOTE: we implement run/cleanup with explicit types to keep compilation simple. +namespace { +struct VarlenPlanned { + graph::GraphTensor out, q, k, v, cu_q, cu_k; + int max_q, max_k; + float scale; + bool causal; + int wleft, wright; +}; +void run_varlen_typed(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + infllmv2_varlen_impl(p->out, p->q, p->k, p->v, p->cu_q, p->cu_k, p->max_q, p->max_k, p->scale, p->causal, p->wleft, p->wright); +} +void cleanup_varlen(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} +void *plan_varlen_typed(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + return new VarlenPlanned{graph::GraphTensor(out), graph::GraphTensor(q), graph::GraphTensor(k), graph::GraphTensor(v), + graph::GraphTensor(cu_seqlens_q), graph::GraphTensor(cu_seqlens_k), + max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right}; +} +static bool registered_infllmv2_attention_varlen = []() { + InfllmV2AttentionVarlen::plan_dispatcher().registerAll(&plan_varlen_typed, false); + InfllmV2AttentionVarlen::run_dispatcher().registerAll(&run_varlen_typed, false); + InfllmV2AttentionVarlen::cleanup_dispatcher().registerAll(&cleanup_varlen, false); + return true; +}(); +} // namespace + +void infllmv2_varlen_(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + InfllmV2AttentionVarlen::execute( + out, q, k, v, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + scale, causal, + window_size_left, window_size_right); +} + +Tensor infllmv2_varlen(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + infllmv2_varlen_(out, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right); + return out; +} + +InfllmV2AttentionKVCache::InfllmV2AttentionKVCache(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, cache_lens); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right); +} + +void InfllmV2AttentionKVCache::execute(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + InfllmV2AttentionKVCache, + out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right); +} + +namespace { +struct KVCachePlanned { + graph::GraphTensor out, q, k, v, lens; + float scale; + bool causal; + int wleft, wright; +}; +void *plan_kvcache(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + return new KVCachePlanned{graph::GraphTensor(out), graph::GraphTensor(q), graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), graph::GraphTensor(cache_lens), + scale, causal, window_size_left, window_size_right}; +} +void run_kvcache(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + infllmv2_kvcache_impl(p->out, p->q, p->k, p->v, p->lens, p->scale, p->causal, p->wleft, p->wright); +} +void cleanup_kvcache(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} +static bool registered_infllmv2_attention_kvcache = []() { + InfllmV2AttentionKVCache::plan_dispatcher().registerAll(&plan_kvcache, false); + InfllmV2AttentionKVCache::run_dispatcher().registerAll(&run_kvcache, false); + InfllmV2AttentionKVCache::cleanup_dispatcher().registerAll(&cleanup_kvcache, false); + return true; +}(); +} // namespace + +void infllmv2_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + InfllmV2AttentionKVCache::execute( + out, q, k_cache, v_cache, cache_lens, + scale, causal, + window_size_left, window_size_right); +} + +Tensor infllmv2_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + infllmv2_kvcache_(out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right); + return out; +} + +InfllmV2AttentionKVCacheUpdate::InfllmV2AttentionKVCacheUpdate(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, k_new, v_new, cache_lens); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right); +} + +void InfllmV2AttentionKVCacheUpdate::execute(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + InfllmV2AttentionKVCacheUpdate, + out, q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right); +} + +namespace { +struct KVCacheUpdatePlanned { + graph::GraphTensor out, q, k, v, knew, vnew, lens; + float scale; + bool causal; + int wleft, wright; +}; +void *plan_kvcache_update(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + return new KVCacheUpdatePlanned{graph::GraphTensor(out), graph::GraphTensor(q), + graph::GraphTensor(k_cache), graph::GraphTensor(v_cache), + graph::GraphTensor(k_new), graph::GraphTensor(v_new), + graph::GraphTensor(cache_lens), + scale, causal, window_size_left, window_size_right}; +} +void run_kvcache_update(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + infllmv2_kvcache_update_impl(p->out, p->q, p->k, p->v, p->knew, p->vnew, p->lens, p->scale, p->causal, p->wleft, p->wright); +} +void cleanup_kvcache_update(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} +static bool registered_infllmv2_attention_kvcache_update = []() { + InfllmV2AttentionKVCacheUpdate::plan_dispatcher().registerAll(&plan_kvcache_update, false); + InfllmV2AttentionKVCacheUpdate::run_dispatcher().registerAll(&run_kvcache_update, false); + InfllmV2AttentionKVCacheUpdate::cleanup_dispatcher().registerAll(&cleanup_kvcache_update, false); + return true; +}(); +} // namespace + +void infllmv2_kvcache_update_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + InfllmV2AttentionKVCacheUpdate::execute( + out, q, k_cache, v_cache, k_new, v_new, cache_lens, + scale, causal, + window_size_left, window_size_right); +} + +Tensor infllmv2_kvcache_update(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &k_new, + const Tensor &v_new, + const Tensor &cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + infllmv2_kvcache_update_(out, q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right); + return out; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/ones/ones.cc b/src/infinicore/ops/ones/ones.cc index c28403eaf..6014900fb 100644 --- a/src/infinicore/ops/ones/ones.cc +++ b/src/infinicore/ops/ones/ones.cc @@ -1,5 +1,9 @@ #include "infinicore/ops/ones.hpp" +#include "../../utils.hpp" + +#include + namespace infinicore::op { common::OpDispatcher &Ones::dispatcher() { @@ -8,6 +12,24 @@ common::OpDispatcher &Ones::dispatcher() { }; void Ones::execute(Tensor output) { + infinicore::context::setDevice(output->device()); + auto device_type = output->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Ones implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output); +} + +Tensor ones() { + INFINICORE_ASSERT(false && "Tensor ones() without shape is not supported."); + return {}; +} + +void ones_(Tensor output) { + Ones::execute(output); } } // namespace infinicore::op diff --git a/src/infinicore/ops/ones/ones_infiniop.cc b/src/infinicore/ops/ones/ones_infiniop.cc new file mode 100644 index 000000000..47231f585 --- /dev/null +++ b/src/infinicore/ops/ones/ones_infiniop.cc @@ -0,0 +1,51 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/ones.hpp" +#include + +namespace infinicore::op::ones_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopOnesDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyOnesDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output) { + size_t seed = 0; + infinicore::hash_combine(seed, output); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopOnesDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateOnesDescriptor( + context::getInfiniopHandle(device), &desc, + output->desc(), output->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetOnesWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopOnes( + desc, workspace->data(), workspace_size, + output->data(), output->data(), context::getStream())); +} + +static bool registered = []() { + Ones::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::ones_impl::infiniop diff --git a/src/infinicore/ops/sigmoid/sigmoid.cc b/src/infinicore/ops/sigmoid/sigmoid.cc new file mode 100644 index 000000000..ef4c8cf18 --- /dev/null +++ b/src/infinicore/ops/sigmoid/sigmoid.cc @@ -0,0 +1,37 @@ +#include "infinicore/ops/sigmoid.hpp" + +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +common::OpDispatcher &Sigmoid::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Sigmoid::execute(Tensor output, Tensor input) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); + infinicore::context::setDevice(output->device()); + auto device_type = output->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Sigmoid implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input); +} + +Tensor sigmoid(Tensor input) { + Shape shape = input->shape(); + auto output = Tensor::empty(shape, input->dtype(), input->device()); + sigmoid_(output, input); + return output; +} + +void sigmoid_(Tensor output, Tensor input) { + Sigmoid::execute(output, input); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/sigmoid/sigmoid_infiniop.cc b/src/infinicore/ops/sigmoid/sigmoid_infiniop.cc new file mode 100644 index 000000000..f418d8ad5 --- /dev/null +++ b/src/infinicore/ops/sigmoid/sigmoid_infiniop.cc @@ -0,0 +1,51 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/sigmoid.hpp" + +#include + +namespace infinicore::op::sigmoid_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopSigmoidDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroySigmoidDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output, Tensor input) { + size_t seed = hash_combine(output, input); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopSigmoidDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateSigmoidDescriptor( + context::getInfiniopHandle(device), &desc, + output->desc(), input->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetSigmoidWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopSigmoid( + desc, workspace->data(), workspace_size, + output->data(), input->data(), context::getStream())); +} + +static bool registered = []() { + Sigmoid::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::sigmoid_impl::infiniop diff --git a/src/infinicore/ops/simple_gla_attention/simple_gla_attention.cc b/src/infinicore/ops/simple_gla_attention/simple_gla_attention.cc new file mode 100644 index 000000000..8a893b985 --- /dev/null +++ b/src/infinicore/ops/simple_gla_attention/simple_gla_attention.cc @@ -0,0 +1,188 @@ +#include "infinicore/ops/simple_gla_attention.hpp" + +#include "../../../utils.h" +#include "../../utils.hpp" +#include "infinicore/context/context.hpp" +#include +#include +#include +#include + +namespace infinicore::op { + +namespace { + +// Read one element from tensor at flat index, convert to float. +template +inline float read_float(const std::byte *ptr, size_t idx) { + return static_cast(*reinterpret_cast(ptr + idx * sizeof(T))); +} + +inline float read_float_at(const std::byte *ptr, size_t idx, DataType dtype) { + switch (dtype) { + case DataType::F32: + return read_float(ptr, idx); + case DataType::F16: + return _f16_to_f32(*reinterpret_cast(ptr + idx * 2)); + case DataType::BF16: + return _bf16_to_f32(*reinterpret_cast(ptr + idx * 2)); + default: + throw std::runtime_error("simple_gla_attention: unsupported dtype (need F32, F16, or BF16)"); + } +} + +// Write one float to tensor at flat index. +inline void write_float_at(std::byte *ptr, size_t idx, DataType dtype, float val) { + switch (dtype) { + case DataType::F32: + *reinterpret_cast(ptr + idx * 4) = val; + break; + case DataType::F16: + *reinterpret_cast(ptr + idx * 2) = _f32_to_f16(val); + break; + case DataType::BF16: + *reinterpret_cast(ptr + idx * 2) = _f32_to_bf16(val); + break; + default: + throw std::runtime_error("simple_gla_attention: unsupported dtype (need F32, F16, or BF16)"); + } +} + +void simple_gla_attention_cpu_impl(Tensor &out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale) { + const auto &q_shape = q->shape(); + const size_t B = q_shape[0]; + const size_t T = q_shape[1]; + const size_t H = q_shape[2]; + const size_t D = q_shape[3]; + + INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape); + INFINICORE_ASSERT(g_gamma->shape().size() == 1 && g_gamma->shape()[0] == H); + + const DataType dtype = q->dtype(); + const std::byte *q_ptr = q->data(); + const std::byte *k_ptr = k->data(); + const std::byte *v_ptr = v->data(); + const std::byte *g_ptr = g_gamma->data(); + std::byte *out_ptr = out->data(); + + // Contiguous layout (B, T, H, D): index (b,t,h,d) = b*T*H*D + t*H*D + h*D + d + const size_t stride_b = T * H * D; + const size_t stride_t = H * D; + const size_t stride_h = D; + + // Gate (H,) in float + std::vector gate(H); + for (size_t h = 0; h < H; ++h) { + gate[h] = std::exp(read_float_at(g_ptr, h, g_gamma->dtype())); + } + + // State S: (B, H, D, D) in float, row-major + std::vector S(B * H * D * D, 0.f); + + for (size_t t = 0; t < T; ++t) { + const size_t t_offset = t * stride_t; + + // 1. S = S * gate + outer(k_t, v_t) + // k_t (b,h,d_k), v_t (b,h,d_v) -> kv(b,h,d_k,d_v) = k_t(b,h,d_k) * v_t(b,h,d_v) + for (size_t b = 0; b < B; ++b) { + const size_t b_offset = b * stride_b + t_offset; + for (size_t h = 0; h < H; ++h) { + const float g = gate[h]; + float *S_bh = S.data() + (b * H + h) * (D * D); + + // Scale S by gate + for (size_t i = 0; i < D * D; ++i) { + S_bh[i] *= g; + } + + // Add outer(k_t, v_t) + for (size_t dk = 0; dk < D; ++dk) { + size_t qk_idx = b_offset + h * stride_h + dk; + float k_val = read_float_at(k_ptr, qk_idx, dtype); + for (size_t dv = 0; dv < D; ++dv) { + size_t qv_idx = b_offset + h * stride_h + dv; + float v_val = read_float_at(v_ptr, qv_idx, dtype); + S_bh[dk * D + dv] += k_val * v_val; + } + } + } + } + + // 2. o_t = (q_t * scale) @ S -> (B, H, D) for each (b,h): o[b,h,:] = scale * (q_t[b,h,:] @ S[b,h,:,:]) + for (size_t b = 0; b < B; ++b) { + const size_t b_offset = b * stride_b + t_offset; + for (size_t h = 0; h < H; ++h) { + const float *S_bh = S.data() + (b * H + h) * (D * D); + for (size_t dv = 0; dv < D; ++dv) { + float acc = 0.f; + for (size_t dk = 0; dk < D; ++dk) { + size_t q_idx = b_offset + h * stride_h + dk; + float q_val = read_float_at(q_ptr, q_idx, dtype) * scale; + acc += q_val * S_bh[dk * D + dv]; + } + size_t out_idx = b_offset + h * stride_h + dv; + write_float_at(out_ptr, out_idx, dtype, acc); + } + } + } + } +} + +void simple_gla_attention_cpu_calculate(Tensor &out, const Tensor &q, const Tensor &k, + const Tensor &v, const Tensor &g_gamma, float scale) { + simple_gla_attention_cpu_impl(out, q, k, v, g_gamma, scale); +} + +static bool register_cpu = []() { + SimpleGlaAttention::dispatcher().registerDevice(Device::Type::CPU, &simple_gla_attention_cpu_calculate, + false); + return true; +}(); + +} // namespace + +common::OpDispatcher &SimpleGlaAttention::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void SimpleGlaAttention::execute(Tensor &out, const Tensor &q, const Tensor &k, const Tensor &v, + const Tensor &g_gamma, float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k, v, g_gamma); + infinicore::context::setDevice(q->device()); + auto device_type = infinicore::context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + if (func == nullptr) { + throw std::runtime_error("simple_gla_attention: no implementation for device type " + std::to_string(static_cast(device_type))); + } + func(out, q, k, v, g_gamma, scale); +} + +Tensor simple_gla_attention(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k, v, g_gamma); + + const auto &q_shape = q->shape(); + INFINICORE_ASSERT(q_shape.size() == 4); + INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape); + INFINICORE_ASSERT(g_gamma->shape().size() == 1 && g_gamma->shape()[0] == q_shape[2]); + + auto q_cont = q->contiguous(); + auto k_cont = k->contiguous(); + auto v_cont = v->contiguous(); + auto g_cont = g_gamma->contiguous(); + + auto out = Tensor::empty(q_shape, q->dtype(), q->device()); + SimpleGlaAttention::execute(out, q_cont, k_cont, v_cont, g_cont, scale); + return out; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/simple_gla_attention/simple_gla_attention_nvidia.cc b/src/infinicore/ops/simple_gla_attention/simple_gla_attention_nvidia.cc new file mode 100644 index 000000000..41b7bc79e --- /dev/null +++ b/src/infinicore/ops/simple_gla_attention/simple_gla_attention_nvidia.cc @@ -0,0 +1,30 @@ +#include "infinicore/ops/simple_gla_attention.hpp" + +#include "../../utils.hpp" +#include "infinicore/ops/simple_gla_prefill.hpp" + +namespace infinicore::op { + +namespace { + +// Prefer the `simple_gla_prefill` implementation (InfiniOP-backed) on NVIDIA. +void simple_gla_attention_nvidia_calculate(Tensor &out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale) { + SimpleGLAPrefill::execute(out, q, k, v, g_gamma, scale); +} + +static bool register_nvidia = []() { + SimpleGlaAttention::dispatcher().registerDevice( + Device::Type::NVIDIA, + &simple_gla_attention_nvidia_calculate, + false); + return true; +}(); + +} // namespace + +} // namespace infinicore::op diff --git a/src/infinicore/ops/simple_gla_decode_step/simple_gla_decode_step.cc b/src/infinicore/ops/simple_gla_decode_step/simple_gla_decode_step.cc new file mode 100644 index 000000000..8d56e47d9 --- /dev/null +++ b/src/infinicore/ops/simple_gla_decode_step/simple_gla_decode_step.cc @@ -0,0 +1,186 @@ +#include "infinicore/ops/simple_gla_decode_step.hpp" + +#include "../../../utils.h" +#include "../../utils.hpp" +#include "infinicore/context/context.hpp" +#include +#include +#include + +namespace infinicore::op { + +namespace { + +template +inline float read_float(const std::byte *ptr, size_t idx) { + return static_cast(*reinterpret_cast(ptr + idx * sizeof(T))); +} + +inline float read_float_at(const std::byte *ptr, size_t idx, DataType dtype) { + switch (dtype) { + case DataType::F32: + return read_float(ptr, idx); + case DataType::F16: + return _f16_to_f32(*reinterpret_cast(ptr + idx * 2)); + case DataType::BF16: + return _bf16_to_f32(*reinterpret_cast(ptr + idx * 2)); + default: + throw std::runtime_error("simple_gla_decode_step: q/k/v need F32, F16, or BF16"); + } +} + +inline void write_float_at(std::byte *ptr, size_t idx, DataType dtype, float val) { + switch (dtype) { + case DataType::F32: + *reinterpret_cast(ptr + idx * 4) = val; + break; + case DataType::F16: + *reinterpret_cast(ptr + idx * 2) = _f32_to_f16(val); + break; + case DataType::BF16: + *reinterpret_cast(ptr + idx * 2) = _f32_to_bf16(val); + break; + default: + throw std::runtime_error("simple_gla_decode_step: out dtype unsupported"); + } +} + +void simple_gla_decode_step_cpu_impl(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, + const Tensor &v, const Tensor &g_gamma, float scale) { + const auto &q_shape = q->shape(); + INFINICORE_ASSERT(q_shape.size() == 4 && q_shape[1] == 1); + INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape); + + const size_t B = q_shape[0]; + const size_t H = q_shape[2]; + const size_t D = q_shape[3]; + + INFINICORE_ASSERT(state->shape().size() == 4 && state->shape()[0] == B && state->shape()[1] == H && state->shape()[2] == D && state->shape()[3] == D); + INFINICORE_ASSERT(state->dtype() == DataType::F32); + INFINICORE_ASSERT(g_gamma->shape().size() == 1 && g_gamma->shape()[0] == H); + + const auto &out_shape = out->shape(); + INFINICORE_ASSERT(out_shape == q_shape); + INFINICORE_ASSERT(out->dtype() == q->dtype()); + + const DataType q_dtype = q->dtype(); + const std::byte *q_ptr = q->data(); + const std::byte *k_ptr = k->data(); + const std::byte *v_ptr = v->data(); + const std::byte *g_ptr = g_gamma->data(); + std::byte *out_ptr = out->data(); + float *s_ptr = reinterpret_cast(state->data()); + + const size_t stride_b = H * D; + const size_t stride_h = D; + + std::vector gate(H); + for (size_t h = 0; h < H; ++h) { + gate[h] = std::exp(read_float_at(g_ptr, h, g_gamma->dtype())); + } + + const size_t t_offset = 0; + + for (size_t b = 0; b < B; ++b) { + const size_t b_offset = b * stride_b + t_offset; + for (size_t h = 0; h < H; ++h) { + const float g = gate[h]; + float *S_bh = s_ptr + (b * H + h) * (D * D); + + for (size_t i = 0; i < D * D; ++i) { + S_bh[i] *= g; + } + + for (size_t dk = 0; dk < D; ++dk) { + size_t k_idx = b_offset + h * stride_h + dk; + float k_val = read_float_at(k_ptr, k_idx, q_dtype); + for (size_t dv = 0; dv < D; ++dv) { + size_t v_idx = b_offset + h * stride_h + dv; + float v_val = read_float_at(v_ptr, v_idx, q_dtype); + S_bh[dk * D + dv] += k_val * v_val; + } + } + } + } + + for (size_t b = 0; b < B; ++b) { + const size_t b_offset = b * stride_b + t_offset; + for (size_t h = 0; h < H; ++h) { + const float *S_bh = s_ptr + (b * H + h) * (D * D); + for (size_t dv = 0; dv < D; ++dv) { + float acc = 0.f; + for (size_t dk = 0; dk < D; ++dk) { + size_t q_idx = b_offset + h * stride_h + dk; + float q_val = read_float_at(q_ptr, q_idx, q_dtype) * scale; + acc += q_val * S_bh[dk * D + dv]; + } + size_t out_idx = b_offset + h * stride_h + dv; + write_float_at(out_ptr, out_idx, q_dtype, acc); + } + } + } +} + +void simple_gla_decode_step_cpu_calculate(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, + const Tensor &v, const Tensor &g_gamma, float scale) { + simple_gla_decode_step_cpu_impl(out, state, q, k, v, g_gamma, scale); +} + +static bool register_cpu = []() { + SimpleGlaDecodeStep::dispatcher().registerDevice(Device::Type::CPU, &simple_gla_decode_step_cpu_calculate, + false); + return true; +}(); + +} // namespace + +common::OpDispatcher &SimpleGlaDecodeStep::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void SimpleGlaDecodeStep::execute(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, const Tensor &v, + const Tensor &g_gamma, float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k, v, g_gamma, state, out); + + const auto &q_shape = q->shape(); + INFINICORE_ASSERT(q_shape.size() == 4 && q_shape[1] == 1); + INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape); + INFINICORE_ASSERT(out->shape() == q_shape && out->dtype() == q->dtype()); + + const size_t B = q_shape[0]; + const size_t H = q_shape[2]; + const size_t D = q_shape[3]; + INFINICORE_ASSERT(state->shape().size() == 4 && state->shape()[0] == B && state->shape()[1] == H && state->shape()[2] == D && state->shape()[3] == D); + INFINICORE_ASSERT(state->dtype() == DataType::F32); + INFINICORE_ASSERT(g_gamma->shape().size() == 1 && g_gamma->shape()[0] == H); + INFINICORE_ASSERT(state->is_contiguous() && out->is_contiguous()); + + infinicore::context::setDevice(q->device()); + auto device_type = infinicore::context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + if (func == nullptr) { + throw std::runtime_error("simple_gla_decode_step: no implementation for device type " + std::to_string(static_cast(device_type))); + } + func(out, state, q, k, v, g_gamma, scale); +} + +Tensor simple_gla_decode_step(const Tensor &q, const Tensor &k, const Tensor &v, Tensor &state, + const Tensor &g_gamma, float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k, v, g_gamma, state); + + const auto &q_shape = q->shape(); + INFINICORE_ASSERT(q_shape.size() == 4 && q_shape[1] == 1); + INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape); + + auto q_cont = q->contiguous(); + auto k_cont = k->contiguous(); + auto v_cont = v->contiguous(); + auto g_cont = g_gamma->contiguous(); + + auto out = Tensor::empty(q_shape, q->dtype(), q->device()); + SimpleGlaDecodeStep::execute(out, state, q_cont, k_cont, v_cont, g_cont, scale); + return out; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/simple_gla_decode_step/simple_gla_decode_step_nvidia.cc b/src/infinicore/ops/simple_gla_decode_step/simple_gla_decode_step_nvidia.cc new file mode 100644 index 000000000..dd2faf562 --- /dev/null +++ b/src/infinicore/ops/simple_gla_decode_step/simple_gla_decode_step_nvidia.cc @@ -0,0 +1,61 @@ +#ifdef ENABLE_ATEN +#include "infinicore/ops/simple_gla_decode_step.hpp" + +#include "../../utils.hpp" +#include "infinicore/adaptor/aten_adaptor.hpp" +#include "infinicore/context/context.hpp" + +#ifdef ENABLE_NVIDIA_API +#include +#include +#include +#endif + +namespace infinicore::op { + +#ifdef ENABLE_NVIDIA_API +namespace { + +void simple_gla_decode_step_nvidia_impl(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, + const Tensor &v, const Tensor &g_gamma, float scale) { + auto aq = infinicore::adaptor::to_aten_tensor(q); + auto ak = infinicore::adaptor::to_aten_tensor(k); + auto av = infinicore::adaptor::to_aten_tensor(v); + auto ag = infinicore::adaptor::to_aten_tensor(g_gamma); + auto aout = infinicore::adaptor::to_aten_tensor(out); + auto aS = infinicore::adaptor::to_aten_tensor(state); + + aq = aq.transpose(1, 2).contiguous(); + ak = ak.transpose(1, 2).contiguous(); + av = av.transpose(1, 2).contiguous(); + + auto gate = ag.exp().to(at::kFloat); + + at::Tensor k_t = ak.select(2, 0); + at::Tensor v_t = av.select(2, 0); + at::Tensor kv = k_t.unsqueeze(-1).mul(v_t.unsqueeze(-2)); + at::Tensor newS = aS.mul(gate.view({1, -1, 1, 1})).add(kv.to(aS.scalar_type())); + aS.copy_(newS); + + at::Tensor q_t = aq.select(2, 0).to(at::kFloat).mul(scale); + at::Tensor o_t = q_t.unsqueeze(-2).matmul(aS).squeeze(-2); + aout.select(1, 0).copy_(o_t.to(aout.scalar_type())); +} + +void simple_gla_decode_step_nvidia_calculate(Tensor &out, Tensor &state, const Tensor &q, const Tensor &k, + const Tensor &v, const Tensor &g_gamma, float scale) { + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + simple_gla_decode_step_nvidia_impl(out, state, q, k, v, g_gamma, scale); +} + +static bool register_nvidia = []() { + SimpleGlaDecodeStep::dispatcher().registerDevice(Device::Type::NVIDIA, &simple_gla_decode_step_nvidia_calculate, + false); + return true; +}(); + +} // namespace +#endif // ENABLE_NVIDIA_API + +} // namespace infinicore::op +#endif // ENABLE_ATEN diff --git a/src/infinicore/ops/simple_gla_prefill/simple_gla_prefill.cc b/src/infinicore/ops/simple_gla_prefill/simple_gla_prefill.cc new file mode 100644 index 000000000..d3627b5a5 --- /dev/null +++ b/src/infinicore/ops/simple_gla_prefill/simple_gla_prefill.cc @@ -0,0 +1,37 @@ +#include "infinicore/ops/simple_gla_prefill.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SimpleGLAPrefill); + +SimpleGLAPrefill::SimpleGLAPrefill(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, g_gamma); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, q, k, v, g_gamma, scale); +} + +void SimpleGLAPrefill::execute(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(SimpleGLAPrefill, out, q, k, v, g_gamma, scale); +} + +Tensor simple_gla_prefill(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g_gamma, + float scale) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + SimpleGLAPrefill::execute(out, q, k, v, g_gamma, scale); + return out; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/simple_gla_prefill/simple_gla_prefill_infiniop.cc b/src/infinicore/ops/simple_gla_prefill/simple_gla_prefill_infiniop.cc new file mode 100644 index 000000000..b79c9976b --- /dev/null +++ b/src/infinicore/ops/simple_gla_prefill/simple_gla_prefill_infiniop.cc @@ -0,0 +1,83 @@ +#include "infinicore/ops/simple_gla_prefill.hpp" + +#include "../infiniop_impl.hpp" +#include "infinicore/context/context.hpp" + +#include + +namespace infinicore::op::simple_gla_prefill_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SimpleGLAPrefill, 64); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace; + graph::GraphTensor out; + graph::GraphTensor q; + graph::GraphTensor k; + graph::GraphTensor v; + graph::GraphTensor g; + float scale; +}; + +static void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &g, float scale) { + size_t key = hash_combine(out, q, k, v, g, static_cast(scale * 1000000.0f)); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, SimpleGLAPrefill, + key, out->desc(), q->desc(), k->desc(), v->desc(), g->desc()); + + // Scratch workspace allocation can dominate VRAM growth for long-context prefill. + // Reuse a single workspace tensor per (workspace_size, device) across layers. + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR( + infiniopGetSimpleGLAPrefillWorkspaceSize(descriptor->desc, &workspace_size)); + + thread_local common::OpCache workspace_caches(8 /*capacity*/); + auto device__ = context::getDevice(); + auto &cache__ = workspace_caches.getCache(device__); + + Tensor workspace; + if (auto cached = cache__.get(workspace_size); cached.has_value()) { + workspace = *cached; + } else { + workspace = Tensor::empty({workspace_size}, DataType::U8, device__); + cache__.put(workspace_size, workspace); + } + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(g), + scale, + }; +} + +static void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + INFINICORE_CHECK_ERROR( + infiniopSimpleGLAPrefill( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->out->data(), + p->q->data(), + p->k->data(), + p->v->data(), + p->g->data(), + p->scale, + context::getStream())); +} + +static void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SimpleGLAPrefill, &plan, &run, &cleanup); + +} // namespace infinicore::op::simple_gla_prefill_impl::infiniop diff --git a/src/infinicore/ops/simple_gla_recurrent_state_append/simple_gla_recurrent_state_append.cc b/src/infinicore/ops/simple_gla_recurrent_state_append/simple_gla_recurrent_state_append.cc new file mode 100644 index 000000000..ba4025b66 --- /dev/null +++ b/src/infinicore/ops/simple_gla_recurrent_state_append/simple_gla_recurrent_state_append.cc @@ -0,0 +1,76 @@ +#include "infinicore/ops/simple_gla_recurrent_state_append.hpp" + +#include "../../utils.hpp" +#include "infinicore/context/context.hpp" +#include "infinicore/ops/simple_gla_decode_step.hpp" + +namespace infinicore::op { + +namespace { + +void simple_gla_recurrent_state_append_cpu(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, + const Tensor &g_gamma) { + const auto &sh = k_seg->shape(); + const size_t B = sh[0]; + const size_t L = sh[1]; + const size_t H = sh[2]; + const size_t D = sh[3]; + INFINICORE_ASSERT(v_seg->shape() == sh); + + auto q_zero = Tensor::zeros({B, 1, H, D}, k_seg->dtype(), k_seg->device()); + auto out = Tensor::empty({B, 1, H, D}, k_seg->dtype(), k_seg->device()); + for (size_t t = 0; t < L; ++t) { + auto kt = k_seg->narrow({{1, t, 1}}); + auto vt = v_seg->narrow({{1, t, 1}}); + SimpleGlaDecodeStep::execute(out, state, q_zero, kt, vt, g_gamma, 1.0f); + } +} + +static bool register_cpu = []() { + SimpleGlaRecurrentStateAppend::dispatcher().registerDevice(Device::Type::CPU, &simple_gla_recurrent_state_append_cpu, + false); + return true; +}(); + +} // namespace + +common::OpDispatcher &SimpleGlaRecurrentStateAppend::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void SimpleGlaRecurrentStateAppend::execute(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, + const Tensor &g_gamma) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(state, k_seg, v_seg, g_gamma); + + const auto &sh = k_seg->shape(); + INFINICORE_ASSERT(sh.size() == 4); + const size_t B = sh[0]; + const size_t L = sh[1]; + const size_t H = sh[2]; + const size_t D = sh[3]; + INFINICORE_ASSERT(v_seg->shape() == sh); + INFINICORE_ASSERT(state->shape() == Shape({B, H, D, D})); + INFINICORE_ASSERT(state->dtype() == DataType::F32); + INFINICORE_ASSERT(g_gamma->shape() == Shape({H})); + INFINICORE_ASSERT(state->is_contiguous()); + + if (L == 0) { + return; + } + + infinicore::context::setDevice(state->device()); + auto dev = infinicore::context::getDevice().getType(); + auto fn = SimpleGlaRecurrentStateAppend::dispatcher().lookup(dev); + if (fn == nullptr) { + throw std::runtime_error("simple_gla_recurrent_state_append_segment: no implementation for device type " + std::to_string(static_cast(dev))); + } + fn(state, k_seg->contiguous(), v_seg->contiguous(), g_gamma); +} + +void simple_gla_recurrent_state_append_segment(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, + const Tensor &g_gamma) { + SimpleGlaRecurrentStateAppend::execute(state, k_seg, v_seg, g_gamma); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/simple_gla_recurrent_state_append/simple_gla_recurrent_state_append_nvidia.cc b/src/infinicore/ops/simple_gla_recurrent_state_append/simple_gla_recurrent_state_append_nvidia.cc new file mode 100644 index 000000000..76d3bc646 --- /dev/null +++ b/src/infinicore/ops/simple_gla_recurrent_state_append/simple_gla_recurrent_state_append_nvidia.cc @@ -0,0 +1,70 @@ +#ifdef ENABLE_ATEN +#include "infinicore/ops/simple_gla_recurrent_state_append.hpp" + +#include "../../utils.hpp" +#include "infinicore/adaptor/aten_adaptor.hpp" +#include "infinicore/context/context.hpp" + +#ifdef ENABLE_NVIDIA_API +#include +#include +#include +#endif + +namespace infinicore::op { + +#ifdef ENABLE_NVIDIA_API +namespace { + +void simple_gla_recurrent_state_append_nvidia(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, + const Tensor &g_gamma) { + auto ak = infinicore::adaptor::to_aten_tensor(k_seg); + auto av = infinicore::adaptor::to_aten_tensor(v_seg); + auto ag = infinicore::adaptor::to_aten_tensor(g_gamma); + auto aS = infinicore::adaptor::to_aten_tensor(state); + + const int64_t B = ak.size(0); + const int64_t L = ak.size(1); + const int64_t H = ak.size(2); + const int64_t D = ak.size(3); + INFINICORE_ASSERT(L > 0); + + auto gate = ag.exp().to(at::kFloat).contiguous(); + auto j = at::arange(L, at::TensorOptions().dtype(at::kFloat).device(gate.device())); + auto pows = (static_cast(L - 1) - j) * 0.5f; + auto w = at::pow(gate.unsqueeze(0), pows.unsqueeze(1)).to(at::kFloat).contiguous(); + + // [B,L,H,D] -> [B,H,L,D] so each (b,h) owns a contiguous [L,D] panel for bmm. + auto kf = ak.to(at::kFloat).permute({0, 2, 1, 3}).contiguous(); + auto vf = av.to(at::kFloat).permute({0, 2, 1, 3}).contiguous(); + // w[l,h] scales token l, head h; broadcast to [1,H,L,1] on [B,H,L,D]. + auto w_bhl = w.transpose(0, 1).contiguous().view({1, H, L, 1}); + kf.mul_(w_bhl); + vf.mul_(w_bhl); + + auto ks = kf.view({B * H, L, D}); + auto vs = vf.view({B * H, L, D}); + auto s_inc = at::bmm(ks.transpose(1, 2), vs); + + auto gL = at::pow(gate, static_cast(L)).view({1, H, 1, 1}); + aS.mul_(gL).add_(s_inc.view({B, H, D, D})); +} + +void simple_gla_recurrent_state_append_nvidia_calculate(Tensor &state, const Tensor &k_seg, const Tensor &v_seg, + const Tensor &g_gamma) { + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + simple_gla_recurrent_state_append_nvidia(state, k_seg, v_seg, g_gamma); +} + +static bool register_nvidia = []() { + SimpleGlaRecurrentStateAppend::dispatcher().registerDevice(Device::Type::NVIDIA, + &simple_gla_recurrent_state_append_nvidia_calculate, + false); + return true; +}(); + +} // namespace +#endif // ENABLE_NVIDIA_API + +} // namespace infinicore::op +#endif // ENABLE_ATEN diff --git a/src/infinicore/ops/zeros/zeros.cc b/src/infinicore/ops/zeros/zeros.cc new file mode 100644 index 000000000..cebcc7378 --- /dev/null +++ b/src/infinicore/ops/zeros/zeros.cc @@ -0,0 +1,30 @@ +#include "infinicore/ops/zeros.hpp" + +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +common::OpDispatcher &Zeros::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Zeros::execute(Tensor output) { + infinicore::context::setDevice(output->device()); + auto device_type = output->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Zeros implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output); +} + +void zeros_(Tensor output) { + Zeros::execute(output); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/zeros/zeros_infiniop.cc b/src/infinicore/ops/zeros/zeros_infiniop.cc new file mode 100644 index 000000000..7ea590515 --- /dev/null +++ b/src/infinicore/ops/zeros/zeros_infiniop.cc @@ -0,0 +1,51 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/zeros.hpp" +#include + +namespace infinicore::op::zeros_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopZerosDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyZerosDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output) { + size_t seed = 0; + infinicore::hash_combine(seed, output); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopZerosDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateZerosDescriptor( + context::getInfiniopHandle(device), &desc, + output->desc(), output->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetZerosWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopZeros( + desc, workspace->data(), workspace_size, + output->data(), output->data(), context::getStream())); +} + +static bool registered = []() { + Zeros::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::zeros_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 9e3ac4377..c8d8da63a 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -49,6 +49,7 @@ #include "ops/hypot.hpp" #include "ops/index_add.hpp" #include "ops/index_copy.hpp" +#include "ops/infllmv2_attention.hpp" #include "ops/inner.hpp" #include "ops/interpolate.hpp" #include "ops/kron.hpp" @@ -86,6 +87,9 @@ #include "ops/selu.hpp" #include "ops/silu.hpp" #include "ops/silu_and_mul.hpp" +#include "ops/simple_gla_attention.hpp" +#include "ops/simple_gla_decode_step.hpp" +#include "ops/simple_gla_prefill.hpp" #include "ops/sinh.hpp" #include "ops/smooth_l1_loss.hpp" #include "ops/softplus.hpp" @@ -135,6 +139,10 @@ inline void bind(py::module &m) { bind_dist(m); bind_flash_attention(m); bind_hinge_embedding_loss(m); + bind_infllmv2_attention(m); + bind_simple_gla_attention(m); + bind_simple_gla_decode_step(m); + bind_simple_gla_prefill(m); bind_kv_caching(m); bind_fmod(m); bind_fmin(m); diff --git a/src/infinicore/pybind11/ops/infllmv2_attention.hpp b/src/infinicore/pybind11/ops/infllmv2_attention.hpp new file mode 100644 index 000000000..851c60bc2 --- /dev/null +++ b/src/infinicore/pybind11/ops/infllmv2_attention.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include + +#include "infinicore/ops/infllmv2_attention.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline Tensor py_infllmv2_varlen(Tensor q, + Tensor k, + Tensor v, + Tensor cu_seqlens_q, + Tensor cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + return op::infllmv2_attention_varlen( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + scale, causal, + window_size_left, window_size_right); +} + +inline Tensor py_infllmv2_kvcache(Tensor q, + Tensor k_cache, + Tensor v_cache, + Tensor cache_lens, + float scale, + bool causal, + int window_size_left, + int window_size_right) { + return op::infllmv2_attention_kvcache( + q, k_cache, v_cache, + cache_lens, + scale, causal, + window_size_left, window_size_right); +} + +inline void bind_infllmv2_attention(py::module &m) { + m.def( + "infllmv2_attention_varlen", + &ops::py_infllmv2_varlen, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("cu_seqlens_q"), + py::arg("cu_seqlens_k"), + py::arg("max_seqlen_q"), + py::arg("max_seqlen_k"), + py::arg("scale"), + py::arg("causal"), + py::arg("window_size_left") = -1, + py::arg("window_size_right") = -1, + R"doc(InfLLM-V2 varlen attention. q,k,v unpadded; cu_seqlens_q/k [batch+1]. Returns [total_q, nheads, head_dim].)doc"); + + m.def( + "infllmv2_attention_kvcache", + &ops::py_infllmv2_kvcache, + py::arg("q"), + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("cache_lens"), + py::arg("scale"), + py::arg("causal"), + py::arg("window_size_left") = -1, + py::arg("window_size_right") = -1, + R"doc(InfLLM-V2 KV-cache (decode) attention. Returns [batch, seqlen_q, nheads, head_dim].)doc"); + + // Backward-compatible Python names. + m.def( + "infllmv2_varlen", + &ops::py_infllmv2_varlen, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("cu_seqlens_q"), + py::arg("cu_seqlens_k"), + py::arg("max_seqlen_q"), + py::arg("max_seqlen_k"), + py::arg("scale"), + py::arg("causal"), + py::arg("window_size_left") = -1, + py::arg("window_size_right") = -1, + R"doc(Deprecated alias for infllmv2_attention_varlen.)doc"); + + m.def( + "infllmv2_kvcache", + &ops::py_infllmv2_kvcache, + py::arg("q"), + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("cache_lens"), + py::arg("scale"), + py::arg("causal"), + py::arg("window_size_left") = -1, + py::arg("window_size_right") = -1, + R"doc(Deprecated alias for infllmv2_attention_kvcache.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/simple_gla_attention.hpp b/src/infinicore/pybind11/ops/simple_gla_attention.hpp new file mode 100644 index 000000000..cc0c6dfca --- /dev/null +++ b/src/infinicore/pybind11/ops/simple_gla_attention.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "infinicore/ops/simple_gla_attention.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline Tensor py_simple_gla_attention(Tensor q, + Tensor k, + Tensor v, + Tensor g_gamma, + float scale) { + return op::simple_gla_attention(q, k, v, g_gamma, scale); +} + +inline void bind_simple_gla_attention(py::module &m) { + m.def( + "simple_gla_attention", + &ops::py_simple_gla_attention, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g_gamma"), + py::arg("scale"), + R"doc(Simple GLA (recurrent linear) attention. q, k, v [B, T, H, D], g_gamma [H]. Returns [B, T, H, D].)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/simple_gla_decode_step.hpp b/src/infinicore/pybind11/ops/simple_gla_decode_step.hpp new file mode 100644 index 000000000..f6b5b5423 --- /dev/null +++ b/src/infinicore/pybind11/ops/simple_gla_decode_step.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include "infinicore/ops/simple_gla_decode_step.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline Tensor py_simple_gla_decode_step(Tensor q, Tensor k, Tensor v, Tensor state, Tensor g_gamma, float scale) { + return op::simple_gla_decode_step(q, k, v, state, g_gamma, scale); +} + +inline void bind_simple_gla_decode_step(py::module &m) { + m.def( + "simple_gla_decode_step", + &ops::py_simple_gla_decode_step, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("state"), + py::arg("g_gamma"), + py::arg("scale"), + R"doc(Simple GLA one decode step. q,k,v [B,1,H,D] (same dtype); state [B,H,D,D] float32 in-place; g_gamma [H]. Returns [B,1,H,D].)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/simple_gla_prefill.hpp b/src/infinicore/pybind11/ops/simple_gla_prefill.hpp new file mode 100644 index 000000000..69bdbda36 --- /dev/null +++ b/src/infinicore/pybind11/ops/simple_gla_prefill.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include "infinicore/ops/simple_gla_prefill.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline Tensor py_simple_gla_prefill(Tensor q, + Tensor k, + Tensor v, + Tensor g_gamma, + float scale) { + return op::simple_gla_prefill(q, k, v, g_gamma, scale); +} + +inline void bind_simple_gla_prefill(py::module &m) { + m.def( + "simple_gla_prefill", + &ops::py_simple_gla_prefill, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g_gamma"), + py::arg("scale"), + R"doc(Simple GLA prefill fused kernel. q, k, v [B, T, H, D], g_gamma [H] (F32). Returns [B, T, H, D].)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/tensor.hpp b/src/infinicore/pybind11/tensor.hpp index ff6c205a0..e5ca218a8 100644 --- a/src/infinicore/pybind11/tensor.hpp +++ b/src/infinicore/pybind11/tensor.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include #include @@ -36,7 +38,43 @@ inline void bind(py::module &m) { .def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); }) .def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); }) .def("unsqueeze", [](const Tensor &tensor, std::size_t dim) { return tensor->unsqueeze(dim); }) - .def("squeeze", [](const Tensor &tensor, std::size_t dim) { return tensor->squeeze(dim); }); + .def("squeeze", [](const Tensor &tensor, std::size_t dim) { return tensor->squeeze(dim); }) + + // Fast in-place scalar writes for tiny CPU metadata tensors (decode loops, etc.). + .def( + "write_i32", + [](Tensor &tensor, std::size_t linear_index, std::int32_t value) { + if (tensor->dtype() != DataType::I32) { + throw py::value_error("write_i32: dtype must be I32"); + } + if (!tensor->is_contiguous()) { + throw py::value_error("write_i32: tensor must be contiguous"); + } + if (linear_index >= static_cast(tensor->numel())) { + throw py::index_error("write_i32: linear_index out of range"); + } + auto *base = reinterpret_cast(tensor->data()); + base[linear_index] = value; + }, + py::arg("linear_index"), + py::arg("value")) + .def( + "write_i64", + [](Tensor &tensor, std::size_t linear_index, std::int64_t value) { + if (tensor->dtype() != DataType::I64) { + throw py::value_error("write_i64: dtype must be I64"); + } + if (!tensor->is_contiguous()) { + throw py::value_error("write_i64: tensor must be contiguous"); + } + if (linear_index >= static_cast(tensor->numel())) { + throw py::index_error("write_i64: linear_index out of range"); + } + auto *base = reinterpret_cast(tensor->data()); + base[linear_index] = value; + }, + py::arg("linear_index"), + py::arg("value")); m.def("empty", &Tensor::empty, py::arg("shape"), diff --git a/src/infinicore/tensor/tensor.cc b/src/infinicore/tensor/tensor.cc index 34a9af601..232c195cc 100644 --- a/src/infinicore/tensor/tensor.cc +++ b/src/infinicore/tensor/tensor.cc @@ -3,6 +3,7 @@ #include "../utils.hpp" #include "infinicore/context/context.hpp" #include "infinicore/dtype.hpp" +#include "infinicore/ops/zeros.hpp" #include @@ -242,8 +243,9 @@ std::shared_ptr TensorImpl::zeros(const Shape &shape, const DataType &dtype, const Device &device, bool pin_memory) { - // TODO: Implement this. - return empty(shape, dtype, device, pin_memory); + auto impl = empty(shape, dtype, device, pin_memory); + infinicore::op::zeros_(Tensor(impl)); + return impl; } std::shared_ptr TensorImpl::ones(const Shape &shape, const DataType &dtype, diff --git a/src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc b/src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc index 2f9bbde47..36cb0c1a9 100644 --- a/src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc +++ b/src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc @@ -41,7 +41,9 @@ infiniStatus_t rmsnorm(const RMSNormInfo *info, T *y, const T *x, const T *w) { T rms = (T)1 / std::sqrt(ss / (T)(dim) + (T)(info->epsilon)); for (size_t k = 0; k < dim; k++) { - y_ptr[k] = x_ptr[k] * w[k] * rms; + // Numerical stability: compute x * rms first, then multiply by weight. + // This avoids inf * 0 -> NaN when rms underflows to 0 for very large x. + y_ptr[k] = (x_ptr[k] * rms) * w[k]; } } @@ -74,10 +76,10 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c for (size_t k = 0; k < dim; k++) { if constexpr (std::is_same::value) { - float val = utils::cast(x_ptr[k]) * w[k] * rms; + float val = utils::cast(x_ptr[k]) * rms * w[k]; y_ptr[k] = utils::cast(val); } else if constexpr (std::is_same::value || std::is_same_v || std::is_same_v) { - float val = utils::cast(x_ptr[k]) * utils::cast(w[k]) * rms; + float val = utils::cast(x_ptr[k]) * rms * utils::cast(w[k]); y_ptr[k] = utils::cast(val); } else { std::abort(); diff --git a/src/infiniop/ops/rms_norm/cuda/kernel.cuh b/src/infiniop/ops/rms_norm/cuda/kernel.cuh index 2fdc36fad..4e2ef51e2 100644 --- a/src/infiniop/ops/rms_norm/cuda/kernel.cuh +++ b/src/infiniop/ops/rms_norm/cuda/kernel.cuh @@ -33,7 +33,9 @@ __device__ void rmsnormBlock( __syncthreads(); for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) { - y_ptr[i] = Tdata(Tcompute(x_ptr[i]) * Tcompute(w_ptr[i]) * rms); + // Numerical stability: compute x * rms first, then multiply by weight. + // This avoids inf * 0 -> NaN when rms underflows to 0 for very large x. + y_ptr[i] = Tdata((Tcompute(x_ptr[i]) * rms) * Tcompute(w_ptr[i])); } } diff --git a/src/infiniop/ops/simple_gla_prefill/info.h b/src/infiniop/ops/simple_gla_prefill/info.h new file mode 100644 index 000000000..9cfd3d550 --- /dev/null +++ b/src/infiniop/ops/simple_gla_prefill/info.h @@ -0,0 +1,57 @@ +#ifndef __SIMPLE_GLA_PREFILL_CUDA_INFO_H__ +#define __SIMPLE_GLA_PREFILL_CUDA_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" + +namespace op::simple_gla_prefill_cuda { + +class SimpleGLAPrefillCudaInfo { + SimpleGLAPrefillCudaInfo() = default; + +public: + infiniDtype_t dtype; + size_t B, T, H, D; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_gamma_desc) { + + auto dtype = out_desc->dtype(); + if (dtype != q_desc->dtype() || dtype != k_desc->dtype() || dtype != v_desc->dtype()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + // Only support half/bf16 for now. + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + + // Shapes must match and be 4D [B,T,H,D] + const auto &out_shape = out_desc->shape(); + CHECK_SAME_SHAPE(out_shape, q_desc->shape(), k_desc->shape(), v_desc->shape()); + if (out_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + // g_gamma: [H], F32 + if (g_gamma_desc->ndim() != 1 || g_gamma_desc->shape()[0] != out_shape[2]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (g_gamma_desc->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return utils::Result(SimpleGLAPrefillCudaInfo{ + dtype, + out_shape[0], + out_shape[1], + out_shape[2], + out_shape[3], + }); + } +}; + +} // namespace op::simple_gla_prefill_cuda + +#endif // __SIMPLE_GLA_PREFILL_CUDA_INFO_H__ diff --git a/src/infiniop/ops/simple_gla_prefill/nvidia/simple_gla_prefill_nvidia_cuda.cu b/src/infiniop/ops/simple_gla_prefill/nvidia/simple_gla_prefill_nvidia_cuda.cu new file mode 100644 index 000000000..4ccde727d --- /dev/null +++ b/src/infiniop/ops/simple_gla_prefill/nvidia/simple_gla_prefill_nvidia_cuda.cu @@ -0,0 +1,312 @@ +#include "simple_gla_prefill_nvidia_cuda.cuh" + +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_handle.cuh" + +namespace op::simple_gla_prefill_cuda::nvidia { + +namespace { + +__device__ __forceinline__ float bf16_to_f32(__nv_bfloat16 x) { return __bfloat162float(x); } +__device__ __forceinline__ float f16_to_f32(__half x) { return __half2float(x); } +__device__ __forceinline__ __nv_bfloat16 f32_to_bf16(float x) { return __float2bfloat16_rn(x); } +__device__ __forceinline__ __half f32_to_f16(float x) { return __float2half_rn(x); } + +template +struct Convert; + +template <> +struct Convert<__half> { + __device__ static float to_f32(__half x) { return f16_to_f32(x); } + __device__ static __half from_f32(float x) { return f32_to_f16(x); } +}; + +template <> +struct Convert<__nv_bfloat16> { + __device__ static float to_f32(__nv_bfloat16 x) { return bf16_to_f32(x); } + __device__ static __nv_bfloat16 from_f32(float x) { return f32_to_bf16(x); } +}; + +// Naive but fused prefill kernel: +// - One block per (b,h) +// - Shared S[D*D] in fp32 +// - Loop over t and update S + compute o_t +template +__global__ void simple_gla_prefill_kernel( + T *out, + const T *q, + const T *k, + const T *v, + const float *g_gamma, + int B, + int Tlen, + int H, + int D, + float scale) { + + const int b = (int)blockIdx.x; + const int h = (int)blockIdx.y; + const int tid = (int)threadIdx.x; + + extern __shared__ float smem[]; + float *S = smem; // D*D + float *kvec = S + D * D; // D + float *vvec = kvec + D; // D + float *qvec = vvec + D; // D + + // Initialize S to 0. + for (int idx = tid; idx < D * D; idx += (int)blockDim.x) { + S[idx] = 0.0f; + } + __syncthreads(); + + const float gate = expf(g_gamma[h]); + + // Base pointers (contiguous [B,T,H,D]) + const int stride_b = Tlen * H * D; + const int stride_t = H * D; + const int stride_h = D; + + for (int t = 0; t < Tlen; ++t) { + const int base = b * stride_b + t * stride_t + h * stride_h; + + // Load q/k/v vectors to shared (fp32). + if (tid < D) { + qvec[tid] = Convert::to_f32(q[base + tid]); + kvec[tid] = Convert::to_f32(k[base + tid]); + vvec[tid] = Convert::to_f32(v[base + tid]); + } + __syncthreads(); + + // Update S = S*gate + outer(k, v) + for (int idx = tid; idx < D * D; idx += (int)blockDim.x) { + const int dk = idx / D; + const int dv = idx - dk * D; + S[idx] = S[idx] * gate + kvec[dk] * vvec[dv]; + } + __syncthreads(); + + // Compute out[t, d] for this (b,h) using D threads. + if (tid < D) { + float acc = 0.0f; + const int dv = tid; + for (int dk = 0; dk < D; ++dk) { + acc += (qvec[dk] * scale) * S[dk * D + dv]; + } + out[base + dv] = Convert::from_f32(acc); + } + __syncthreads(); + } +} + +// Chunked/tiled kernel for D > 64 (e.g. D=128) to stay under 64KB shared memory. +// Grid: (ceil(D/BK), ceil(D/BV), B*H), block: 256 threads. Each block holds S_tile [BK][BV]. +constexpr int TILE = 32; + +template +__global__ void simple_gla_prefill_chunked_kernel( + float *out_float, + const T *q, + const T *k, + const T *v, + const float *g_gamma, + int B, + int Tlen, + int H, + int D, + float scale) { + + const int i_k = (int)blockIdx.x; // tile index along K + const int i_v = (int)blockIdx.y; // tile index along V + const int bh = (int)blockIdx.z; + const int b = bh / H; + const int h = bh % H; + const int tid = (int)threadIdx.x; + + const int k0 = i_k * TILE; + const int v0 = i_v * TILE; + const int nk = (D - k0) < TILE ? (D - k0) : TILE; + const int nv = (D - v0) < TILE ? (D - v0) : TILE; + + extern __shared__ float smem[]; + float *S_tile = smem; // [TILE][TILE] = 32*32*4 = 4KB + float *q_tile = S_tile + TILE * TILE; + float *k_tile = q_tile + TILE; + float *v_tile = k_tile + TILE; + + const float gate = expf(g_gamma[h]); + const int stride_b = Tlen * H * D; + const int stride_t = H * D; + const int stride_h = D; + const int out_stride_b = Tlen * H * D; + const int out_stride_t = H * D; + const int out_stride_h = D; + + // Initialize S_tile to 0 + for (int i = tid; i < TILE * TILE; i += blockDim.x) { + S_tile[i] = 0.0f; + } + __syncthreads(); + + for (int t = 0; t < Tlen; ++t) { + const int base = b * stride_b + t * stride_t + h * stride_h; + + // Load q, k, v tiles into shared (fp32) + if (tid < nk) { + q_tile[tid] = Convert::to_f32(q[base + k0 + tid]); + k_tile[tid] = Convert::to_f32(k[base + k0 + tid]); + } + if (tid < nv) { + v_tile[tid] = Convert::to_f32(v[base + v0 + tid]); + } + __syncthreads(); + + // S_tile = S_tile * gate + outer(k_tile, v_tile) + for (int idx = tid; idx < TILE * TILE; idx += blockDim.x) { + const int dk = idx / TILE; + const int dv = idx - dk * TILE; + if (dk < nk && dv < nv) { + S_tile[idx] = S_tile[idx] * gate + k_tile[dk] * v_tile[dv]; + } + } + __syncthreads(); + + // Output uses the updated state (matches naive kernel: update S then compute o_t). + if (tid < nv) { + float acc = 0.0f; + for (int kk = 0; kk < nk; ++kk) { + acc += (q_tile[kk] * scale) * S_tile[kk * TILE + tid]; + } + atomicAdd(&out_float[b * out_stride_b + t * out_stride_t + h * out_stride_h + v0 + tid], acc); + } + __syncthreads(); + } +} + +// Convert float buffer (B,T,H,D) to output dtype. +template +__global__ void simple_gla_prefill_convert_kernel( + T *out, + const float *in_float, + int B, + int Tlen, + int H, + int D) { + const int idx = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int total = B * Tlen * H * D; + if (idx < total) { + out[idx] = Convert::from_f32(in_float[idx]); + } +} + +} // namespace + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_gamma_desc) { + + auto handle = reinterpret_cast(handle_); + auto info_r = op::simple_gla_prefill_cuda::SimpleGLAPrefillCudaInfo::create( + out_desc, q_desc, k_desc, v_desc, g_gamma_desc); + if (!info_r) { + return info_r.status(); + } + auto info = info_r.take(); + + // Workspace for chunked path (D > 64): float buffer B*T*H*D. + const size_t D_val = info.D; + const size_t workspace_size = (D_val > 64u) + ? (info.B * info.T * info.H * info.D * sizeof(float)) + : 0u; + *desc_ptr = new Descriptor( + /*opaque=*/nullptr, + info, + workspace_size, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *g_gamma, + float scale, + void *stream) const { + + const int B = (int)_info.B; + const int Tlen = (int)_info.T; + const int H = (int)_info.H; + const int D = (int)_info.D; + const bool use_chunked = (D > 64); + cudaStream_t cuda_stream = (cudaStream_t)stream; + + if (use_chunked) { + // Chunked path: write to workspace (float), then convert to out. + const int nk_tiles = (D + TILE - 1) / TILE; + const int nv_tiles = (D + TILE - 1) / TILE; + dim3 grid_chunk(nk_tiles, nv_tiles, B * H); + dim3 block_chunk(256, 1, 1); + const size_t shmem_chunk = sizeof(float) * (TILE * TILE + 3 * TILE); + if (workspace == nullptr || workspace_size < (size_t)(B * Tlen * H * D * (int)sizeof(float))) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *out_float = static_cast(workspace); + cudaMemsetAsync(out_float, 0, (size_t)(B * Tlen * H * D) * sizeof(float), cuda_stream); + if (_info.dtype == INFINI_DTYPE_F16) { + simple_gla_prefill_chunked_kernel<<>>( + out_float, (const half *)q, (const half *)k, (const half *)v, (const float *)g_gamma, + B, Tlen, H, D, scale); + const int total = B * Tlen * H * D; + simple_gla_prefill_convert_kernel<<<(total + 255) / 256, 256, 0, cuda_stream>>>( + (half *)out, out_float, B, Tlen, H, D); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + simple_gla_prefill_chunked_kernel<<>>( + out_float, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k, (const __nv_bfloat16 *)v, (const float *)g_gamma, + B, Tlen, H, D, scale); + const int total = B * Tlen * H * D; + simple_gla_prefill_convert_kernel<<<(total + 255) / 256, 256, 0, cuda_stream>>>( + (__nv_bfloat16 *)out, out_float, B, Tlen, H, D); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; + } + + // Naive path (D <= 64) + (void)workspace; + (void)workspace_size; + dim3 grid(B, H, 1); + dim3 block(256, 1, 1); + const size_t shmem = sizeof(float) * (D * D + 3 * D); + + if (_info.dtype == INFINI_DTYPE_F16) { + simple_gla_prefill_kernel<<>>( + (half *)out, (const half *)q, (const half *)k, (const half *)v, (const float *)g_gamma, + B, Tlen, H, D, scale); + return INFINI_STATUS_SUCCESS; + } + if (_info.dtype == INFINI_DTYPE_BF16) { + simple_gla_prefill_kernel<<>>( + (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k, (const __nv_bfloat16 *)v, (const float *)g_gamma, + B, Tlen, H, D, scale); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::simple_gla_prefill_cuda::nvidia diff --git a/src/infiniop/ops/simple_gla_prefill/nvidia/simple_gla_prefill_nvidia_cuda.cuh b/src/infiniop/ops/simple_gla_prefill/nvidia/simple_gla_prefill_nvidia_cuda.cuh new file mode 100644 index 000000000..c0feaba53 --- /dev/null +++ b/src/infiniop/ops/simple_gla_prefill/nvidia/simple_gla_prefill_nvidia_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SIMPLE_GLA_PREFILL_NVIDIA_CUDA_IMPL_H__ +#define __SIMPLE_GLA_PREFILL_NVIDIA_CUDA_IMPL_H__ + +#include "../simple_gla_prefill_cuda.h" + +DESCRIPTOR(nvidia) + +#endif diff --git a/src/infiniop/ops/simple_gla_prefill/operator.cc b/src/infiniop/ops/simple_gla_prefill/operator.cc new file mode 100644 index 000000000..a36ff1594 --- /dev/null +++ b/src/infiniop/ops/simple_gla_prefill/operator.cc @@ -0,0 +1,86 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/simple_gla_prefill.h" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/simple_gla_prefill_nvidia_cuda.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateSimpleGLAPrefillDescriptor( + infiniopHandle_t handle, + infiniopSimpleGLAPrefillDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_gamma_desc) { + +#define CREATE_CUDA(CASE, NAMESPACE) \ + case CASE: \ + return op::simple_gla_prefill_cuda::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, q_desc, k_desc, v_desc, g_gamma_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE_CUDA(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE_CUDA +} + +__INFINI_C infiniStatus_t infiniopGetSimpleGLAPrefillWorkspaceSize( + infiniopSimpleGLAPrefillDescriptor_t desc, + size_t *size) { + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + case INFINI_DEVICE_NVIDIA: + *size = reinterpret_cast(desc)->workspaceSize(); + return INFINI_STATUS_SUCCESS; +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopSimpleGLAPrefill( + infiniopSimpleGLAPrefillDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void const *q, + void const *k, + void const *v, + void const *g_gamma, + float scale, + void *stream) { + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + case INFINI_DEVICE_NVIDIA: + return reinterpret_cast(desc) + ->calculate(workspace, workspace_size, out, q, k, v, g_gamma, scale, stream); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopDestroySimpleGLAPrefillDescriptor( + infiniopSimpleGLAPrefillDescriptor_t desc) { + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + case INFINI_DEVICE_NVIDIA: + delete reinterpret_cast(desc); + return INFINI_STATUS_SUCCESS; +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/src/infiniop/ops/simple_gla_prefill/simple_gla_prefill_cuda.h b/src/infiniop/ops/simple_gla_prefill/simple_gla_prefill_cuda.h new file mode 100644 index 000000000..a155dcd95 --- /dev/null +++ b/src/infiniop/ops/simple_gla_prefill/simple_gla_prefill_cuda.h @@ -0,0 +1,48 @@ +#ifndef SIMPLE_GLA_PREFILL_CUDA_H +#define SIMPLE_GLA_PREFILL_CUDA_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + namespace op::simple_gla_prefill_cuda::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + op::simple_gla_prefill_cuda::SimpleGLAPrefillCudaInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(Opaque *opaque, \ + op::simple_gla_prefill_cuda::SimpleGLAPrefillCudaInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + size_t workspaceSize() const { return _workspace_size; } \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_gamma_desc); \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, \ + const void *q, \ + const void *k, \ + const void *v, \ + const void *g_gamma, \ + float scale, \ + void *stream) const; \ + }; \ + } + +#endif // SIMPLE_GLA_PREFILL_CUDA_H diff --git a/test/infinicore/ops/infllmv2_attention.py b/test/infinicore/ops/infllmv2_attention.py new file mode 100644 index 000000000..b4c04bc4c --- /dev/null +++ b/test/infinicore/ops/infllmv2_attention.py @@ -0,0 +1,509 @@ +""" +Operator unit tests for InfiniCore InfLLM-V2 attention ops: infllmv2_varlen and infllmv2_kvcache. +Uses the InfiniCore test framework (BaseOperatorTest, TestCase, GenericTestRunner). +Runs only when InfiniCore is built with ENABLE_INFLLMV2 and linked to the infllmv2 .so; +otherwise tests are skipped so CI without the .so still passes. + +Run (from InfiniCore dir): + python test/infinicore/run.py --ops infllmv2_attention --nvidia + +Direct: + python test/infinicore/ops/infllmv2_attention.py --nvidia +""" + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import infinicore +import torch + +from framework import ( + BaseOperatorTest, + TestCase, + TensorSpec, + TensorInitializer, + GenericTestRunner, + get_args, +) +from framework.results import CaseResult +from framework.utils.tensor_utils import convert_infinicore_to_torch + +# Check for InfLLM-V2 ops; skip entire module if not built +infllmv2_varlen = getattr(infinicore, "infllmv2_varlen", None) +infllmv2_kvcache = getattr(infinicore, "infllmv2_kvcache", None) +INFLLMV2_AVAILABLE = infllmv2_varlen is not None and infllmv2_kvcache is not None + + +def _print_metrics(name, out_infinicore): + out_t = convert_infinicore_to_torch(out_infinicore) + l2 = float(out_t.norm()) + max_abs = float(out_t.abs().max()) + print( + f" {name}: shape={list(out_infinicore.shape)} L2={l2:.4f} max_abs={max_abs:.4f}" + ) + + +def _make_varlen_test_case(): + total_q, nheads, head_dim = 8, 2, 8 + total_k, nheads_k = 8, 2 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor((total_q, nheads, head_dim), None, infinicore.float16) + k_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + v_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + cu_q_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + cu_k_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_spec, v_spec, cu_q_spec, cu_k_spec], + kwargs={ + "max_seqlen_q": 4, + "max_seqlen_k": 4, + "scale": scale, + "causal": True, + "_expected_out_shape": (total_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 varlen (2 batches, 4 tokens)", + ) + + +def _make_varlen_test_case_bf16(): + total_q, nheads, head_dim = 8, 2, 8 + total_k, nheads_k = 8, 2 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor( + (total_q, nheads, head_dim), None, infinicore.bfloat16 + ) + k_spec = TensorSpec.from_tensor( + (total_k, nheads_k, head_dim), None, infinicore.bfloat16 + ) + v_spec = TensorSpec.from_tensor( + (total_k, nheads_k, head_dim), None, infinicore.bfloat16 + ) + cu_q_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + cu_k_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_spec, v_spec, cu_q_spec, cu_k_spec], + kwargs={ + "max_seqlen_q": 4, + "max_seqlen_k": 4, + "scale": scale, + "causal": True, + "_expected_out_shape": (total_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-1, "rtol": 1e-1}, + description="InfLLMV2 varlen BF16 (2 batches, 4 tokens)", + ) + + +def _make_varlen_test_case_localwindow(): + total_q, nheads, head_dim = 8, 2, 8 + total_k, nheads_k = 8, 2 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor((total_q, nheads, head_dim), None, infinicore.float16) + k_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + v_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + cu_q_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + cu_k_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_spec, v_spec, cu_q_spec, cu_k_spec], + kwargs={ + "max_seqlen_q": 4, + "max_seqlen_k": 4, + "scale": scale, + "causal": False, + "window_size_left": 2, + "window_size_right": 0, + "_expected_out_shape": (total_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 varlen local-window (causal=false, left=2, right=0)", + ) + + +def _make_varlen_test_case_localwindow_left0(): + total_q, nheads, head_dim = 8, 2, 8 + total_k, nheads_k = 8, 2 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor((total_q, nheads, head_dim), None, infinicore.float16) + k_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + v_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + cu_q_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + cu_k_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_spec, v_spec, cu_q_spec, cu_k_spec], + kwargs={ + "max_seqlen_q": 4, + "max_seqlen_k": 4, + "scale": scale, + "causal": False, + "window_size_left": 0, + "window_size_right": 0, + "_expected_out_shape": (total_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 varlen local-window (causal=false, left=0, right=0)", + ) + + +def _make_varlen_test_case_localwindow_left3(): + total_q, nheads, head_dim = 8, 2, 8 + total_k, nheads_k = 8, 2 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor((total_q, nheads, head_dim), None, infinicore.float16) + k_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + v_spec = TensorSpec.from_tensor((total_k, nheads_k, head_dim), None, infinicore.float16) + cu_q_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + cu_k_spec = TensorSpec.from_tensor( + (3,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([0, 4, 8], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_spec, v_spec, cu_q_spec, cu_k_spec], + kwargs={ + "max_seqlen_q": 4, + "max_seqlen_k": 4, + "scale": scale, + "causal": False, + "window_size_left": 3, + "window_size_right": 0, + "_expected_out_shape": (total_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 varlen local-window (causal=false, left=3, right=0)", + ) + + +def _make_kvcache_test_case(): + batch, seqlen_q, nheads, head_dim = 1, 1, 2, 8 + cache_len = 4 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor( + (batch, seqlen_q, nheads, head_dim), None, infinicore.float16 + ) + k_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.float16 + ) + v_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.float16 + ) + cache_lens_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([cache_len], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_cache_spec, v_cache_spec, cache_lens_spec], + kwargs={ + "scale": scale, + "causal": True, + "_expected_out_shape": (batch, seqlen_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 kvcache (1 batch, 1 query, cache_len=4)", + ) + + +def _make_kvcache_test_case_localwindow(): + batch, seqlen_q, nheads, head_dim = 1, 1, 2, 8 + cache_len = 4 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor( + (batch, seqlen_q, nheads, head_dim), None, infinicore.float16 + ) + k_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.float16 + ) + v_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.float16 + ) + cache_lens_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([cache_len], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_cache_spec, v_cache_spec, cache_lens_spec], + kwargs={ + "scale": scale, + "causal": False, + "window_size_left": 2, + "window_size_right": 0, + "_expected_out_shape": (batch, seqlen_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 kvcache local-window (causal=false, left=2, right=0)", + ) + + +def _make_kvcache_test_case_localwindow_left0(): + batch, seqlen_q, nheads, head_dim = 1, 1, 2, 8 + cache_len = 4 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor( + (batch, seqlen_q, nheads, head_dim), None, infinicore.float16 + ) + k_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.float16 + ) + v_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.float16 + ) + cache_lens_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([cache_len], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_cache_spec, v_cache_spec, cache_lens_spec], + kwargs={ + "scale": scale, + "causal": False, + "window_size_left": 0, + "window_size_right": 0, + "_expected_out_shape": (batch, seqlen_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-2, "rtol": 1e-2}, + description="InfLLMV2 kvcache local-window (causal=false, left=0, right=0)", + ) + + +def _make_kvcache_test_case_bf16(): + batch, seqlen_q, nheads, head_dim = 1, 1, 2, 8 + cache_len = 4 + scale = 1.0 / (head_dim**0.5) + q_spec = TensorSpec.from_tensor( + (batch, seqlen_q, nheads, head_dim), None, infinicore.bfloat16 + ) + k_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.bfloat16 + ) + v_cache_spec = TensorSpec.from_tensor( + (batch, cache_len, nheads, head_dim), None, infinicore.bfloat16 + ) + cache_lens_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor([cache_len], dtype=torch.int32), + ) + return TestCase( + inputs=[q_spec, k_cache_spec, v_cache_spec, cache_lens_spec], + kwargs={ + "scale": scale, + "causal": True, + "_expected_out_shape": (batch, seqlen_q, nheads, head_dim), + }, + output_spec=None, + comparison_target=None, + tolerance={"atol": 1e-1, "rtol": 1e-1}, + description="InfLLMV2 kvcache BF16 (1 batch, 1 query, cache_len=4)", + ) + + +class InfLLMV2AttentionTest(BaseOperatorTest): + def __init__(self): + super().__init__("InfLLMV2Attention") + + def get_test_cases(self): + if not INFLLMV2_AVAILABLE: + return [] + return [ + _make_varlen_test_case(), + _make_kvcache_test_case(), + _make_kvcache_test_case_localwindow(), + _make_kvcache_test_case_localwindow_left0(), + _make_varlen_test_case_localwindow(), + _make_varlen_test_case_localwindow_left0(), + _make_varlen_test_case_localwindow_left3(), + _make_varlen_test_case_bf16(), + _make_kvcache_test_case_bf16(), + ] + + def torch_operator(self, *args, **kwargs): + raise NotImplementedError( + "InfLLM-V2 has no PyTorch reference in this test (InfiniCore-only)" + ) + + def infinicore_operator(self, *args, **kwargs): + raise NotImplementedError("InfLLM-V2 uses run_test override (InfiniCore-only)") + + def run_test(self, device, test_case, config): + test_result = CaseResult( + success=False, + return_code=-1, + test_case=test_case, + device=device, + ) + + if not INFLLMV2_AVAILABLE: + test_result.return_code = -2 + test_result.error_message = ( + "infllmv2_varlen/infllmv2_kvcache not available (build without ENABLE_INFLLMV2?)" + ) + return test_result + + inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device) + expected_shape = kwargs.pop("_expected_out_shape", None) + infini_inputs, infini_kwargs, _ = self.prepare_infinicore_inputs_and_kwargs( + inputs, kwargs, test_case.comparison_target + ) + + if len(infini_inputs) == 5: + window_size_left = infini_kwargs.get("window_size_left", -1) + window_size_right = infini_kwargs.get("window_size_right", -1) + out = infllmv2_varlen( + infini_inputs[0], + infini_inputs[1], + infini_inputs[2], + infini_inputs[3], + infini_inputs[4], + max_seqlen_q=infini_kwargs["max_seqlen_q"], + max_seqlen_k=infini_kwargs["max_seqlen_k"], + scale=infini_kwargs["scale"], + causal=infini_kwargs["causal"], + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + name = "varlen" + elif len(infini_inputs) == 4: + window_size_left = infini_kwargs.get("window_size_left", -1) + window_size_right = infini_kwargs.get("window_size_right", -1) + out = infllmv2_kvcache( + infini_inputs[0], + infini_inputs[1], + infini_inputs[2], + infini_inputs[3], + scale=infini_kwargs["scale"], + causal=infini_kwargs["causal"], + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + name = "kvcache" + else: + test_result.error_message = f"Unexpected number of inputs: {len(infini_inputs)}" + return test_result + + infinicore.sync_stream() + + if out is None: + test_result.error_message = "InfiniCore operator returned None" + return test_result + + shape = out.shape + if expected_shape is not None and tuple(shape) != tuple(expected_shape): + test_result.error_message = ( + f"Shape mismatch: got {list(shape)}, expected {list(expected_shape)}" + ) + return test_result + + out_t = convert_infinicore_to_torch(out) + if torch.isnan(out_t).any() or torch.isinf(out_t).any(): + test_result.error_message = "Output contained NaN/Inf" + return test_result + + _print_metrics(name, out) + test_result.success = True + test_result.return_code = 0 + return test_result + + +def main(): + args = get_args() + if not args.nvidia: + print("InfLLM-V2 ops require CUDA; use --nvidia to run on GPU.") + sys.exit(0) + if not INFLLMV2_AVAILABLE: + print( + "infllmv2_varlen / infllmv2_kvcache not available. Build InfiniCore with --aten=y --infllmv2=..." + ) + sys.exit(0) + + runner = GenericTestRunner(InfLLMV2AttentionTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() + diff --git a/test/infinicore/ops/simple_gla_decode_recurrent.py b/test/infinicore/ops/simple_gla_decode_recurrent.py new file mode 100644 index 000000000..eff3a5902 --- /dev/null +++ b/test/infinicore/ops/simple_gla_decode_recurrent.py @@ -0,0 +1,237 @@ +""" +Operator unit test for InfiniCore simple_gla_decode_step. + +Runs a multi-step decode loop with shared float32 state [B,H,D,D] and compares stacked +outputs to simple_gla_attention on [B,T,H,D] and a PyTorch recurrent reference. +Optional cross-check against FLA naive_recurrent_simple_gla when flash-linear-attention is installed. + +Run (from InfiniCore dir): + python test/infinicore/run.py --ops simple_gla_decode_recurrent --nvidia + +Direct: + python test/infinicore/ops/simple_gla_decode_recurrent.py --nvidia +""" + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch + +import infinicore + +from framework import ( + BaseOperatorTest, + TestCase, + TensorSpec, + GenericTestRunner, + get_args, +) +from framework.devices import torch_device_map +from framework.results import CaseResult +from framework.utils.tensor_utils import ( + convert_infinicore_to_torch, + infinicore_tensor_from_torch, +) + + +SIMPLE_GLA_DECODE_STEP_AVAILABLE = hasattr(infinicore, "simple_gla_decode_step") + + +def _torch_simple_gla_recurrent_ref(q, k, v, g_gamma, scale: float) -> torch.Tensor: + dtype = q.dtype + qf = q.transpose(1, 2).float() + kf = k.transpose(1, 2).float() + vf = v.transpose(1, 2).float() + B, H, T, K = qf.shape + V = vf.shape[-1] + qf = qf * scale + o = vf.new_zeros(B, H, T, V) + S = qf.new_zeros(B, H, K, V) + gate = g_gamma.float().exp() + for i in range(T): + key = kf[:, :, i, :] + value = vf[:, :, i, :] + kv = key.unsqueeze(-1) * value.unsqueeze(-2) + S = S * gate.view(1, -1, 1, 1) + kv + q_i = qf[:, :, i, :] + o[:, :, i, :] = (q_i.unsqueeze(-1) * S).sum(-2) + return o.transpose(1, 2).to(dtype) + + +def _optional_fla_output(q, k, v, g_gamma, scale: float): + try: + from fla.ops.simple_gla import naive_recurrent_simple_gla + except Exception: + try: + from fla.ops.simple_gla.naive import naive_recurrent_simple_gla + except Exception: + return None + try: + return naive_recurrent_simple_gla( + q.contiguous(), + k.contiguous(), + v.contiguous(), + g_gamma.contiguous(), + scale=scale, + ) + except Exception: + return None + + +def _make_case(dtype: str, *, B=2, T=16, H=8, D=64): + dt = infinicore.bfloat16 if dtype == "bf16" else infinicore.float16 + scale = 1.0 / (D**0.5) + + q = TensorSpec.from_tensor((B, T, H, D), None, dt) + k = TensorSpec.from_tensor((B, T, H, D), None, dt) + v = TensorSpec.from_tensor((B, T, H, D), None, dt) + g_gamma = TensorSpec.from_tensor((H,), None, infinicore.float32) + + return TestCase( + inputs=[q, k, v, g_gamma], + kwargs={"scale": scale, "_dtype": dtype, "_shape": (B, T, H, D)}, + output_spec=None, + comparison_target=None, + tolerance={"atol": 5e-2, "rtol": 5e-2} + if dtype == "bf16" + else {"atol": 2e-2, "rtol": 2e-2}, + description=f"simple_gla_decode_step loop vs ref ({dtype}) B={B} T={T} H={H} D={D}", + ) + + +class SimpleGLADecodeRecurrentTest(BaseOperatorTest): + def __init__(self): + super().__init__("SimpleGLADecodeRecurrent") + + def get_test_cases(self): + if not SIMPLE_GLA_DECODE_STEP_AVAILABLE: + return [] + return [ + _make_case("fp16"), + _make_case("bf16"), + _make_case("fp16", D=128), + _make_case("bf16", D=128), + ] + + def torch_operator(self, *args, **kwargs): + raise NotImplementedError("This test overrides run_test.") + + def infinicore_operator(self, *args, **kwargs): + raise NotImplementedError("This test overrides run_test.") + + def run_test(self, device, test_case, config): + tr = CaseResult(success=False, return_code=-1, test_case=test_case, device=device) + + if not SIMPLE_GLA_DECODE_STEP_AVAILABLE: + tr.return_code = -2 + tr.error_message = "simple_gla_decode_step not available (pybind not built?)" + return tr + + torch.manual_seed(0) + dev_str = torch_device_map[device] + infinicore.set_device(infinicore.device(dev_str, 0)) + + inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device) + scale = float(kwargs["scale"]) + + for t in inputs: + if not t.is_floating_point(): + continue + with torch.no_grad(): + if t.dim() == 1: + t.uniform_(-1.2, -0.05) + else: + t.uniform_(-0.1, 0.1) + + q, k, v, g_gamma = inputs + B, T, H, D = q.shape + + ref_loop = _torch_simple_gla_recurrent_ref(q, k, v, g_gamma, scale).float() + + infini_inputs, _, _ = self.prepare_infinicore_inputs_and_kwargs( + inputs, {"scale": scale}, None + ) + q_ic, k_ic, v_ic, g_ic = infini_inputs + + full_out = infinicore.simple_gla_attention(q_ic, k_ic, v_ic, g_ic, scale=scale) + infinicore.sync_stream() + full_t = convert_infinicore_to_torch(full_out).float() + + S_buf = torch.zeros(B, H, D, D, device=q.device, dtype=torch.float32) + S_ic = infinicore_tensor_from_torch(S_buf) + outs = [] + for ti in range(T): + qs = infinicore_tensor_from_torch(q[:, ti : ti + 1].contiguous()) + ks = infinicore_tensor_from_torch(k[:, ti : ti + 1].contiguous()) + vs = infinicore_tensor_from_torch(v[:, ti : ti + 1].contiguous()) + o_step = infinicore.simple_gla_decode_step(qs, ks, vs, S_ic, g_ic, scale=scale) + infinicore.sync_stream() + outs.append(convert_infinicore_to_torch(o_step)) + stacked = torch.cat(outs, dim=1).float() + + if not torch.isfinite(stacked).all(): + tr.error_message = "decode loop produced NaN/Inf" + return tr + if not torch.isfinite(ref_loop).all(): + tr.error_message = "torch reference produced NaN/Inf" + return tr + + tol = test_case.tolerance or {"atol": 2e-2, "rtol": 2e-2} + + def _diff_line(name, a, b): + d = (a - b).abs() + print(f" {name}: max_abs={float(d.max()):.6f} mean_abs={float(d.mean()):.6f}") + + _diff_line("decode_steps vs torch_ref", stacked, ref_loop) + if not torch.allclose( + stacked, ref_loop, atol=float(tol["atol"]), rtol=float(tol["rtol"]) + ): + d = (stacked - ref_loop).abs() + tr.error_message = f"vs torch ref: max_abs={float(d.max()):.6f}" + return tr + + _diff_line("decode_steps vs simple_gla_attention", stacked, full_t) + if not torch.allclose( + stacked, full_t, atol=float(tol["atol"]), rtol=float(tol["rtol"]) + ): + d = (stacked - full_t).abs() + tr.error_message = f"vs simple_gla_attention: max_abs={float(d.max()):.6f}" + return tr + + fla_o = _optional_fla_output(q, k, v, g_gamma, scale) + if fla_o is not None: + fla_f = fla_o.float() + _diff_line("decode_steps vs fla naive", stacked, fla_f) + if not torch.allclose( + stacked, fla_f, atol=float(tol["atol"]), rtol=float(tol["rtol"]) + ): + d = (stacked - fla_f).abs() + tr.error_message = f"vs FLA: max_abs={float(d.max()):.6f}" + return tr + print(" FLA naive_recurrent_simple_gla cross-check: ok") + else: + print(" FLA cross-check skipped (not installed or API mismatch)") + + tr.success = True + tr.return_code = 0 + return tr + + +def main(): + args = get_args() + if not args.nvidia: + print("simple_gla_decode_step recurrent test expects CUDA; run with --nvidia") + sys.exit(0) + if not SIMPLE_GLA_DECODE_STEP_AVAILABLE: + print("simple_gla_decode_step not available in python bindings (rebuild InfiniCore)") + sys.exit(1) + + runner = GenericTestRunner(SimpleGLADecodeRecurrentTest, args=args) + runner.run_and_exit() + + +if __name__ == "__main__": + main() + diff --git a/test/infinicore/ops/simple_gla_prefill.py b/test/infinicore/ops/simple_gla_prefill.py new file mode 100644 index 000000000..bc10a3ee5 --- /dev/null +++ b/test/infinicore/ops/simple_gla_prefill.py @@ -0,0 +1,158 @@ +""" +Operator unit test for InfiniCore simple_gla_prefill. + +Validates that simple_gla_prefill(q,k,v,g_gamma,scale) matches the existing +simple_gla_attention reference. Covers head_dim=64 (naive fused path) and +head_dim=128 (chunked/tiled path for MiniCPM-SALA). + +Run (from InfiniCore dir): + python test/infinicore/run.py --ops simple_gla_prefill --nvidia + +Direct: + python test/infinicore/ops/simple_gla_prefill.py --nvidia +""" + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch + +import infinicore + +from framework import ( + BaseOperatorTest, + TestCase, + TensorSpec, + GenericTestRunner, + get_args, +) +from framework.results import CaseResult +from framework.utils.tensor_utils import convert_infinicore_to_torch + + +SIMPLE_GLA_PREFILL_AVAILABLE = hasattr(infinicore, "simple_gla_prefill") + + +def _make_case(dtype: str, *, B=1, T=32, H=8, D=64): + dt = infinicore.bfloat16 if dtype == "bf16" else infinicore.float16 + scale = 1.0 / (D**0.5) + + q = TensorSpec.from_tensor((B, T, H, D), None, dt) + k = TensorSpec.from_tensor((B, T, H, D), None, dt) + v = TensorSpec.from_tensor((B, T, H, D), None, dt) + g_gamma = TensorSpec.from_tensor((H,), None, infinicore.float32) + + return TestCase( + inputs=[q, k, v, g_gamma], + kwargs={"scale": scale, "_dtype": dtype, "_shape": (B, T, H, D)}, + output_spec=None, + comparison_target=None, + tolerance={"atol": 5e-2, "rtol": 5e-2} + if dtype == "bf16" + else {"atol": 2e-2, "rtol": 2e-2}, + description=f"simple_gla_prefill vs simple_gla_attention ({dtype}) B={B} T={T} H={H} D={D}", + ) + + +class SimpleGLAPrefillTest(BaseOperatorTest): + def __init__(self): + super().__init__("SimpleGLAPrefill") + + def get_test_cases(self): + if not SIMPLE_GLA_PREFILL_AVAILABLE: + return [] + return [ + _make_case("fp16"), + _make_case("bf16"), + _make_case("fp16", D=128), + _make_case("bf16", D=128), + ] + + def torch_operator(self, *args, **kwargs): + raise NotImplementedError("This test compares two InfiniCore operators.") + + def infinicore_operator(self, *args, **kwargs): + raise NotImplementedError("This test overrides run_test.") + + def run_test(self, device, test_case, config): + tr = CaseResult(success=False, return_code=-1, test_case=test_case, device=device) + + if not SIMPLE_GLA_PREFILL_AVAILABLE: + tr.return_code = -2 + tr.error_message = "simple_gla_prefill not available (pybind not built?)" + return tr + + torch.manual_seed(0) + inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device) + scale = float(kwargs["scale"]) + + for t in inputs: + if not t.is_floating_point(): + continue + with torch.no_grad(): + if t.dim() == 1: + t.uniform_(-1.2, -0.05) + else: + t.uniform_(-0.1, 0.1) + + infini_inputs, _, _ = self.prepare_infinicore_inputs_and_kwargs( + inputs, {"scale": scale}, None + ) + + q, k, v, g_gamma = infini_inputs + + out_ref = infinicore.simple_gla_attention(q, k, v, g_gamma, scale=scale) + out_new = infinicore.simple_gla_prefill(q, k, v, g_gamma, scale=scale) + infinicore.sync_stream() + + if out_ref is None or out_new is None: + tr.error_message = "operator returned None" + return tr + + t_ref = convert_infinicore_to_torch(out_ref).float() + t_new = convert_infinicore_to_torch(out_new).float() + + if not torch.isfinite(t_ref).all(): + tr.error_message = "reference simple_gla_attention produced NaN/Inf" + return tr + if not torch.isfinite(t_new).all(): + tr.error_message = "simple_gla_prefill produced NaN/Inf" + return tr + + diff = (t_ref - t_new).abs() + max_abs = float(diff.max()) + mean_abs = float(diff.mean()) + print(f" diff max_abs={max_abs:.6f} mean_abs={mean_abs:.6f}") + + tol = test_case.tolerance or {"atol": 2e-2, "rtol": 2e-2} + if not torch.allclose( + t_new, t_ref, atol=float(tol["atol"]), rtol=float(tol["rtol"]) + ): + tr.error_message = ( + f"mismatch: max_abs={max_abs:.6f}, atol={tol['atol']} rtol={tol['rtol']}" + ) + return tr + + tr.success = True + tr.return_code = 0 + return tr + + +def main(): + args = get_args() + if not args.nvidia: + print("simple_gla_prefill requires CUDA; run with --nvidia") + sys.exit(0) + if not SIMPLE_GLA_PREFILL_AVAILABLE: + print("simple_gla_prefill not available in python bindings (rebuild InfiniCore)") + sys.exit(1) + + runner = GenericTestRunner(SimpleGLAPrefillTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() + diff --git a/xmake.lua b/xmake.lua index 8f32bf7cc..de135b7f8 100644 --- a/xmake.lua +++ b/xmake.lua @@ -1,5 +1,8 @@ add_rules("mode.debug", "mode.release") -add_requires("boost", {configs = {stacktrace = true}}) +-- In CI/docker or non-interactive shells, run xmake with -y (e.g. xmake clean -y) to avoid hanging on package prompts. +if is_mode("debug") then + add_requires("boost", {configs = {stacktrace = true}}) +end add_requires("pybind11") -- Define color codes @@ -55,6 +58,10 @@ option_end() if has_config("nv-gpu") then add_defines("ENABLE_NVIDIA_API") includes("xmake/nvidia.lua") + -- Ensure CUDA toolkit headers (e.g. cuda_runtime_api.h) are visible to + -- C++ sources that include ATen CUDA wrappers like CUDAContextLight.h. + local cuda_dir = get_config("cuda") or os.getenv("CUDA_HOME") or os.getenv("CUDA_ROOT") or "/usr/local/cuda" + add_includedirs(path.join(cuda_dir, "include"), { public = true }) end option("cudnn") @@ -73,7 +80,7 @@ option("cutlass") set_description("Whether to compile cutlass for Nvidia GPU") option_end() -if has_config("cutlass") then +if has_config("cutlass") then add_defines("ENABLE_CUTLASS_API") end @@ -242,11 +249,42 @@ option_end() if has_config("aten") then add_defines("ENABLE_ATEN") - if get_config("flash-attn") ~= false then + -- Only enable FlashAttention integration when a non-empty path is provided. + local flash_attn_cfg = get_config("flash-attn") + if flash_attn_cfg ~= nil and flash_attn_cfg ~= "" and flash_attn_cfg ~= false then add_defines("ENABLE_FLASH_ATTN") end end +-- InfLLM-V2 direct kernels (requires aten; link against infllm_v2 shared library) +-- +-- Policy: InfLLM-V2 is optional and must be checked out/built by the user. +-- We do NOT auto-run `git submodule update` or `python setup.py install` from xmake. +-- +-- Usage: +-- - auto-detect (if you manually checked out to third_party/infllmv2_cuda_impl): +-- xmake f --aten=y --infllmv2=y +-- - or specify a path (recommended; works without any checkout under this repo): +-- xmake f --aten=y --infllmv2=/abs/path/to/libinfllm_v2.so +-- xmake f --aten=y --infllmv2=/abs/path/to/infllmv2_cuda_impl # will auto-detect under build/lib.*/ +option("infllmv2") + set_default("") + set_showmenu(true) + set_description("Enable InfLLM-V2 support. Value: 'y' (auto-detect under third_party/infllmv2_cuda_impl) or a path to libinfllm_v2.so / infllmv2_cuda_impl root. Requires --aten=y.") +option_end() + +local function _infllmv2_enabled() + local cfg = get_config("infllmv2") + return cfg ~= nil and cfg ~= "" and cfg ~= false +end + +if _infllmv2_enabled() then + -- Fail fast: C++ code is gated on ENABLE_INFLLMV2 && ENABLE_ATEN. + if not has_config("aten") then + error("--infllmv2 requires --aten=y") + end +end + -- cuda graph option("graph") set_default(false) @@ -489,11 +527,11 @@ target("infinicore_cpp_api") local TORCH_DIR = outdata target:add( - "includedirs", - path.join(TORCH_DIR, "include"), + "includedirs", + path.join(TORCH_DIR, "include"), path.join(TORCH_DIR, "include/torch/csrc/api/include"), { public = true }) - + target:add( "linkdirs", path.join(TORCH_DIR, "lib"), @@ -509,6 +547,78 @@ target("infinicore_cpp_api") ) end + -- InfLLM-V2: locate + link infllm_v2 .so + local resolved_infllmv2 = nil + if _infllmv2_enabled() then + local infllmv2_cfg = get_config("infllmv2") + + local function detect_infllmv2_so(infllmv2_root) + local candidates = os.files(path.join(infllmv2_root, "build", "lib.*", "infllm_v2", "*.so")) + if candidates and #candidates > 0 then + table.sort(candidates) + return candidates[1] + end + return nil + end + + local function is_truthy_enable(v) + if v == true then + return true + end + if type(v) == "string" then + local s = v:lower() + return s == "y" or s == "yes" or s == "true" or s == "1" or s == "on" + end + return false + end + + -- 1) If user passed a file path (libinfllm_v2.so / *.so), use it directly. + if type(infllmv2_cfg) == "string" and infllmv2_cfg ~= "" and os.isfile(infllmv2_cfg) then + resolved_infllmv2 = infllmv2_cfg + end + + -- 2) If user passed a directory, try to auto-detect under it. + if not resolved_infllmv2 and type(infllmv2_cfg) == "string" and infllmv2_cfg ~= "" and os.isdir(infllmv2_cfg) then + resolved_infllmv2 = detect_infllmv2_so(infllmv2_cfg) + end + + -- 3) If user passed y/true, try the conventional in-tree location (if present). + if not resolved_infllmv2 and is_truthy_enable(infllmv2_cfg) then + local infllmv2_root = path.join(os.projectdir(), "third_party", "infllmv2_cuda_impl") + if os.isdir(infllmv2_root) then + resolved_infllmv2 = detect_infllmv2_so(infllmv2_root) + end + end + + if not resolved_infllmv2 then + local default_root = path.join(os.projectdir(), "third_party", "infllmv2_cuda_impl") + error( + "[InfLLM-V2] Cannot find built InfLLM-V2 shared library (infllm_v2/*.so).\n" .. + "You must build it first, then point xmake to it.\n\n" .. + "Options:\n" .. + " (A) Pass a direct .so path:\n" .. + " xmake f --aten=y --infllmv2=/abs/path/to/libinfllm_v2.so -cv\n" .. + " (B) Pass an infllmv2_cuda_impl root directory (auto-detects build/lib.*/infllm_v2/*.so):\n" .. + " xmake f --aten=y --infllmv2=/abs/path/to/infllmv2_cuda_impl -cv\n" .. + " (C) If you checked it out under this repo:\n" .. + " " .. default_root .. "\n" .. + " xmake f --aten=y --infllmv2=y -cv\n" + ) + end + + if has_config("aten") then + target:add("defines", "ENABLE_INFLLMV2") + end + + local abs = path.absolute(resolved_infllmv2) + local so_dir = path.directory(abs) + local so_name = path.filename(abs) + -- IMPORTANT: ensure `infinicore_cpp_api` gets a DT_NEEDED on infllm_v2 .so. + -- Using `shflags` (not `ldflags`) and `--no-as-needed` avoids the linker + -- dropping the dependency and leaving runtime undefined symbols + -- (e.g. `mha_varlen_fwd`) that would otherwise require LD_PRELOAD/ctypes preload. + target:add("shflags", "-Wl,--no-as-needed -L" .. so_dir .. " -l:" .. so_name .. " -Wl,-rpath," .. so_dir, { public = true }) + end end) -- Add InfiniCore C++ source files (needed for RoPE and other nn modules) @@ -534,8 +644,8 @@ target("infinicore_cpp_api") target_end() target("_infinicore") - add_packages("boost") if is_mode("debug") then + add_packages("boost") add_defines("BOOST_STACKTRACE_USE_BACKTRACE") add_links("backtrace") else @@ -559,6 +669,18 @@ target("_infinicore") add_files("src/infinicore/pybind11/**.cc") set_installdir("python/infinicore") + on_install(function (target) + -- Make the in-tree Python package usable after `xmake install _infinicore`. + -- (Reviewer request: keep install logic in install phase, not after_build.) + local targetfile = target:targetfile() + if targetfile and os.isfile(targetfile) then + local libdir = path.join(os.projectdir(), "python", "infinicore", "lib") + if not os.isdir(libdir) then + os.mkdir(libdir) + end + os.cp(targetfile, path.join(libdir, path.filename(targetfile))) + end + end) target_end() option("editable") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 602fb190d..3a1baffd5 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -68,7 +68,10 @@ target("infiniop-nvidia") for _, arch in ipairs(arch_opt:split(",")) do arch = arch:trim() local compute = arch:gsub("sm_", "compute_") - add_cuflags("-gencode=arch=" .. compute .. ",code=" .. arch) + local gencode = "-gencode=arch=" .. compute .. ",code=" .. arch + -- NVCC compile + device-link must both get gencode. + add_cuflags(gencode, {force = true}) + add_culdflags(gencode, {force = true}) end else add_cugencodes("native") @@ -151,7 +154,7 @@ target("flash-attn-nvidia") local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() local LIB_PYTHON = os.iorunv("python", {"-c", "import glob,sysconfig,os;print(glob.glob(os.path.join(sysconfig.get_config_var('LIBDIR'),'libpython*.so'))[0])"}):trim() - + -- Include dirs (needed for both device and host) target:add("includedirs", FLASH_ATTN_ROOT .. "/csrc/flash_attn/src", {public = false}) target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include", {public = false}) @@ -167,14 +170,14 @@ target("flash-attn-nvidia") add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/flash_api.cpp") add_files(FLASH_ATTN_ROOT .. "/csrc/flash_attn/src/*.cu") - + -- Link options - add_ldflags("-Wl,--no-undefined", {force = true}) - + add_ldflags("-Wl,--no-undefined") + -- Compile options - add_cxflags("-fPIC", {force = true}) + add_cxflags("-fPIC") add_cuflags("-Xcompiler=-fPIC") - add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math", {force = true}) + add_cuflags("--forward-unknown-to-host-compiler --expt-relaxed-constexpr --use_fast_math") set_values("cuda.rdc", false) else -- If flash-attn is not available, just create an empty target