Skip to content

feat: extract the common module of Transformer#115

Open
JYMiracle305 wants to merge 7 commits intomasterfrom
feat/transformer
Open

feat: extract the common module of Transformer#115
JYMiracle305 wants to merge 7 commits intomasterfrom
feat/transformer

Conversation

@JYMiracle305
Copy link
Copy Markdown
Contributor

@JYMiracle305 JYMiracle305 commented Mar 13, 2026

核心变更

抽象出Transformer类模型的构建架构,将GPT2和LLaMA3构建过程统一为一个流程实现,后续transformer类的模型支持使用统一流程。TransformerConfig 中的枚举字段直接控制模型行为,同一套 TransformerModel 代码,传入不同的 config 就构建出 GPT2 或 LLaMA3。后续新增 Transformer 类模型只需扩展枚举值并填写对应的 config,保持整体架构改动最小化。

目录结构

infini_train/include/nn/modules/
├── transformer/
│ ├── transformer_config.h # TransformerConfig + ModelType/AttentionType/MLPType/NormType 枚举
│ ├── transformer.h # TransformerLayer / FirstStage / Chunk / LastStage / TransformerModel 类
│ ├── causal_self_attention.h # CausalSelfAttention(支持 Standard/RoPE,含 GQA)
│ ├── mlp.h # MLP 模块(支持 GELU/SwiGLU)
│ └── utils.h # RoPE 辅助方法(PrecomputeFreqsCis)
├── activations.h # 激活函数声明(NewGELU / SwiGLU)
└── normalization.h # 归一化类声明(LayerNorm / RMSNorm)

infini_train/src/nn/modules/
├── modules/transformer/
│ ├── transformer.cc # TransformerModel / FirstStage / Chunk / LastStage / Layer 实现
│ ├── causal_self_attention.cc # CausalSelfAttention 实现
│ ├── mlp.cc # MLP 实现(支持 GELU / SwiGLU)
│ └── utils.cc # PrecomputeFreqsCis 实现
├──activations.cc # NewGELU / SwiGLU 定义和Forward实现
└──normalization.cc # LayerNorm / RMSNorm 定义和Forward实现

example/
├── gpt2/
│ ├── config.h # GPT2Config 预设(返回 TransformerConfig)
│ └── checkpoint_loader.h/.cc # GPT2 权重加载(使用统一 TransformerModel)
└── llama3/
├── config.h # LLaMA3Config 预设(返回 TransformerConfig)
└── checkpoint_loader.h/.cc # LLaMA3 权重加载(使用统一 TransformerModel)

@JYMiracle305
Copy link
Copy Markdown
Contributor Author

JYMiracle305 commented Mar 16, 2026

单机多卡:
GPT2:
image

LLaMA3:
image

多机训练结果:
GPT2:
image

LLaMA3:
image

@JYMiracle305 JYMiracle305 requested review from Chamberlain0w0, chen2021673 and kilinchange and removed request for Chamberlain0w0 March 16, 2026 05:42
@JYMiracle305 JYMiracle305 force-pushed the feat/transformer branch 3 times, most recently from dfdd913 to d833ec2 Compare March 16, 2026 08:10
first_stage.with_submodule(TransformerFirstStage::kWTELayerName, BuildVocabEmbeddingSpec(gpt2_config))
.with_submodule(TransformerFirstStage::kWPELayerName,
BuildPositionEmbeddingSpec(gpt2_config.block_size, gpt2_config.n_embd));
spec.with_submodule("first_stage", first_stage);

This comment was marked as resolved.


