From f889b18879f416dcb83fb23ac3a5ca4653ee7ea5 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Wed, 28 Jan 2026 06:02:35 -0800 Subject: [PATCH 1/2] Add CUB TopK MaxPairs JAX FFI interface Signed-off-by: Hua Huang --- .gitmodules | 3 + 3rdparty/cccl | 1 + tests/jax/test_custom_call_compute.py | 36 ++++++ transformer_engine/common/CMakeLists.txt | 16 ++- .../common/include/transformer_engine/cub.h | 39 +++++++ transformer_engine/common/util/cub.cu | 58 ++++++++++ transformer_engine/jax/cpp_extensions/cub.py | 109 ++++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 3 + .../jax/csrc/extensions/cub.cpp | 68 +++++++++++ .../jax/csrc/extensions/pybind.cpp | 3 + 10 files changed, 335 insertions(+), 1 deletion(-) create mode 160000 3rdparty/cccl create mode 100644 transformer_engine/common/include/transformer_engine/cub.h create mode 100644 transformer_engine/common/util/cub.cu create mode 100644 transformer_engine/jax/cpp_extensions/cub.py create mode 100644 transformer_engine/jax/csrc/extensions/cub.cpp diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..ad6951d995 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/cccl"] + path = 3rdparty/cccl + url = https://github.com/NVIDIA/cccl.git diff --git a/3rdparty/cccl b/3rdparty/cccl new file mode 160000 index 0000000000..c262ef4c68 --- /dev/null +++ b/3rdparty/cccl @@ -0,0 +1 @@ +Subproject commit c262ef4c68e6e35105052f749d4e58b6f17cb515 diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..fcb7860bad 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -47,6 +47,7 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.cpp_extensions.cub import cub_topk GEMM_CASES = [ (256, 256, 512), @@ -1955,3 +1956,38 @@ def f(x): actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) assert_allclose(actual, expected, dtype=dtype) + + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) +@pytest.mark.parametrize("problem_size", [(10000, 100), (50000, 200), (100000, 500), (1000000, 1000), (5000000, 2000)]) +class TestCubOps: + def test_cub_topk(self, dtype, problem_size): + n, k = problem_size + + prng_key = jax.random.PRNGKey(0) + keys = jax.random.split(prng_key, 3) + topk_values = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) + bottom_values = jax.random.uniform(keys[1], shape=(n-k,), dtype=dtype, minval=0.0, maxval=1.0) + x = jnp.concatenate([topk_values, bottom_values]) + x = jax.random.permutation(keys[2], x) + + ref_topk_jit = jax.jit(jax.lax.top_k, static_argnums=(1,)) + prim_topk_jit = jax.jit(cub_topk, static_argnums=(1,)) + + ref_topk, ref_indices = ref_topk_jit(x, k) + prim_topk, prim_indices = prim_topk_jit(x, k) + + # CUB output does not guarantee the order of the topk values, sort them for comparison + ref_topk, ref_indices = jax.lax.sort_key_val(ref_topk, ref_indices) + prim_topk, prim_indices = jax.lax.sort_key_val(prim_topk, prim_indices) + + assert_allclose(ref_topk, prim_topk, dtype=dtype) + + # sort and sort_key_val are ascending, make sure the smallest topk value + # prim_topk[0] is not smaller than the k+1 largest value in the original array + sorted_x = jax.lax.sort(x) + assert(prim_topk[0] >= sorted_x[-(k+1)]) + + # TopK values can be duplicated, instead of directly comparing the indices, we check + # if the values at the returned indices are the same + assert_allclose(x[ref_indices], x[prim_indices], dtype=dtype) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..beed91425a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -95,6 +95,16 @@ set(CUTLASS_INCLUDE_DIR set(CUTLASS_TOOLS_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") +# CCCL (CUDA Core Compute Libraries) +set(CCCL_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cccl") +if(NOT EXISTS "${CCCL_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find CCCL at ${CCCL_INCLUDE_DIR}. " + "Try running 'git submodule update --init --recursive' " + "within the Transformer Engine source.") +endif() + # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -151,6 +161,7 @@ list(APPEND transformer_engine_cuda_sources normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu + util/cub.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu @@ -262,8 +273,11 @@ target_link_libraries(transformer_engine PUBLIC target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +# Use CCCL from 3rdparty instead of the one from CUDA Toolkit target_include_directories(transformer_engine SYSTEM PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) + ${CCCL_INCLUDE_DIR}/thrust + ${CCCL_INCLUDE_DIR}/cub + ${CCCL_INCLUDE_DIR}/libcudacxx/include) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} diff --git a/transformer_engine/common/include/transformer_engine/cub.h b/transformer_engine/common/include/transformer_engine/cub.h new file mode 100644 index 0000000000..ad42bd3582 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/cub.h @@ -0,0 +1,39 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_CUB_H_ +#define TRANSFORMER_ENGINE_CUB_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute the top-K largest (key, value) pairs using CUB. + * + * \param[in] stream CUDA stream used for the operation. + * \param[in] keys_in Input 1D keys tensor, shape (num_items,) + * \param[in] values_in Input 1D values tensor, shape (num_items,) + * \param[in,out] keys_out Output 1D keys tensor, shape (k,) + * \param[in,out] values_out Output 1D values tensor, shape (k,) + * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,) + * \param[in] num_items Number of items in the input tensor + * \param[in] k Number of top-K largest values to return + * \param[in] workspace_bytes Workspace size in bytes + * + * Requirements: + * - Only supports float32, float16, bfloat16 keys and int32 values. + */ +void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, + NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, + const int num_items, const int k, const size_t workspace_bytes); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/common/util/cub.cu b/transformer_engine/common/util/cub.cu new file mode 100644 index 0000000000..f2ba3b02f0 --- /dev/null +++ b/transformer_engine/common/util/cub.cu @@ -0,0 +1,58 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include "../common.h" +#include +#include + +void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, + NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, + int num_items, int k, size_t workspace_bytes) { + NVTE_API_CALL(nvte_cub_topk); + using namespace transformer_engine; + + const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); + const Tensor *values_in_tensor = convertNVTETensorCheck(values_in); + Tensor *keys_out_tensor = convertNVTETensor(keys_out); + Tensor *values_out_tensor = convertNVTETensor(values_out); + Tensor *workspace_tensor = convertNVTETensor(workspace); + auto keys_in_dtype = keys_in_tensor->data.dtype; + auto values_in_dtype = values_in_tensor->data.dtype; + + auto requirements = cuda::execution::require( + cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted + ); + cuda::stream_ref stream_ref{stream}; + auto env = cuda::std::execution::env{stream_ref, requirements}; + + #define DISPATCH_CUB_TOPK(KeyT, ValueT) \ + do { \ + KeyT *d_keys_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + KeyT *d_keys_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + ValueT *d_values_in = reinterpret_cast(values_in_tensor->data.dptr); \ + ValueT *d_values_out = reinterpret_cast(values_out_tensor->data.dptr); \ + void *d_workspace = reinterpret_cast(workspace_tensor->data.dptr); \ + cub::DeviceTopK::MaxPairs( \ + d_workspace, workspace_bytes, \ + d_keys_in, d_keys_out, \ + d_values_in, d_values_out, \ + num_items, k, env \ + ); \ + } while (0); + + if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) { + DISPATCH_CUB_TOPK(float, int); + } else if (keys_in_dtype == DType::kFloat16 && values_in_dtype == DType::kInt32) { + DISPATCH_CUB_TOPK(__half, int); + } else if (keys_in_dtype == DType::kBFloat16 && values_in_dtype == DType::kInt32) { + DISPATCH_CUB_TOPK(__nv_bfloat16, int); + } else { + NVTE_ERROR("Unsupported input key and value data types"); + } + #undef DISPATCH_CUB_TOPK +} diff --git a/transformer_engine/jax/cpp_extensions/cub.py b/transformer_engine/jax/cpp_extensions/cub.py new file mode 100644 index 0000000000..ac08e20b54 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/cub.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""CUB custom ops""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi + +from .base import BasePrimitive, register_primitive + +__all__ = ["CubTopkPrimitive"] + +def get_cub_topk_workspace_bytes() -> int: + """ + Get the workspace size for CUB Topk + The safe way is calling the CUB kernel to query the workspace size. + For convenience, we use a heuristic value based on experiments. + 4 MiB is enough for N up to 5,000,000 and K up to 100,000. + """ + return 4 * 1024 * 1024 + + +class CubTopkPrimitive(BasePrimitive): + """ + CUB Topk Primitive + """ + + name = "te_cub_topk_ffi" + multiple_results = True + impl_static_args = (2,) # k_value + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + in_keys_aval, + in_values_aval, + *, + k_value, + ): + keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) + values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype) + assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert values_dtype == jnp.int32 + + workspace_bytes = get_cub_topk_workspace_bytes() + out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype) + out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32) + workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) + return (out_keys_aval, out_values_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs) + return (out_keys_aval, out_values_aval) + + @staticmethod + def lowering( + ctx, + in_keys, + in_values, + k_value, + ): + workspace_bytes = get_cub_topk_workspace_bytes() + return ffi.ffi_lowering( + CubTopkPrimitive.name, + )( + ctx, + in_keys, + in_values, + k_value=k_value, + workbuf_bytes=workspace_bytes, + ) + + @staticmethod + def impl( + in_keys, + in_values, + k_value, + ): + assert CubTopkPrimitive.inner_primitive is not None + out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind( + in_keys, + in_values, + k_value=k_value, + ) + return (out_keys, out_values) + + +register_primitive(CubTopkPrimitive) + +def cub_topk( + x: jnp.ndarray, + k_value: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + CUB Topk max pairs + """ + keys = x + values = jnp.arange(x.shape[0], dtype=jnp.int32) + out_keys, out_values = CubTopkPrimitive.outer_primitive.bind( + keys, + values, + k_value=k_value, + ) + return out_keys, out_values diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..c670ac6e05 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -171,6 +171,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// Cub Topk +XLA_FFI_DECLARE_HANDLER_SYMBOL(CubTopkHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/cub.cpp b/transformer_engine/jax/csrc/extensions/cub.cpp new file mode 100644 index 0000000000..9151bf4626 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cub.cpp @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" +#include "transformer_engine/cub.h" + +namespace transformer_engine { +namespace jax { + +Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf, + Result_Type keys_out_buf, Result_Type values_out_buf, Result_Type workspace_buf, + int64_t k_value, int64_t workbuf_bytes) { + auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); + auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type()); + auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); + auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type()); + NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype"); + NVTE_CHECK(values_in_dtype == values_out_dtype, "Input and output values must have the same datatype"); + NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now"); + + auto keys_in_shape = keys_in_buf.dimensions(); + auto values_in_shape = values_in_buf.dimensions(); + auto keys_out_shape = keys_out_buf->dimensions(); + auto values_out_shape = values_out_buf->dimensions(); + NVTE_CHECK(keys_in_shape.size() == 1, "Keys input must have 1 dimension"); + NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension"); + NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension"); + NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension"); + NVTE_CHECK(keys_in_shape[0] == values_in_shape[0], "Keys and values input must have the same number of items"); + NVTE_CHECK(keys_out_shape[0] == values_out_shape[0], "Keys and values output must have the same number of items"); + int num_items = static_cast(keys_in_shape[0]); + int k = static_cast(k_value); + + auto input_shape = std::vector{keys_in_shape[0]}; + auto output_shape = std::vector{keys_out_shape[0]}; + auto workspace_shape = std::vector{workbuf_bytes}; + + auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype); + auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype); + auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype); + auto values_out_tensor = TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype); + auto workspace_tensor = TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte); + + nvte_cub_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), + keys_out_tensor.data(), values_out_tensor.data(), workspace_tensor.data(), + num_items, k, workbuf_bytes); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CubTopkHandler, CubTopkFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // keys_buf + .Arg() // values_buf + .Ret() // topk_buf + .Ret() // indices_buf + .Ret() // workspace_buf + .Attr("k_value") + .Attr("workbuf_bytes"), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..67e6831739 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,6 +100,9 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + // Cub Topk + dict["te_cub_topk_ffi"] = EncapsulateFFI(CubTopkHandler); + return dict; } From 8f9c5edd6aaddc1bcbbde50092b16e9a733703f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 02:14:56 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 10 ++-- transformer_engine/common/util/cub.cu | 40 +++++++--------- transformer_engine/jax/cpp_extensions/cub.py | 2 + .../jax/csrc/extensions/cub.cpp | 47 ++++++++++--------- 4 files changed, 53 insertions(+), 46 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index fcb7860bad..218e6f62a8 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1959,7 +1959,9 @@ def f(x): @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) -@pytest.mark.parametrize("problem_size", [(10000, 100), (50000, 200), (100000, 500), (1000000, 1000), (5000000, 2000)]) +@pytest.mark.parametrize( + "problem_size", [(10000, 100), (50000, 200), (100000, 500), (1000000, 1000), (5000000, 2000)] +) class TestCubOps: def test_cub_topk(self, dtype, problem_size): n, k = problem_size @@ -1967,7 +1969,9 @@ def test_cub_topk(self, dtype, problem_size): prng_key = jax.random.PRNGKey(0) keys = jax.random.split(prng_key, 3) topk_values = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) - bottom_values = jax.random.uniform(keys[1], shape=(n-k,), dtype=dtype, minval=0.0, maxval=1.0) + bottom_values = jax.random.uniform( + keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0 + ) x = jnp.concatenate([topk_values, bottom_values]) x = jax.random.permutation(keys[2], x) @@ -1986,7 +1990,7 @@ def test_cub_topk(self, dtype, problem_size): # sort and sort_key_val are ascending, make sure the smallest topk value # prim_topk[0] is not smaller than the k+1 largest value in the original array sorted_x = jax.lax.sort(x) - assert(prim_topk[0] >= sorted_x[-(k+1)]) + assert prim_topk[0] >= sorted_x[-(k + 1)] # TopK values can be duplicated, instead of directly comparing the indices, we check # if the values at the returned indices are the same diff --git a/transformer_engine/common/util/cub.cu b/transformer_engine/common/util/cub.cu index f2ba3b02f0..6e9db56fa4 100644 --- a/transformer_engine/common/util/cub.cu +++ b/transformer_engine/common/util/cub.cu @@ -5,13 +5,15 @@ ************************************************************************/ #include -#include "../common.h" -#include + #include +#include + +#include "../common.h" void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, - NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, - int num_items, int k, size_t workspace_bytes) { + NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, int num_items, + int k, size_t workspace_bytes) { NVTE_API_CALL(nvte_cub_topk); using namespace transformer_engine; @@ -23,26 +25,20 @@ void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETens auto keys_in_dtype = keys_in_tensor->data.dtype; auto values_in_dtype = values_in_tensor->data.dtype; - auto requirements = cuda::execution::require( - cuda::execution::determinism::not_guaranteed, - cuda::execution::output_ordering::unsorted - ); + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); cuda::stream_ref stream_ref{stream}; auto env = cuda::std::execution::env{stream_ref, requirements}; - #define DISPATCH_CUB_TOPK(KeyT, ValueT) \ - do { \ - KeyT *d_keys_in = reinterpret_cast(keys_in_tensor->data.dptr); \ - KeyT *d_keys_out = reinterpret_cast(keys_out_tensor->data.dptr); \ - ValueT *d_values_in = reinterpret_cast(values_in_tensor->data.dptr); \ - ValueT *d_values_out = reinterpret_cast(values_out_tensor->data.dptr); \ - void *d_workspace = reinterpret_cast(workspace_tensor->data.dptr); \ - cub::DeviceTopK::MaxPairs( \ - d_workspace, workspace_bytes, \ - d_keys_in, d_keys_out, \ - d_values_in, d_values_out, \ - num_items, k, env \ - ); \ +#define DISPATCH_CUB_TOPK(KeyT, ValueT) \ + do { \ + KeyT *d_keys_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + KeyT *d_keys_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + ValueT *d_values_in = reinterpret_cast(values_in_tensor->data.dptr); \ + ValueT *d_values_out = reinterpret_cast(values_out_tensor->data.dptr); \ + void *d_workspace = reinterpret_cast(workspace_tensor->data.dptr); \ + cub::DeviceTopK::MaxPairs(d_workspace, workspace_bytes, d_keys_in, d_keys_out, d_values_in, \ + d_values_out, num_items, k, env); \ } while (0); if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) { @@ -54,5 +50,5 @@ void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETens } else { NVTE_ERROR("Unsupported input key and value data types"); } - #undef DISPATCH_CUB_TOPK +#undef DISPATCH_CUB_TOPK } diff --git a/transformer_engine/jax/cpp_extensions/cub.py b/transformer_engine/jax/cpp_extensions/cub.py index ac08e20b54..41ebe7fc49 100644 --- a/transformer_engine/jax/cpp_extensions/cub.py +++ b/transformer_engine/jax/cpp_extensions/cub.py @@ -13,6 +13,7 @@ __all__ = ["CubTopkPrimitive"] + def get_cub_topk_workspace_bytes() -> int: """ Get the workspace size for CUB Topk @@ -92,6 +93,7 @@ def impl( register_primitive(CubTopkPrimitive) + def cub_topk( x: jnp.ndarray, k_value: int, diff --git a/transformer_engine/jax/csrc/extensions/cub.cpp b/transformer_engine/jax/csrc/extensions/cub.cpp index 9151bf4626..7cfc0d0421 100644 --- a/transformer_engine/jax/csrc/extensions/cub.cpp +++ b/transformer_engine/jax/csrc/extensions/cub.cpp @@ -4,22 +4,24 @@ * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/cub.h" + #include "../extensions.h" #include "xla/ffi/api/c_api.h" -#include "transformer_engine/cub.h" namespace transformer_engine { namespace jax { Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf, - Result_Type keys_out_buf, Result_Type values_out_buf, Result_Type workspace_buf, - int64_t k_value, int64_t workbuf_bytes) { + Result_Type keys_out_buf, Result_Type values_out_buf, + Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type()); auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type()); NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype"); - NVTE_CHECK(values_in_dtype == values_out_dtype, "Input and output values must have the same datatype"); + NVTE_CHECK(values_in_dtype == values_out_dtype, + "Input and output values must have the same datatype"); NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now"); auto keys_in_shape = keys_in_buf.dimensions(); @@ -30,8 +32,10 @@ Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension"); NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension"); NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension"); - NVTE_CHECK(keys_in_shape[0] == values_in_shape[0], "Keys and values input must have the same number of items"); - NVTE_CHECK(keys_out_shape[0] == values_out_shape[0], "Keys and values output must have the same number of items"); + NVTE_CHECK(keys_in_shape[0] == values_in_shape[0], + "Keys and values input must have the same number of items"); + NVTE_CHECK(keys_out_shape[0] == values_out_shape[0], + "Keys and values output must have the same number of items"); int num_items = static_cast(keys_in_shape[0]); int k = static_cast(k_value); @@ -42,27 +46,28 @@ Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype); auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype); auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype); - auto values_out_tensor = TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype); - auto workspace_tensor = TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte); + auto values_out_tensor = + TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype); + auto workspace_tensor = + TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte); - nvte_cub_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), - keys_out_tensor.data(), values_out_tensor.data(), workspace_tensor.data(), - num_items, k, workbuf_bytes); + nvte_cub_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), keys_out_tensor.data(), + values_out_tensor.data(), workspace_tensor.data(), num_items, k, workbuf_bytes); return ffi_with_cuda_error_check(); } XLA_FFI_DEFINE_HANDLER_SYMBOL(CubTopkHandler, CubTopkFFI, - FFI::Bind() - .Ctx() // stream - .Arg() // keys_buf - .Arg() // values_buf - .Ret() // topk_buf - .Ret() // indices_buf - .Ret() // workspace_buf - .Attr("k_value") - .Attr("workbuf_bytes"), - FFI_CudaGraph_Traits); + FFI::Bind() + .Ctx() // stream + .Arg() // keys_buf + .Arg() // values_buf + .Ret() // topk_buf + .Ret() // indices_buf + .Ret() // workspace_buf + .Attr("k_value") + .Attr("workbuf_bytes"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine