Open
Conversation
c4d8812 to
2b0c909
Compare
e63cb38 to
480d546
Compare
1e8b87a to
e13bb38
Compare
…l promotion Replace the implicit TypeMap fallback in BackendTypeMap with explicit per-backend dtype registration (INFINI_REGISTER_STANDARD_BACKEND_TYPES), ensuring FP16/BF16 are only dispatched through backend-specific paths (DispatchCpuFunc/DispatchCudaFunc). Migrate CUDA kernel promotion from concrete-type WidestType_t to pure DataType enum-level PromoteDataTypes(), eliminating the need for backend scalar types at promotion time. Replace runtime kDataTypeToSize map with constexpr DTypeSize().
Collaborator
Author
Collaborator
Author
e13bb38 to
de1e14b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.


本 PR 主要功能:抽象出框架级通用低精度类型 FP16/BF16(主要作为 storage type,但提供了必要的数值操作作为 fallback 逻辑),让框架代码不再直接依赖任何后端特有低精度类型(例如 __half/__nv_bfloat16);同时把框架 [DataType -> 后端 C++ 类型] 的映射改为显式注册机制,未注册的类型如果被实例化,会在编译期被拦截报错。
主要变更
docs目录:新增了俩设计文档,device_guard_design.md是之前飞书上的 device 注册设计文档,在 github 上补了一份,可以不用看;dtype_registry_design.md是这个 pr 工作的设计文档,可以重点结合代码看下。======================= 根目录下的通用头文件/实现改动 =======================
datatype.[h|cc]:定义FP16/BF16类型,主要作为框架层低精度存储载体,同时作为 CPU fallback 的低精度类型,提供基础运算与数据类型转换能力;移除TypeMap后,框架层不再提供[DataType -> 后端 C++ 类型]映射;原先依赖模板特化的类型提升逻辑统一收敛为PromoteDataTypes,以运行时函数形式提供[DataType -> DataType]的提升规则映射。dispatcher.h:将 dtype dispatch 相关逻辑拆分出去后,该文件仅保留 kernel dispatch 逻辑。dtype_dispatch.h:承接原dispatcher.h中的类型分发模板与宏,新增HasMappedType模板,用于在未注册类型实例化时通过static_assert提供明确报错;DispatchByTypeMap替代原DispatchFunc,新增TypeMap类模板参数,用于接收后端特化并注册的类型映射。scalar.h:统一标量载体,为避免框架 API 为每种数值类型提供重载(解决 Tensor::Fill 之前需要模板特化实现的问题),引入Scalar作为统一标量载体。Scalar采用Kind + union的轻量表示,将输入值存储为bool / double / int64_t / uint64_t四类宿主类型:所有浮点(含FP16/BF16)存储为double类型,整数按符号存储为int64_t / uint64_t类型,bool独立一份 Kind 类别(但存成 uint64_t 类型)。后续将扩展这套逻辑到其他接受标量类型的函数(例如 Tensor::Add),整体设计上参考 torch 实现 并进行了简化。tensor.[h|cc]:将 Fill 函数的入参改为 Scalar 类型。======================= 框架层通用头文件/实现改动 =======================
infini_train/include/common目录:common_cpu.h里修改了 cpu Cast 工具函数的实现,这个主要原因是 FP16/BF16 实现的数据类型转换操作有限,所以在这里进行了特判,涉及到这俩类型的统一用 float 作为中间过渡类型;common.h里将之前散落在datatype.h里的通用宏挪了过来,同时把仅用于 dtype dispatch 的宏挪到了特定文件。infini_train/[include|src]/core目录:backend_type_map.h声明了BackendTypeMap,用于替代原datatype.h中的TypeMap,负责[DataType -> 后端 C++ 类型]映射。该模板仅声明不定义,具体实现由各后端在infini_train/src/core/runtime/[dev]/[dev]_dispatch.h中显式注册;同时提供公共宏,用于批量注册非低精度类型。======================= 后端新增实现 =======================
infini_train/src/core/runtime目录:各后端在此注册BackendTypeMap实现,作为DispatchByTypeMap的类模板参数使用,并封装Dispatch[Dev]Func分发入口。test/dtype:包含三个测试,test_dtype_dispatch.cc验证运行时dtype_dispatch能正确转发到后端 C++ 类型,test_dtype_dispatch_compile_fail.cc验证编译期能够拦截未注册类型的实例化,test_scalar.cc验证Scalar行为的正确性。