Skip to content

dtype registry#114

Open
kilinchange wants to merge 9 commits intomasterfrom
feat/dtype_registry
Open

dtype registry#114
kilinchange wants to merge 9 commits intomasterfrom
feat/dtype_registry

Conversation

@kilinchange
Copy link
Copy Markdown
Collaborator

@kilinchange kilinchange commented Mar 12, 2026

本 PR 主要功能:抽象出框架级通用低精度类型 FP16/BF16(主要作为 storage type,但提供了必要的数值操作作为 fallback 逻辑),让框架代码不再直接依赖任何后端特有低精度类型(例如 __half/__nv_bfloat16);同时把框架 [DataType -> 后端 C++ 类型] 的映射改为显式注册机制,未注册的类型如果被实例化,会在编译期被拦截报错。

主要变更

  • docs 目录:新增了俩设计文档,device_guard_design.md 是之前飞书上的 device 注册设计文档,在 github 上补了一份,可以不用看;dtype_registry_design.md 是这个 pr 工作的设计文档,可以重点结合代码看下。

======================= 根目录下的通用头文件/实现改动 =======================

  • datatype.[h|cc]:定义 FP16/BF16 类型,主要作为框架层低精度存储载体,同时作为 CPU fallback 的低精度类型,提供基础运算与数据类型转换能力;移除 TypeMap 后,框架层不再提供 [DataType -> 后端 C++ 类型] 映射;原先依赖模板特化的类型提升逻辑统一收敛为 PromoteDataTypes,以运行时函数形式提供 [DataType -> DataType] 的提升规则映射。
  • dispatcher.h:将 dtype dispatch 相关逻辑拆分出去后,该文件仅保留 kernel dispatch 逻辑。
  • dtype_dispatch.h:承接原 dispatcher.h 中的类型分发模板与宏,新增 HasMappedType 模板,用于在未注册类型实例化时通过 static_assert 提供明确报错;DispatchByTypeMap 替代原 DispatchFunc,新增 TypeMap 类模板参数,用于接收后端特化并注册的类型映射。
  • scalar.h:统一标量载体,为避免框架 API 为每种数值类型提供重载(解决 Tensor::Fill 之前需要模板特化实现的问题),引入 Scalar 作为统一标量载体。Scalar 采用 Kind + union 的轻量表示,将输入值存储为 bool / double / int64_t / uint64_t 四类宿主类型:所有浮点(含 FP16/BF16)存储为 double 类型,整数按符号存储为 int64_t / uint64_t 类型,bool 独立一份 Kind 类别(但存成 uint64_t 类型)。后续将扩展这套逻辑到其他接受标量类型的函数(例如 Tensor::Add),整体设计上参考 torch 实现 并进行了简化。
  • tensor.[h|cc]:将 Fill 函数的入参改为 Scalar 类型。

======================= 框架层通用头文件/实现改动 =======================

  • infini_train/include/common 目录:common_cpu.h 里修改了 cpu Cast 工具函数的实现,这个主要原因是 FP16/BF16 实现的数据类型转换操作有限,所以在这里进行了特判,涉及到这俩类型的统一用 float 作为中间过渡类型;common.h 里将之前散落在 datatype.h 里的通用宏挪了过来,同时把仅用于 dtype dispatch 的宏挪到了特定文件。
  • infini_train/[include|src]/core 目录:backend_type_map.h 声明了 BackendTypeMap,用于替代原 datatype.h 中的 TypeMap,负责 [DataType -> 后端 C++ 类型] 映射。该模板仅声明不定义,具体实现由各后端在 infini_train/src/core/runtime/[dev]/[dev]_dispatch.h 中显式注册;同时提供公共宏,用于批量注册非低精度类型。

======================= 后端新增实现 =======================

  • infini_train/src/core/runtime 目录:各后端在此注册 BackendTypeMap 实现,作为 DispatchByTypeMap 的类模板参数使用,并封装 Dispatch[Dev]Func 分发入口。
  • infini_train/src/kernels 目录:更新 kernel 的 dtype dispatch、Fill 调用模式、promotion 使用方式。
  • test/dtype:包含三个测试,test_dtype_dispatch.cc 验证运行时 dtype_dispatch 能正确转发到后端 C++ 类型,test_dtype_dispatch_compile_fail.cc 验证编译期能够拦截未注册类型的实例化,test_scalar.cc 验证 Scalar 行为的正确性。

@kilinchange kilinchange force-pushed the feat/dtype_registry branch 2 times, most recently from c4d8812 to 2b0c909 Compare March 12, 2026 10:44
@kilinchange kilinchange force-pushed the feat/dtype_registry branch 11 times, most recently from e63cb38 to 480d546 Compare April 7, 2026 03:11
@kilinchange kilinchange force-pushed the feat/dtype_registry branch 4 times, most recently from 1e8b87a to e13bb38 Compare April 9, 2026 10:24
…l promotion

Replace the implicit TypeMap fallback in BackendTypeMap with explicit per-backend
dtype registration (INFINI_REGISTER_STANDARD_BACKEND_TYPES), ensuring FP16/BF16
are only dispatched through backend-specific paths (DispatchCpuFunc/DispatchCudaFunc).

Migrate CUDA kernel promotion from concrete-type WidestType_t to pure DataType
enum-level PromoteDataTypes(), eliminating the need for backend scalar types at
promotion time. Replace runtime kDataTypeToSize map with constexpr DTypeSize().
@kilinchange
Copy link
Copy Markdown
Collaborator Author

kilinchange commented Apr 10, 2026

精度比对结果:
image

@kilinchange
Copy link
Copy Markdown
Collaborator Author

性能比对结果:
image

@kilinchange kilinchange force-pushed the feat/dtype_registry branch from e13bb38 to de1e14b Compare April 10, 2026 12:59
@kilinchange kilinchange changed the title [WIP] dtype registry dtype registry Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant