Skip to content
35 changes: 35 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,38 @@ link_infini_train_exe(test_precision_check)
add_executable(test_lora test/lora/test_lora.cc)
link_infini_train_exe(test_lora)

add_executable(test_scalar test/scalar/test_scalar.cc)
link_infini_train_exe(test_scalar)

add_executable(test_dtype_dispatch test/dispatch/test_dtype_dispatch.cc)
link_infini_train_exe(test_dtype_dispatch)

# Negative compile test: missing dtype registration must fail at compile time.
set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE
${PROJECT_SOURCE_DIR}/test/dispatch/test_dtype_dispatch_compile_fail.cc)

try_compile(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED
${CMAKE_BINARY_DIR}/CMakeFiles/try_compile_dtype_dispatch_missing_map
SOURCES ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}
CMAKE_FLAGS
"-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}"
"-DCMAKE_CXX_STANDARD_REQUIRED=ON"
"-DCMAKE_CXX_EXTENSIONS=OFF"
"-DCMAKE_CXX_FLAGS=-I${PROJECT_SOURCE_DIR}"
OUTPUT_VARIABLE DTYPE_DISPATCH_TRY_COMPILE_OUTPUT
)

if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED)
message(FATAL_ERROR
"dtype dispatch compile-fail test unexpectedly succeeded.\n"
"Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n"
"Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
endif()

add_custom_target(test_dtype_dispatch_compile_fail
COMMAND ${CMAKE_COMMAND} -E echo
"dtype dispatch compile-fail check passed (missing dtype registration correctly fails to compile)."
VERBATIM
)

add_dependencies(test_dtype_dispatch test_dtype_dispatch_compile_fail)
210 changes: 210 additions & 0 deletions docs/device_guard_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Device Guard Design
device 注册初版基建 pr:https://github.com/InfiniTensor/InfiniTrain/pull/103

## 1. 设计背景与目标

### 1.1 背景

InfiniTrain 需要长期支持:

- 多种设备类型(CPU/CUDA/国产芯片)
- 多种运行时能力(stream、memory、blas、通信等)
- 在不侵入上层逻辑的前提下进行后端扩展与替换

在实际工程中,如果设备相关逻辑散落在框架各个模块,会导致:

- `#ifdef USE_CUDA/USE_MUSA/...` 泛滥
- 新硬件接入需要修改大量框架核心代码
- 设备切换与资源管理缺乏统一语义

### 1.2 设计目标

InfiniTrain 的 device 注册机制设计目标是:

1. 统一抽象:将所有与设备相关的运行时行为抽象到一个统一接口中。
2. 后端可插拔:新设备后端可通过注册机制接入,无需修改框架核心逻辑。
3. RAII 语义清晰:设备切换、资源恢复具备严格的作用域。
4. 最小上层侵入:上层模块(Tensor/Autograd/Module)只感知 DeviceGuard/DeviceGuardImpl,不感知具体后端实现。

## 2. 核心组件

InfiniTrain 的 device 机制由三类核心组件构成:

```C++
+-------------------+
| DeviceGuard | ← 对外 RAII 接口(public)
+-------------------+
|
v
+-------------------+
| DeviceGuardImpl | ← 后端抽象接口(virtual)
+-------------------+
^
|
+-------------------+
| DeviceGuardImpl |
| Registry | ← 全局注册表(singleton)
+-------------------+
```

其中 DeviceGuard 与 DeviceGuardImpl 的关系是:

| 组件 | 职责 |
| --------------- | ------------------------------------------------------------ |
| DeviceGuard | 管理 “当前在哪个 device 上” 的上下文语义(RAII),语义与 device index 绑定;负责 device 的保存/切换/恢复,并将具体 runtime 操作转发给对应的 DeviceGuardImpl。 |
| DeviceGuardImpl | 管理 “在该类 device 上如何执行 runtime 操作”,语义与 device type 绑定;对外提供 设备管理查询、stream、blas、同步、内存 等运行时能力接口。 |

### 2.1 DeviceGuardImpl:运行时能力抽象(对外暴露)

DeviceGuardImpl 是 InfiniTrain 中 device runtime 能力的统一抽象接口,并且是框架内部对外暴露的能力接口,封装了所有与 device 相关的行为(待补充 event 相关接口):

```C++
// ----------------------------------------------------------------------
// Device management
// ----------------------------------------------------------------------

virtual Device GetDevice() const = 0;

virtual void SetDevice(Device device) const;

virtual int8_t DeviceCount() const;

virtual Device::DeviceType Type() const = 0;

// ----------------------------------------------------------------------
// Stream management
// ----------------------------------------------------------------------

virtual Stream *GetStream(Device) const;

// ----------------------------------------------------------------------
// Synchronization
// ----------------------------------------------------------------------

virtual void SynchronizeDevice(Device) const;

virtual void SynchronizeStream(Stream *) const;

// ----------------------------------------------------------------------
// BLAS handle
// ----------------------------------------------------------------------

virtual BlasHandle *GetBlasHandle(Device) const;

// ----------------------------------------------------------------------
// Memory operations
// ----------------------------------------------------------------------

virtual void Malloc(void **dev_ptr, size_t size) = 0;

virtual void MallocAsync(void **dev_ptr, size_t size, Stream *stream);

virtual void Free(void *dev_ptr) = 0;

virtual void FreeAsync(void *dev_ptr, Stream *stream);

virtual void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) = 0;

virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream);

virtual void ResetMemPoolHighWatermarks(Device device) const;

virtual std::pair<size_t, size_t> GetMemPoolPeakMB(Device device) const;
```

### 2.2 DeviceGuard:RAII 前端接口

DeviceGuard 是设备上下文的 RAII 管理器,其职责严格限定为:

- 保存当前 device
- 切换到目标 device
- 在作用域结束时恢复原 device

DeviceGuard 不直接提供任何运行时能力接口。

使用示例:

```C++
{
DeviceGuard guard(Device(DeviceType::kCUDA, 1));
// 当前线程的 device 上下文被切换到 CUDA:1
// 所有 runtime 操作将发生在 CUDA:1
}
// 离开作用域后,自动恢复进入前的 device
```

### 2.3 DeviceGuardImplRegistry:全局注册表

`DeviceGuardImplRegistry`是 InfiniTrain 中用于管理 device runtime 后端实现的全局注册表,采用 singleton 模式,生命周期覆盖整个进程。

其核心职责是维护`DeviceType -> DeviceGuardImpl`的一对一映射关系:

```C++
std::unordered_map<Device::DeviceType, std::unique_ptr<DeviceGuardImpl>> impls_;
```

## 3. Runtime Capability 获取与使用范式

### 3.1 获取入口

```C++
DeviceGuardImpl* GetDeviceGuardImpl(Device::DeviceType type);
```

- 返回指定`DeviceType`的 DeviceGuardImpl
- 若未注册对应 backend,直接报错

### 3.2 推荐使用模式(标准范式)

```C++
auto device = tensor->GetDevice();
const int64_t num_elements = tensor->NumElements();
std::vector<float> buffer(num_elements);

{
// 1. 切换 device 上下文(RAII scope)
core::DeviceGuard guard(device);

// 2. 获取 runtime capability
auto* impl = core::GetDeviceGuardImpl(device.type());

// 3. 执行 runtime 操作
const core::MemcpyKind kind =
device.type() == Device::DeviceType::kCPU
? core::MemcpyKind::kD2D // CPU: host-host memcpy
: core::MemcpyKind::kH2D; // Device: host-device copy

impl->MemcpyAsync(
tensor->DataPtr(), // dst
buffer.data(), // src
num_elements * sizeof(float), // count
kind, // kind(说明:在 CPU backend 中,kD2D 对应普通 memcpy)
impl->GetStream(device) // stream
);
} // <-- DeviceGuard 在此处析构,device 上下文被恢复
```

## 4. Backend 注册机制(静态注册)

### 4.1 注册宏

```C++
#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \
static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \
infini_train::core::DeviceGuardImplRegistry::Instance().Register(device_type, std::make_unique<class_impl>()); \
return true; \
}();
```

采用静态变量 + lambda 在程序启动阶段完成注册。

### 4.2 使用示例(CUDA Backend)

```C++
class CudaGuardImpl : public DeviceGuardImpl {
...
};

INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl)
```

96 changes: 96 additions & 0 deletions docs/dtype_registry_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Low-Precision DType Abstraction & Backend Registration Design
统一低精度类型抽象与后端显式注册 pr:https://github.com/InfiniTensor/InfiniTrain/pull/114

## 1. 背景与目标

InfiniTrain 在引入 BF16 / FP16 之前,框架层并没有低精度类型的统一抽象,所有 16-bit 浮点语义都直接绑定到后端原生类型:CUDA 侧使用 __half / __nv_bfloat16,CPU 侧则直接使用 uint16_t。这种设计带来了几个问题:

1. **框架代码被 `#ifdef USE_CUDA` 污染。**
`infini_train/include/datatype.h`、`infini_train/src/nn/init.cc` 等通用模块都需要写出 `#ifdef USE_CUDA … #else …` 来在「有 CUDA」和「没有 CUDA」两个版本之间切换 16-bit 类型映射;非 CUDA 路径只能退化成 `uint16_t`,而 `uint16_t` 又会与
`kUINT16` 的反向映射产生歧义。
2. **`TypeMap<DType>` 是「全后端共享」的单点表。**
旧 `TypeMap` 把所有标量类型直接映射到 C++ 类型。CPU 与 CUDA 共享同一个表,意味着不可能在不同后端把 `kFLOAT16` 映射到不同的本地标量;要扩展新硬件必须改框架头文件。
3. **类型提升耦合具体后端类型。**
旧的 `WidestType_t<T1, T2>` 在 C++ 模板层面做提升,需要每个调用点先 dispatch 出一对具体的标量类型(例如 `nv_bfloat16` + `float`),再交给元函数做选择。这把「类型提升」这一纯 dtype 级别的逻辑跟「后端具体标量」捆死了。
4. **静默 fallback 容易掩盖错误。**
一旦某个后端忘记定义低精度类型,旧实现默认映射到 `uint16_t`,会得到一个语义错误的内核,而不是显式报错。

本工作的目标是:

> **抽象出框架级通用低精度类型 FP16/BF16**,让框架代码不再直接依赖任何后端原生 16-bit 类型;同时把框架 [DataType -> 后端 C++ 类型] 的映射改为**显式注册**机制,未注册的类型如果被实例化,会在编译期被拦截报错。

## 2. Design In One Diagram

```
framework code ──► FP16 / BF16 (datatype.h, 纯软件实现,提供基本转换操作)
PromoteDataTypes(DataType, DataType)

kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
BackendTypeMap<Dev, DType> (主模板只声明不定义)
├─ kFLOAT16 / kBFLOAT16 → 后端在 *_dispatch.h 显式特化后注册
│ └── CUDA: __half / __nv_bfloat16
│ └── CPU : FP16 / BF16
└─ 其它 10 个标量 dtype 使用默认注册 → INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)
```

要点:

- 框架层不提供任何「DataType → 后端 C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**,只有后端显式特化并完成注册的组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。

## 3. Core API

| API | 位置 | 说明 |
| --- | --- | --- |
| `struct FP16 / BF16` | [datatype.h](../infini_train/include/datatype.h) | 16-bit 软件包装(IEEE-754 half / truncated bf16),承担框架身份、存储布局、fallback 转换;不承担后端高性能算术语义。 |
| `PromoteDataTypes(DataType, DataType)` | [datatype.h](../infini_train/include/datatype.h) | 纯枚举到枚举的类型提升。规则:FP16+BF16→FP32;浮点优先于整数;同类按字节宽取大。 |
| `BackendTypeMap<Dev, DType>` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 主模板**只声明不定义**;后端通过显式特化提供 `::type`。 |
| `INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 一次性注册 10 个非低精度 dtype(`kUINT8…kFLOAT64`)到对应 C++ 标量。 |
| `DispatchCpuFunc / DispatchCudaFunc<AllowedDTypes...>` | `src/core/runtime/{cpu,cuda}/{cpu,cuda}_dispatch.h` | 后端 dispatch 入口,底层转发到 `DispatchByTypeMap<TypeMap, AllowedDTypes...>`。 |

## 4. How To Add A New Backend

按以下清单操作,**不需要**修改 `infini_train/include/` 下的任何框架头文件,也不需要 `#ifdef`:

1. 在后端的 `*_dispatch.h` 里 include `core/backend_type_map.h` 与 `dtype_dispatch.h`。
2. 调用 `INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kXxx)` 注册 10 个标准 dtype。
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>` 指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在编译期触发 `static_assert`。
4. 定义 `XxxTypeMap<DType>` 转发/继承到 `BackendTypeMap<kXxx, DType>`。
5. 提供 `DispatchXxxFunc` 入口,转发到 `DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>`。

### 最小示例

```cpp
// xxx_dispatch.h
#include "infini_train/include/core/backend_type_map.h"
#include "infini_train/include/dtype_dispatch.h"

namespace infini_train::core {
// 若硬件支持低精度,显式特化 FP16/BF16
template <> struct BackendTypeMap<Device::DeviceType::kXxx, DataType::kFLOAT16> { using type = xxx_half; };
template <> struct BackendTypeMap<Device::DeviceType::kXxx, DataType::kBFLOAT16> { using type = xxx_bfloat; };
} // namespace infini_train::core

INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kXxx)

