diff --git a/CMakeLists.txt b/CMakeLists.txt index df636b27..d22a2583 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,3 +204,38 @@ link_infini_train_exe(test_precision_check) add_executable(test_lora test/lora/test_lora.cc) link_infini_train_exe(test_lora) +add_executable(test_scalar test/scalar/test_scalar.cc) +link_infini_train_exe(test_scalar) + +add_executable(test_dtype_dispatch test/dispatch/test_dtype_dispatch.cc) +link_infini_train_exe(test_dtype_dispatch) + +# Negative compile test: missing dtype registration must fail at compile time. +set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE + ${PROJECT_SOURCE_DIR}/test/dispatch/test_dtype_dispatch_compile_fail.cc) + +try_compile(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED + ${CMAKE_BINARY_DIR}/CMakeFiles/try_compile_dtype_dispatch_missing_map + SOURCES ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE} + CMAKE_FLAGS + "-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}" + "-DCMAKE_CXX_STANDARD_REQUIRED=ON" + "-DCMAKE_CXX_EXTENSIONS=OFF" + "-DCMAKE_CXX_FLAGS=-I${PROJECT_SOURCE_DIR}" + OUTPUT_VARIABLE DTYPE_DISPATCH_TRY_COMPILE_OUTPUT +) + +if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED) + message(FATAL_ERROR + "dtype dispatch compile-fail test unexpectedly succeeded.\n" + "Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n" + "Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}") +endif() + +add_custom_target(test_dtype_dispatch_compile_fail + COMMAND ${CMAKE_COMMAND} -E echo + "dtype dispatch compile-fail check passed (missing dtype registration correctly fails to compile)." + VERBATIM +) + +add_dependencies(test_dtype_dispatch test_dtype_dispatch_compile_fail) diff --git a/docs/device_guard_design.md b/docs/device_guard_design.md new file mode 100644 index 00000000..64e59892 --- /dev/null +++ b/docs/device_guard_design.md @@ -0,0 +1,210 @@ +# Device Guard Design +device 注册初版基建 pr:https://github.com/InfiniTensor/InfiniTrain/pull/103 + +## 1. 设计背景与目标 + +### 1.1 背景 + +InfiniTrain 需要长期支持: + +- 多种设备类型(CPU/CUDA/国产芯片) +- 多种运行时能力(stream、memory、blas、通信等) +- 在不侵入上层逻辑的前提下进行后端扩展与替换 + +在实际工程中,如果设备相关逻辑散落在框架各个模块,会导致: + +- `#ifdef USE_CUDA/USE_MUSA/...` 泛滥 +- 新硬件接入需要修改大量框架核心代码 +- 设备切换与资源管理缺乏统一语义 + +### 1.2 设计目标 + +InfiniTrain 的 device 注册机制设计目标是: + +1. 统一抽象:将所有与设备相关的运行时行为抽象到一个统一接口中。 +2. 后端可插拔:新设备后端可通过注册机制接入,无需修改框架核心逻辑。 +3. RAII 语义清晰:设备切换、资源恢复具备严格的作用域。 +4. 最小上层侵入:上层模块(Tensor/Autograd/Module)只感知 DeviceGuard/DeviceGuardImpl,不感知具体后端实现。 + +## 2. 核心组件 + +InfiniTrain 的 device 机制由三类核心组件构成: + +```C++ ++-------------------+ +| DeviceGuard | ← 对外 RAII 接口(public) ++-------------------+ + | + v ++-------------------+ +| DeviceGuardImpl | ← 后端抽象接口(virtual) ++-------------------+ + ^ + | ++-------------------+ +| DeviceGuardImpl | +| Registry | ← 全局注册表(singleton) ++-------------------+ +``` + +其中 DeviceGuard 与 DeviceGuardImpl 的关系是: + +| 组件 | 职责 | +| --------------- | ------------------------------------------------------------ | +| DeviceGuard | 管理 “当前在哪个 device 上” 的上下文语义(RAII),语义与 device index 绑定;负责 device 的保存/切换/恢复,并将具体 runtime 操作转发给对应的 DeviceGuardImpl。 | +| DeviceGuardImpl | 管理 “在该类 device 上如何执行 runtime 操作”,语义与 device type 绑定;对外提供 设备管理查询、stream、blas、同步、内存 等运行时能力接口。 | + +### 2.1 DeviceGuardImpl:运行时能力抽象(对外暴露) + +DeviceGuardImpl 是 InfiniTrain 中 device runtime 能力的统一抽象接口,并且是框架内部对外暴露的能力接口,封装了所有与 device 相关的行为(待补充 event 相关接口): + +```C++ +// ---------------------------------------------------------------------- +// Device management +// ---------------------------------------------------------------------- + +virtual Device GetDevice() const = 0; + +virtual void SetDevice(Device device) const; + +virtual int8_t DeviceCount() const; + +virtual Device::DeviceType Type() const = 0; + +// ---------------------------------------------------------------------- +// Stream management +// ---------------------------------------------------------------------- + +virtual Stream *GetStream(Device) const; + +// ---------------------------------------------------------------------- +// Synchronization +// ---------------------------------------------------------------------- + +virtual void SynchronizeDevice(Device) const; + +virtual void SynchronizeStream(Stream *) const; + +// ---------------------------------------------------------------------- +// BLAS handle +// ---------------------------------------------------------------------- + +virtual BlasHandle *GetBlasHandle(Device) const; + +// ---------------------------------------------------------------------- +// Memory operations +// ---------------------------------------------------------------------- + +virtual void Malloc(void **dev_ptr, size_t size) = 0; + +virtual void MallocAsync(void **dev_ptr, size_t size, Stream *stream); + +virtual void Free(void *dev_ptr) = 0; + +virtual void FreeAsync(void *dev_ptr, Stream *stream); + +virtual void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) = 0; + +virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream); + +virtual void ResetMemPoolHighWatermarks(Device device) const; + +virtual std::pair GetMemPoolPeakMB(Device device) const; +``` + +### 2.2 DeviceGuard:RAII 前端接口 + +DeviceGuard 是设备上下文的 RAII 管理器,其职责严格限定为: + +- 保存当前 device +- 切换到目标 device +- 在作用域结束时恢复原 device + +DeviceGuard 不直接提供任何运行时能力接口。 + +使用示例: + +```C++ +{ + DeviceGuard guard(Device(DeviceType::kCUDA, 1)); + // 当前线程的 device 上下文被切换到 CUDA:1 + // 所有 runtime 操作将发生在 CUDA:1 +} +// 离开作用域后,自动恢复进入前的 device +``` + +### 2.3 DeviceGuardImplRegistry:全局注册表 + +`DeviceGuardImplRegistry`是 InfiniTrain 中用于管理 device runtime 后端实现的全局注册表,采用 singleton 模式,生命周期覆盖整个进程。 + +其核心职责是维护`DeviceType -> DeviceGuardImpl`的一对一映射关系: + +```C++ +std::unordered_map> impls_; +``` + +## 3. Runtime Capability 获取与使用范式 + +### 3.1 获取入口 + +```C++ +DeviceGuardImpl* GetDeviceGuardImpl(Device::DeviceType type); +``` + +- 返回指定`DeviceType`的 DeviceGuardImpl +- 若未注册对应 backend,直接报错 + +### 3.2 推荐使用模式(标准范式) + +```C++ +auto device = tensor->GetDevice(); +const int64_t num_elements = tensor->NumElements(); +std::vector buffer(num_elements); + +{ + // 1. 切换 device 上下文(RAII scope) + core::DeviceGuard guard(device); + + // 2. 获取 runtime capability + auto* impl = core::GetDeviceGuardImpl(device.type()); + + // 3. 执行 runtime 操作 + const core::MemcpyKind kind = + device.type() == Device::DeviceType::kCPU + ? core::MemcpyKind::kD2D // CPU: host-host memcpy + : core::MemcpyKind::kH2D; // Device: host-device copy + + impl->MemcpyAsync( + tensor->DataPtr(), // dst + buffer.data(), // src + num_elements * sizeof(float), // count + kind, // kind(说明:在 CPU backend 中,kD2D 对应普通 memcpy) + impl->GetStream(device) // stream + ); +} // <-- DeviceGuard 在此处析构,device 上下文被恢复 +``` + +## 4. Backend 注册机制(静态注册) + +### 4.1 注册宏 + +```C++ +#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \ + static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \ + infini_train::core::DeviceGuardImplRegistry::Instance().Register(device_type, std::make_unique()); \ + return true; \ + }(); +``` + +采用静态变量 + lambda 在程序启动阶段完成注册。 + +### 4.2 使用示例(CUDA Backend) + +```C++ +class CudaGuardImpl : public DeviceGuardImpl { + ... +}; + +INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl) +``` + diff --git a/docs/dtype_registry_design.md b/docs/dtype_registry_design.md new file mode 100644 index 00000000..d667f76e --- /dev/null +++ b/docs/dtype_registry_design.md @@ -0,0 +1,96 @@ +# Low-Precision DType Abstraction & Backend Registration Design +统一低精度类型抽象与后端显式注册 pr:https://github.com/InfiniTensor/InfiniTrain/pull/114 + +## 1. 背景与目标 + +InfiniTrain 在引入 BF16 / FP16 之前,框架层并没有低精度类型的统一抽象,所有 16-bit 浮点语义都直接绑定到后端原生类型:CUDA 侧使用 __half / __nv_bfloat16,CPU 侧则直接使用 uint16_t。这种设计带来了几个问题: + +1. **框架代码被 `#ifdef USE_CUDA` 污染。** + `infini_train/include/datatype.h`、`infini_train/src/nn/init.cc` 等通用模块都需要写出 `#ifdef USE_CUDA … #else …` 来在「有 CUDA」和「没有 CUDA」两个版本之间切换 16-bit 类型映射;非 CUDA 路径只能退化成 `uint16_t`,而 `uint16_t` 又会与 + `kUINT16` 的反向映射产生歧义。 +2. **`TypeMap` 是「全后端共享」的单点表。** + 旧 `TypeMap` 把所有标量类型直接映射到 C++ 类型。CPU 与 CUDA 共享同一个表,意味着不可能在不同后端把 `kFLOAT16` 映射到不同的本地标量;要扩展新硬件必须改框架头文件。 +3. **类型提升耦合具体后端类型。** + 旧的 `WidestType_t` 在 C++ 模板层面做提升,需要每个调用点先 dispatch 出一对具体的标量类型(例如 `nv_bfloat16` + `float`),再交给元函数做选择。这把「类型提升」这一纯 dtype 级别的逻辑跟「后端具体标量」捆死了。 +4. **静默 fallback 容易掩盖错误。** + 一旦某个后端忘记定义低精度类型,旧实现默认映射到 `uint16_t`,会得到一个语义错误的内核,而不是显式报错。 + +本工作的目标是: + +> **抽象出框架级通用低精度类型 FP16/BF16**,让框架代码不再直接依赖任何后端原生 16-bit 类型;同时把框架 [DataType -> 后端 C++ 类型] 的映射改为**显式注册**机制,未注册的类型如果被实例化,会在编译期被拦截报错。 + +## 2. Design In One Diagram + +``` +framework code ──► FP16 / BF16 (datatype.h, 纯软件实现,提供基本转换操作) + PromoteDataTypes(DataType, DataType) + +kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc + │ + ▼ + BackendTypeMap (主模板只声明不定义) + │ + ├─ kFLOAT16 / kBFLOAT16 → 后端在 *_dispatch.h 显式特化后注册 + │ └── CUDA: __half / __nv_bfloat16 + │ └── CPU : FP16 / BF16 + └─ 其它 10 个标量 dtype 使用默认注册 → INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) +``` + +要点: + +- 框架层不提供任何「DataType → 后端 C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap` 完成。 +- `BackendTypeMap` 主模板**只声明不定义**,只有后端显式特化并完成注册的组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。 + +## 3. Core API + +| API | 位置 | 说明 | +| --- | --- | --- | +| `struct FP16 / BF16` | [datatype.h](../infini_train/include/datatype.h) | 16-bit 软件包装(IEEE-754 half / truncated bf16),承担框架身份、存储布局、fallback 转换;不承担后端高性能算术语义。 | +| `PromoteDataTypes(DataType, DataType)` | [datatype.h](../infini_train/include/datatype.h) | 纯枚举到枚举的类型提升。规则:FP16+BF16→FP32;浮点优先于整数;同类按字节宽取大。 | +| `BackendTypeMap` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 主模板**只声明不定义**;后端通过显式特化提供 `::type`。 | +| `INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 一次性注册 10 个非低精度 dtype(`kUINT8…kFLOAT64`)到对应 C++ 标量。 | +| `DispatchCpuFunc / DispatchCudaFunc` | `src/core/runtime/{cpu,cuda}/{cpu,cuda}_dispatch.h` | 后端 dispatch 入口,底层转发到 `DispatchByTypeMap`。 | + +## 4. How To Add A New Backend + +按以下清单操作,**不需要**修改 `infini_train/include/` 下的任何框架头文件,也不需要 `#ifdef`: + +1. 在后端的 `*_dispatch.h` 里 include `core/backend_type_map.h` 与 `dtype_dispatch.h`。 +2. 调用 `INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kXxx)` 注册 10 个标准 dtype。 +3. 若硬件支持低精度,显式特化 `BackendTypeMap` / `BackendTypeMap` 指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在编译期触发 `static_assert`。 +4. 定义 `XxxTypeMap` 转发/继承到 `BackendTypeMap`。 +5. 提供 `DispatchXxxFunc` 入口,转发到 `DispatchByTypeMap`。 + +### 最小示例 + +```cpp +// xxx_dispatch.h +#include "infini_train/include/core/backend_type_map.h" +#include "infini_train/include/dtype_dispatch.h" + +namespace infini_train::core { +// 若硬件支持低精度,显式特化 FP16/BF16 +template <> struct BackendTypeMap { using type = xxx_half; }; +template <> struct BackendTypeMap { using type = xxx_bfloat; }; +} // namespace infini_train::core + +INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kXxx) + +namespace infini_train::core::xxx { +template +struct XxxTypeMap : BackendTypeMap {}; + +template +auto DispatchXxxFunc(DataType dtype, Functor &&f, std::string_view ctx = "", Args &&...a) { + return DispatchByTypeMap( + dtype, std::forward(f), ctx, std::forward(a)...); +} +} // namespace infini_train::core::xxx +``` + +## 5. Failure Modes + +| 情形 | 表现 | +| --- | --- | +| 后端未注册某个 dtype(`BackendTypeMap` 无特化),但被 dispatch 命中 | 编译期 `static_assert` 触发,错误信息指向 `BackendTypeMap` 的显式注册要求。 | +| dispatch 的 dtype 不在调用点 `AllowedDTypes...` 白名单内 | 运行期 `LOG_UNSUPPORTED_DTYPE` 报错。 | diff --git a/infini_train/include/common/common.h b/infini_train/include/common/common.h index b6a02543..80cba728 100644 --- a/infini_train/include/common/common.h +++ b/infini_train/include/common/common.h @@ -7,11 +7,21 @@ #include "infini_train/include/datatype.h" +/** + * General Utility Macros + */ +#define EXPAND(X) X +// This macro lets you pass an arbitrary expression that may contain internal +// commas to another macro without having the commas causing the expression +// to be interpreted as being multiple arguments +// Basically an alternative for __VA_OPTS__ before C++20 +// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h +#define WRAP(...) __VA_ARGS__ +#define CAT(a, b) CAT_(a, b) +#define CAT_(a, b) a##b + #define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) #define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__ -#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \ - LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \ - + kDataTypeToDesc.at(static_cast(dtype)))) inline std::vector ComputeStrides(const std::vector &dims) { std::vector strides(dims.size(), 1); diff --git a/infini_train/include/common/cpu/common_cpu.h b/infini_train/include/common/cpu/common_cpu.h index d4c73e84..b8a01538 100644 --- a/infini_train/include/common/cpu/common_cpu.h +++ b/infini_train/include/common/cpu/common_cpu.h @@ -3,20 +3,41 @@ #include #include +#include "infini_train/include/datatype.h" + namespace infini_train::common::cpu { + +namespace detail { + +// FP16/BF16 don't support implicit conversion, so we route through float. +template DST CastImpl(SRC &&x) { + using SrcBase = std::remove_cvref_t; + if constexpr (std::is_same_v) { + return x; + } else if constexpr (std::is_same_v || std::is_same_v) { + // Destination is a framework 16-bit type: convert via float + return DST(static_cast(std::forward(x))); + } else if constexpr (std::is_same_v || std::is_same_v) { + // Source is a framework 16-bit type: widen to float first + return static_cast(static_cast(x)); + } else { + return static_cast(std::forward(x)); + } +} + +} // namespace detail + /** - * Converts a value between arbitrary types. This offers perfect - * forwarding which preserves value categories (lvalues/rvalues) + * Converts a value between arbitrary types, including framework FP16/BF16. * - * @tparam DST Destination type (deduced) + * @tparam DST Destination type * @tparam SRC Source type (deduced) - * @param x Input value (preserves const/volatile and value category) + * @param x Input value * @return Value converted to DST type */ template DST Cast(SRC &&x) { static_assert(!std::is_reference_v, "Cast cannot return reference types"); - - // TODO(lzm): add cpu-version fp16 and bf16 - return (DST)(std::forward(x)); + return detail::CastImpl(std::forward(x)); } + } // namespace infini_train::common::cpu diff --git a/infini_train/include/core/backend_type_map.h b/infini_train/include/core/backend_type_map.h new file mode 100644 index 00000000..f67b8da7 --- /dev/null +++ b/infini_train/include/core/backend_type_map.h @@ -0,0 +1,81 @@ +#pragma once + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train::core { + +/** + * Backend type mapping: DataType -> backend-native dispatch type + * + * BackendTypeMap — maps DataType to the C++ type used by kernels/dispatch. + * Primary template intentionally undefined — there is NO + * default fallback to the framework TypeMap. + * + * Backends must register dtypes explicitly: + * - Standard types (int, float, double, ...): + * call INFINI_REGISTER_STANDARD_BACKEND_TYPES(Dev) + * once at file scope in the backend's dispatch header. + * - Low-precision types (FP16, BF16): + * directly specialize BackendTypeMap + * in the backend's dispatch header (the native scalar type + * differs per backend, e.g. __half on CUDA). + * + * If a backend does not register a dtype, HasMappedType_v returns false and + * DispatchByTypeMap fires a clear static_assert at compile time. + */ + +// ----------------------------------------------------------------------------- +// BackendTypeMap: DataType -> backend dispatch type +// Primary template intentionally undefined — no TypeMap fallback. +// ----------------------------------------------------------------------------- +template struct BackendTypeMap; + +} // namespace infini_train::core + +// ----------------------------------------------------------------------------- +// INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) +// +// Explicitly registers the 10 standard (non-low-precision) dtypes for a backend +// device. Invoke once at file scope (outside any namespace) in the backend's +// dispatch header, e.g.: +// +// INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kCUDA) +// +// FP16 and BF16 are NOT registered here — backends must specialize +// BackendTypeMap directly with their native scalar +// type (e.g. __half / __nv_bfloat16 on CUDA). +// ----------------------------------------------------------------------------- +#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \ + namespace infini_train::core { \ + template <> struct BackendTypeMap { \ + using type = uint8_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = int8_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = uint16_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = int16_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = uint32_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = int32_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = uint64_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = int64_t; \ + }; \ + template <> struct BackendTypeMap { \ + using type = float; \ + }; \ + template <> struct BackendTypeMap { \ + using type = double; \ + }; \ + } /* namespace infini_train::core */ diff --git a/infini_train/include/core/runtime/device_guard.h b/infini_train/include/core/runtime/device_guard.h index c9eeeb25..dc56fc6f 100644 --- a/infini_train/include/core/runtime/device_guard.h +++ b/infini_train/include/core/runtime/device_guard.h @@ -66,6 +66,7 @@ class DeviceGuardImpl { // Device management // ---------------------------------------------------------------------- + // FIXME(dcj): impl should only bind with device type virtual Device GetDevice() const = 0; virtual void SetDevice(Device device) const; diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h index 79f325db..e2f3e2f6 100644 --- a/infini_train/include/datatype.h +++ b/infini_train/include/datatype.h @@ -1,14 +1,88 @@ #pragma once +#include #include #include #include -#ifdef USE_CUDA -#include -#include -#endif namespace infini_train { + +// ----------------------------------------------------------------------------- +// Framework scalar types (16-bit storage + fallback scalar semantics) +// ----------------------------------------------------------------------------- +// FP16/BF16 are framework-level 16-bit scalar/storage types. +// They are used for: +// - framework type identity +// - baseline dtype mapping +// - metadata / storage layout +// - CPU/reference/fallback conversion paths +// +// They are NOT intended to define backend-native arithmetic semantics. +// Backend kernels should use backend-specific type maps, e.g.: +// - CUDA: __half / __nv_bfloat16 +// - CPU : FP16 / BF16 / widened compute types (as needed) +// ----------------------------------------------------------------------------- + +namespace detail { + +// --------------------------- +// BF16 helpers +// --------------------------- +uint16_t FloatToBf16Bits(float value); +float Bf16BitsToFloat(uint16_t bits); + +// --------------------------- +// FP16 helpers +// Pure software IEEE-754 half <-> float conversion for framework fallback use. +// --------------------------- +uint16_t FloatToFp16Bits(float value); +float Fp16BitsToFloat(uint16_t bits); + +} // namespace detail + +struct alignas(2) FP16 { + uint16_t x{0}; + + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { return {}; } + + constexpr FP16() = default; + constexpr FP16(uint16_t bits, from_bits_t) : x(bits) {} + + explicit FP16(float value); + explicit FP16(double value); + explicit FP16(int value); + explicit FP16(int64_t value); + + explicit operator float() const; + explicit operator double() const; + + FP16 &operator++(); +}; + +struct alignas(2) BF16 { + uint16_t x{0}; + + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { return {}; } + + constexpr BF16() = default; + constexpr BF16(uint16_t bits, from_bits_t) : x(bits) {} + + explicit BF16(float value); + explicit BF16(double value); + explicit BF16(int value); + explicit BF16(int64_t value); + + explicit operator float() const; + explicit operator double() const; + + BF16 &operator++(); +}; + +// ----------------------------------------------------------------------------- +// DataType enum and metadata tables +// ----------------------------------------------------------------------------- enum class DataType : int8_t { kUINT8, kINT8, @@ -37,164 +111,19 @@ inline const std::unordered_map kDataTypeToDesc = { {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, }; -/** - * Compile-time type mapping from DataType enum to concrete C++ types. - * - * - Primary template: Declared but undefined to enforce specialization - * - Specializations: Explicit mappings (DataType::kFLOAT32 → float, etc) - * - TypeMap_t alias: Direct access to mapped type (TypeMap_t → int32_t) - * - * Enables type-safe generic code where operations dispatch based on DataType tokens, - * with zero runtime overhead. Extend by adding new specializations. - */ -template struct TypeMap; -template using TypeMap_t = typename TypeMap::type; - -/** - * Compile-time type mapping from C++ types to DataType enum. - * - * Example usage: DataTypeMap::value // Returns DataType::kINT32 - * DataTypeMap_v for convenient access to the mapped value (e.g., DataTypeMap_v). - */ -template struct DataTypeMap; -template inline constexpr DataType DataTypeMap_v = DataTypeMap::value; - -// Macro to define TypeMap specializations and reverse mappings -#define DEFINE_DATA_TYPE_MAPPING(ENUM_VALUE, CPP_TYPE) \ - template <> struct TypeMap { \ - using type = CPP_TYPE; \ - }; \ - template <> struct DataTypeMap { \ - static constexpr DataType value = DataType::ENUM_VALUE; \ - }; - -DEFINE_DATA_TYPE_MAPPING(kUINT8, uint8_t) -DEFINE_DATA_TYPE_MAPPING(kINT8, int8_t) -DEFINE_DATA_TYPE_MAPPING(kUINT16, uint16_t) -DEFINE_DATA_TYPE_MAPPING(kINT16, int16_t) -DEFINE_DATA_TYPE_MAPPING(kUINT32, uint32_t) -DEFINE_DATA_TYPE_MAPPING(kINT32, int32_t) -DEFINE_DATA_TYPE_MAPPING(kUINT64, uint64_t) -DEFINE_DATA_TYPE_MAPPING(kINT64, int64_t) -DEFINE_DATA_TYPE_MAPPING(kFLOAT32, float) -DEFINE_DATA_TYPE_MAPPING(kFLOAT64, double) - -#ifdef USE_CUDA -DEFINE_DATA_TYPE_MAPPING(kBFLOAT16, nv_bfloat16) -DEFINE_DATA_TYPE_MAPPING(kFLOAT16, half) -#else -// Non-CUDA fallbacks -template <> struct TypeMap { - using type = uint16_t; -}; -template <> struct TypeMap { - using type = uint16_t; -}; - -// TODO(lzm): currently for non-CUDA/CPU, there's an ambiguity of uint16_t mapping to both kUINT16 and -// kFLOAT16/kBFLOAT16. When CPU custom bfloat16/float16 types are defined, we should replace uint16_t with those types. -#endif -#undef DEFINE_DATA_TYPE_MAPPING - -// Extends std::is_floating_point to support CUDA floating-point types. -template struct is_floating_point_ext : std::is_floating_point {}; - -// Extends std::is_arithmetic to support CUDA floating-point types. -template struct is_arithmetic_ext : std::is_arithmetic {}; - -// Specializations for CUDA types -#ifdef USE_CUDA -template <> struct is_floating_point_ext<__nv_bfloat16> : std::true_type {}; -template <> struct is_arithmetic_ext<__nv_bfloat16> : std::true_type {}; -template <> struct is_floating_point_ext<__half> : std::true_type {}; -template <> struct is_arithmetic_ext<__half> : std::true_type {}; -#endif - -namespace { -template struct LargerType { - static constexpr size_t size1 = sizeof(T1); - static constexpr size_t size2 = sizeof(T2); - using type = std::conditional_t<(size1 >= size2), T1, T2>; -}; - -// Specializations of LargerType for the specific 16-bit FP combinations -#ifdef USE_CUDA -template <> struct LargerType<__nv_bfloat16, __half> { - using type = float; -}; +// ============================================================================= +// DataType-level promotion (pure enum → enum, no concrete/backend types) +// ============================================================================= +// Rules (priority order): +// 1. FP16 + BF16 → FLOAT32 (neither is a lossless superset of the other) +// 2. Any float dominates any integer → keep the float type +// 3. Same category (float-float or int-int) → wider byte size wins +// ============================================================================= -template <> struct LargerType<__half, __nv_bfloat16> { - using type = float; -}; -#endif - -/** - * @brief Finds the first type in a parameter pack that satisfies the given predicate. If no type matches, - * returns the last type in the pack (base case). - * - * @tparam Predicate Template template parameter that takes one type and provides a static `value` member - * @tparam Ts Parameter pack of types to check - */ -template