diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..ad6951d995 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/cccl"] + path = 3rdparty/cccl + url = https://github.com/NVIDIA/cccl.git diff --git a/3rdparty/cccl b/3rdparty/cccl new file mode 160000 index 0000000000..c262ef4c68 --- /dev/null +++ b/3rdparty/cccl @@ -0,0 +1 @@ +Subproject commit c262ef4c68e6e35105052f749d4e58b6f17cb515 diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..218e6f62a8 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -47,6 +47,7 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.cpp_extensions.cub import cub_topk GEMM_CASES = [ (256, 256, 512), @@ -1955,3 +1956,42 @@ def f(x): actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) assert_allclose(actual, expected, dtype=dtype) + + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) +@pytest.mark.parametrize( + "problem_size", [(10000, 100), (50000, 200), (100000, 500), (1000000, 1000), (5000000, 2000)] +) +class TestCubOps: + def test_cub_topk(self, dtype, problem_size): + n, k = problem_size + + prng_key = jax.random.PRNGKey(0) + keys = jax.random.split(prng_key, 3) + topk_values = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) + bottom_values = jax.random.uniform( + keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0 + ) + x = jnp.concatenate([topk_values, bottom_values]) + x = jax.random.permutation(keys[2], x) + + ref_topk_jit = jax.jit(jax.lax.top_k, static_argnums=(1,)) + prim_topk_jit = jax.jit(cub_topk, static_argnums=(1,)) + + ref_topk, ref_indices = ref_topk_jit(x, k) + prim_topk, prim_indices = prim_topk_jit(x, k) + + # CUB output does not guarantee the order of the topk values, sort them for comparison + ref_topk, ref_indices = jax.lax.sort_key_val(ref_topk, ref_indices) + prim_topk, prim_indices = jax.lax.sort_key_val(prim_topk, prim_indices) + + assert_allclose(ref_topk, prim_topk, dtype=dtype) + + # sort and sort_key_val are ascending, make sure the smallest topk value + # prim_topk[0] is not smaller than the k+1 largest value in the original array + sorted_x = jax.lax.sort(x) + assert prim_topk[0] >= sorted_x[-(k + 1)] + + # TopK values can be duplicated, instead of directly comparing the indices, we check + # if the values at the returned indices are the same + assert_allclose(x[ref_indices], x[prim_indices], dtype=dtype) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..beed91425a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -95,6 +95,16 @@ set(CUTLASS_INCLUDE_DIR set(CUTLASS_TOOLS_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") +# CCCL (CUDA Core Compute Libraries) +set(CCCL_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cccl") +if(NOT EXISTS "${CCCL_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find CCCL at ${CCCL_INCLUDE_DIR}. " + "Try running 'git submodule update --init --recursive' " + "within the Transformer Engine source.") +endif() + # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -151,6 +161,7 @@ list(APPEND transformer_engine_cuda_sources normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu + util/cub.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu @@ -262,8 +273,11 @@ target_link_libraries(transformer_engine PUBLIC target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +# Use CCCL from 3rdparty instead of the one from CUDA Toolkit target_include_directories(transformer_engine SYSTEM PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) + ${CCCL_INCLUDE_DIR}/thrust + ${CCCL_INCLUDE_DIR}/cub + ${CCCL_INCLUDE_DIR}/libcudacxx/include) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} diff --git a/transformer_engine/common/include/transformer_engine/cub.h b/transformer_engine/common/include/transformer_engine/cub.h new file mode 100644 index 0000000000..ad42bd3582 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/cub.h @@ -0,0 +1,39 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_CUB_H_ +#define TRANSFORMER_ENGINE_CUB_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute the top-K largest (key, value) pairs using CUB. + * + * \param[in] stream CUDA stream used for the operation. + * \param[in] keys_in Input 1D keys tensor, shape (num_items,) + * \param[in] values_in Input 1D values tensor, shape (num_items,) + * \param[in,out] keys_out Output 1D keys tensor, shape (k,) + * \param[in,out] values_out Output 1D values tensor, shape (k,) + * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,) + * \param[in] num_items Number of items in the input tensor + * \param[in] k Number of top-K largest values to return + * \param[in] workspace_bytes Workspace size in bytes + * + * Requirements: + * - Only supports float32, float16, bfloat16 keys and int32 values. + */ +void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, + NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, + const int num_items, const int k, const size_t workspace_bytes); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/common/util/cub.cu b/transformer_engine/common/util/cub.cu new file mode 100644 index 0000000000..6e9db56fa4 --- /dev/null +++ b/transformer_engine/common/util/cub.cu @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include + +#include "../common.h" + +void nvte_cub_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor values_in, + NVTETensor keys_out, NVTETensor values_out, NVTETensor workspace, int num_items, + int k, size_t workspace_bytes) { + NVTE_API_CALL(nvte_cub_topk); + using namespace transformer_engine; + + const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); + const Tensor *values_in_tensor = convertNVTETensorCheck(values_in); + Tensor *keys_out_tensor = convertNVTETensor(keys_out); + Tensor *values_out_tensor = convertNVTETensor(values_out); + Tensor *workspace_tensor = convertNVTETensor(workspace); + auto keys_in_dtype = keys_in_tensor->data.dtype; + auto values_in_dtype = values_in_tensor->data.dtype; + + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + cuda::stream_ref stream_ref{stream}; + auto env = cuda::std::execution::env{stream_ref, requirements}; + +#define DISPATCH_CUB_TOPK(KeyT, ValueT) \ + do { \ + KeyT *d_keys_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + KeyT *d_keys_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + ValueT *d_values_in = reinterpret_cast(values_in_tensor->data.dptr); \ + ValueT *d_values_out = reinterpret_cast(values_out_tensor->data.dptr); \ + void *d_workspace = reinterpret_cast(workspace_tensor->data.dptr); \ + cub::DeviceTopK::MaxPairs(d_workspace, workspace_bytes, d_keys_in, d_keys_out, d_values_in, \ + d_values_out, num_items, k, env); \ + } while (0); + + if (keys_in_dtype == DType::kFloat32 && values_in_dtype == DType::kInt32) { + DISPATCH_CUB_TOPK(float, int); + } else if (keys_in_dtype == DType::kFloat16 && values_in_dtype == DType::kInt32) { + DISPATCH_CUB_TOPK(__half, int); + } else if (keys_in_dtype == DType::kBFloat16 && values_in_dtype == DType::kInt32) { + DISPATCH_CUB_TOPK(__nv_bfloat16, int); + } else { + NVTE_ERROR("Unsupported input key and value data types"); + } +#undef DISPATCH_CUB_TOPK +} diff --git a/transformer_engine/jax/cpp_extensions/cub.py b/transformer_engine/jax/cpp_extensions/cub.py new file mode 100644 index 0000000000..41ebe7fc49 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/cub.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""CUB custom ops""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi + +from .base import BasePrimitive, register_primitive + +__all__ = ["CubTopkPrimitive"] + + +def get_cub_topk_workspace_bytes() -> int: + """ + Get the workspace size for CUB Topk + The safe way is calling the CUB kernel to query the workspace size. + For convenience, we use a heuristic value based on experiments. + 4 MiB is enough for N up to 5,000,000 and K up to 100,000. + """ + return 4 * 1024 * 1024 + + +class CubTopkPrimitive(BasePrimitive): + """ + CUB Topk Primitive + """ + + name = "te_cub_topk_ffi" + multiple_results = True + impl_static_args = (2,) # k_value + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + in_keys_aval, + in_values_aval, + *, + k_value, + ): + keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) + values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype) + assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert values_dtype == jnp.int32 + + workspace_bytes = get_cub_topk_workspace_bytes() + out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype) + out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32) + workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) + return (out_keys_aval, out_values_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs) + return (out_keys_aval, out_values_aval) + + @staticmethod + def lowering( + ctx, + in_keys, + in_values, + k_value, + ): + workspace_bytes = get_cub_topk_workspace_bytes() + return ffi.ffi_lowering( + CubTopkPrimitive.name, + )( + ctx, + in_keys, + in_values, + k_value=k_value, + workbuf_bytes=workspace_bytes, + ) + + @staticmethod + def impl( + in_keys, + in_values, + k_value, + ): + assert CubTopkPrimitive.inner_primitive is not None + out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind( + in_keys, + in_values, + k_value=k_value, + ) + return (out_keys, out_values) + + +register_primitive(CubTopkPrimitive) + + +def cub_topk( + x: jnp.ndarray, + k_value: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + CUB Topk max pairs + """ + keys = x + values = jnp.arange(x.shape[0], dtype=jnp.int32) + out_keys, out_values = CubTopkPrimitive.outer_primitive.bind( + keys, + values, + k_value=k_value, + ) + return out_keys, out_values diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..c670ac6e05 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -171,6 +171,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// Cub Topk +XLA_FFI_DECLARE_HANDLER_SYMBOL(CubTopkHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/cub.cpp b/transformer_engine/jax/csrc/extensions/cub.cpp new file mode 100644 index 0000000000..7cfc0d0421 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cub.cpp @@ -0,0 +1,73 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/cub.h" + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type CubTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type values_in_buf, + Result_Type keys_out_buf, Result_Type values_out_buf, + Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { + auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); + auto values_in_dtype = convert_ffi_datatype_to_te_dtype(values_in_buf.element_type()); + auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); + auto values_out_dtype = convert_ffi_datatype_to_te_dtype(values_out_buf->element_type()); + NVTE_CHECK(keys_in_dtype == keys_out_dtype, "Input and output keys must have the same datatype"); + NVTE_CHECK(values_in_dtype == values_out_dtype, + "Input and output values must have the same datatype"); + NVTE_CHECK(values_in_dtype == DType::kInt32, "CubTopkFFI() only supports int32 values for now"); + + auto keys_in_shape = keys_in_buf.dimensions(); + auto values_in_shape = values_in_buf.dimensions(); + auto keys_out_shape = keys_out_buf->dimensions(); + auto values_out_shape = values_out_buf->dimensions(); + NVTE_CHECK(keys_in_shape.size() == 1, "Keys input must have 1 dimension"); + NVTE_CHECK(values_in_shape.size() == 1, "Values input must have 1 dimension"); + NVTE_CHECK(keys_out_shape.size() == 1, "Keys output must have 1 dimension"); + NVTE_CHECK(values_out_shape.size() == 1, "Values output must have 1 dimension"); + NVTE_CHECK(keys_in_shape[0] == values_in_shape[0], + "Keys and values input must have the same number of items"); + NVTE_CHECK(keys_out_shape[0] == values_out_shape[0], + "Keys and values output must have the same number of items"); + int num_items = static_cast(keys_in_shape[0]); + int k = static_cast(k_value); + + auto input_shape = std::vector{keys_in_shape[0]}; + auto output_shape = std::vector{keys_out_shape[0]}; + auto workspace_shape = std::vector{workbuf_bytes}; + + auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), input_shape, keys_in_dtype); + auto values_in_tensor = TensorWrapper(values_in_buf.untyped_data(), input_shape, values_in_dtype); + auto keys_out_tensor = TensorWrapper(keys_out_buf->untyped_data(), output_shape, keys_out_dtype); + auto values_out_tensor = + TensorWrapper(values_out_buf->untyped_data(), output_shape, values_out_dtype); + auto workspace_tensor = + TensorWrapper(workspace_buf->untyped_data(), workspace_shape, DType::kByte); + + nvte_cub_topk(stream, keys_in_tensor.data(), values_in_tensor.data(), keys_out_tensor.data(), + values_out_tensor.data(), workspace_tensor.data(), num_items, k, workbuf_bytes); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CubTopkHandler, CubTopkFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // keys_buf + .Arg() // values_buf + .Ret() // topk_buf + .Ret() // indices_buf + .Ret() // workspace_buf + .Attr("k_value") + .Attr("workbuf_bytes"), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..67e6831739 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,6 +100,9 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + // Cub Topk + dict["te_cub_topk_ffi"] = EncapsulateFFI(CubTopkHandler); + return dict; }