namespace infini_train::core::xxx {
template <DataType DType>
struct XxxTypeMap : BackendTypeMap<Device::DeviceType::kXxx, DType> {};

template <DataType... AllowedDTypes, typename Functor, typename... Args>
auto DispatchXxxFunc(DataType dtype, Functor &&f, std::string_view ctx = "", Args &&...a) {
return DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>(
dtype, std::forward<Functor>(f), ctx, std::forward<Args>(a)...);
}
} // namespace infini_train::core::xxx
```

## 5. Failure Modes

| 情形 | 表现 |
| --- | --- |
| 后端未注册某个 dtype(`BackendTypeMap<Dev, DType>` 无特化),但被 dispatch 命中 | 编译期 `static_assert` 触发,错误信息指向 `BackendTypeMap` 的显式注册要求。 |
| dispatch 的 dtype 不在调用点 `AllowedDTypes...` 白名单内 | 运行期 `LOG_UNSUPPORTED_DTYPE` 报错。 |
16 changes: 13 additions & 3 deletions infini_train/include/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@

#include "infini_train/include/datatype.h"

/**
* General Utility Macros
*/
#define EXPAND(X) X
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
// Basically an alternative for __VA_OPTS__ before C++20
// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h
#define WRAP(...) __VA_ARGS__
#define CAT(a, b) CAT_(a, b)
#define CAT_(a, b) a##b

#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
#define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__
#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \
LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \
+ kDataTypeToDesc.at(static_cast<infini_train::DataType>(dtype))))

inline std::vector<int64_t> ComputeStrides(const std::vector<int64_t> &dims) {
std::vector<int64_t> strides(dims.size(), 1);
Expand Down
Loading
Loading