namespace infini_train::nn {

void ModuleRegistry::Register(std::type_index type, ModuleCreator creator) { registry_[type] = std::move(creator); }

This comment was marked as resolved.

This comment was marked as resolved.

auto tok_emb = (*modules_[kWTELayerName])({x1});

// Add position embedding only for models that use absolute position encoding
if (config_.attention_type == AttentionType::kStandard) {

This comment was marked as resolved.

This comment was marked as outdated.

// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
nn::TransformerConfig model_config;

This comment was marked as resolved.

This comment was marked as resolved.


// ========== GPT2 Model Definition ==========
// Uses LayerNorm, GELU activation, standard multi-head attention
class GPT2 : public nn::TransformerLayer {

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
AttentionType attention_type_;

This comment was marked as resolved.


// Architecture choices
AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type
MLPType mlp_type = MLPType::kGELU; // MLP activation type

This comment was marked as resolved.


namespace infini_train::nn {

class RMSNorm : public infini_train::nn::CloneableModule<RMSNorm> {

This comment was marked as resolved.

This comment was marked as resolved.

modules_[kCFcLayerName] = build_module(config, spec.submodules_.at(kCFcLayerName));

// For SwiGLU, add second projection
if (spec.submodules_.count(kCFc2LayerName) > 0) {

This comment was marked as resolved.


// ========== LLaMA3 Model Definition ==========
// Uses RMSNorm, SwiGLU activation, GQA attention, RoPE positional encoding
class LLaMA3 : public nn::TransformerLayer {

This comment was marked as resolved.

@Chamberlain0w0

This comment was marked as outdated.

@JYMiracle305 JYMiracle305 force-pushed the feat/transformer branch 4 times, most recently from 6ba15c3 to 2ac0526 Compare March 26, 2026 03:30

static constexpr char kParamBiasName[] = "bias";

explicit CausalSelfAttention(const TransformerConfig &config, const ModuleSpec &spec = {});

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as resolved.

class ModuleRegistry {
public:
static ModuleRegistry &Instance() {
static ModuleRegistry inst;

This comment was marked as resolved.

auto norm = x[0] * nn::function::Rsqrt(nn::function::Mean(nn::function::Pow(x[0], 2), -1, true) + eps_);
return {norm * parameters_[kParamWeightName]};
}
} // namespace infini_train::nn No newline at end of file

This comment was marked as resolved.

return spec;
}

ModuleSpec BuildTransformerBlockSpec(const TransformerConfig &config) {

This comment was marked as resolved.

This comment was marked as resolved.

struct ModuleSpec {
ModuleSpec() = default;

explicit ModuleSpec(std::type_index m) : module_(m) {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉引入 type_index 有点奇怪,ModuleSpec 绑定 module 的意义是什么呢?

Copy link
Copy Markdown
Contributor Author

@JYMiracle305 JYMiracle305 Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type_index 作为在注册表中查找具体构造函数的键。
Megatron中的ModuleSpec支持根据module或submodules进行构建。

@kilinchange

This comment was marked as resolved.

@JYMiracle305 JYMiracle305 force-pushed the feat/transformer branch 3 times, most recently from d97d661 to 5cec43f Compare April 1, 2026 08:35
@kilinchange kilinchange self-requested a review April 3, 2026 02:10
@JYMiracle305 JYMiracle305 force-pushed the feat/transformer branch 3 times, most recently from fab42f1 to 2e4611c Compare April 3, 2026 07:11

This comment was marked as resolved.

.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.max_gen_batch_size = max_gen_bs});
nn::TransformerConfig llama3_config = nn::llama3::LLaMA3Config();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要换一种写法?不是必要的话还原回去吧。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了适配新的config结构,先调用各自的初始化函数,保证参数是属于本模型的,再去根据读入的数据修改


static std::shared_ptr<DecoderOnlyTransformer> FromLLMC_GPT2(const std::string &filepath);
static std::shared_ptr<DecoderOnlyTransformer> FromLLMC_LLaMA3(const std::string &filepath);
static void LoadWeightsFromLLMC(const std::string &filepath, DecoderOnlyTransformer *model,

This comment was marked as resolved.

static std::shared_ptr<DecoderOnlyTransformer> FromPretrained(ModelType model_type);

static std::shared_ptr<DecoderOnlyTransformer> FromLLMC_GPT2(const std::string &filepath);
static std::shared_ptr<DecoderOnlyTransformer> FromLLMC_LLaMA3(const std::string &filepath);

This comment was marked as resolved.

This comment was marked as resolved.

This comment was marked as resolved.

INFINI_TRAIN_REGISTER_MODULE(CausalSelfAttention);
INFINI_TRAIN_REGISTER_MODULE(MLP);

// NewGELU
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上下面这些基础 module 不需要注册了,需要明确的一点是,spec 只表示到 mlp/attention 这层;对于基础 module,在 mlp/attention 这类上层 module 的构造函数中,通过 spec 解析需要的参数,传参给基础 module 构造函数直接构造就行。


using ModuleCreator = std::function<std::shared_ptr<Module>(const TransformerConfig &, const ModuleSpec &)>;

class ModuleRegistry {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于所有需要 spec 构造的 module 的 creator 都是明确的了,所以似乎也不需要 Registry 了,从而 ModuleSpec 也不需要通过 type_index 绑定 module 信息了。

return *value;
}

std::shared_ptr<Module> BuildModule(const TransformerConfig &config, const ModuleSpec &spec);

This comment was marked as resolved.

This comment was marked as resolved.

@JYMiracle305 JYMiracle305 force-pushed the feat/transformer branch 6 times, most recently from 9d7cd4c to 9a8cce4 Compare April 9, 2026 07:06
@JYMiracle305 JYMiracle305 requested a review from kilinchange April 9, 2026 07:12
      - Remove ModuleRegistry and INFINI_TRAIN_REGISTER_MODULE macros
      - Replace BuildModule() with direct constructor calls
      - Simplify module instantiation in MLP, CausalSelfAttention, and TransformerLayer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants