diff --git a/.gitignore b/.gitignore index b728e6ea..1d4781b7 100644 --- a/.gitignore +++ b/.gitignore @@ -30,4 +30,7 @@ __pycache__/ *.http -*.nsys-rep +**/*.nsys-rep +**/*.jsonl +*.jsonl +**/*.mem diff --git a/csrc/models/minicpm_sala/minicpm_sala_allocate_kv_cache_tensors.cpp b/csrc/models/minicpm_sala/minicpm_sala_allocate_kv_cache_tensors.cpp index f4cb3b55..3ad0b506 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_allocate_kv_cache_tensors.cpp +++ b/csrc/models/minicpm_sala/minicpm_sala_allocate_kv_cache_tensors.cpp @@ -32,7 +32,7 @@ std::vector minicpm_sala_allocate_kv_cache_tensors(const cac const size_t num_key_value_heads = text_config->get("num_key_value_heads"); const size_t max_position_embeddings = text_config->get("max_position_embeddings"); - const auto &dtype{text_config->get_dtype()}; + const auto &dtype{text_config->get_kv_cache_dtype()}; std::vector mixer_types = text_config->get>("mixer_types"); size_t current_layer_head_dim, current_layer_num_key_value_heads; for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { @@ -70,7 +70,7 @@ std::vector minicpm_sala_allocate_kv_cache_tensors(const cac const size_t head_dim = text_config->get("head_dim"); const size_t num_key_value_heads = text_config->get("num_key_value_heads"); - const auto &dtype{text_config->get_dtype()}; + const auto &dtype{text_config->get_kv_cache_dtype()}; std::vector mixer_types = text_config->get>("mixer_types"); size_t current_layer_head_dim, current_layer_num_key_value_heads; for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp index c1e20f76..27cb2275 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp @@ -1,133 +1,550 @@ #include "minicpm_sala_attention.hpp" + +#include "infinicore/ops.hpp" +#include "infinicore/ops/infllmv2_attention.hpp" +#include "infinicore/ops/simple_gla_attention.hpp" +#include "infinicore/ops/simple_gla_decode_step.hpp" +#include "infinicore/ops/simple_gla_prefill.hpp" +#include "infinicore/ops/simple_gla_recurrent_state_append.hpp" +#include "infinicore/context/context.hpp" #include "../../global_state/global_state.hpp" +#include "../debug_utils/tensor_utils.hpp" + +#include +#include +#include #include +#include namespace infinilm::models::minicpm_sala { -AttentionBase::AttentionBase(std::shared_ptr model_config, - size_t num_attention_heads, - size_t num_key_value_heads, - size_t layer_idx, - const infinicore::Device &device) - : layer_idx_(layer_idx), - hidden_size_(model_config->get("hidden_size")), - head_dim_(model_config->get("head_dim")) { - - const auto &dtype{model_config->get_dtype()}; - - use_bias_ = model_config->get_or("attention_bias", true); - use_output_bias_ = model_config->get_or("attention_output_bias", false); - - attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; - const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); - int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); - int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); - - const size_t total_num_heads = num_attention_heads; - const size_t total_num_kv_heads = num_key_value_heads; - if ((total_num_kv_heads < static_cast(tp_size)) || (0 != (total_num_kv_heads % static_cast(tp_size)))) { - throw std::runtime_error("infinilm::models::minicpm_sala::AttentionBase: num_key_value_heads must be divisible by tp_size"); - } - - num_attention_heads_ = total_num_heads / static_cast(tp_size); - num_key_value_heads_ = total_num_kv_heads / static_cast(tp_size); - - auto quant_scheme = model_config->get_quant_scheme(); - auto quantization_method = model_config->get_quantization_method(); - switch (quant_scheme) { - case infinicore::quantization::QuantScheme::NONE: - INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, total_num_heads * head_dim_, quantization_method, - use_bias_, dtype, device, tp_rank, tp_size); - INFINICORE_NN_MODULE_INIT(k_proj, hidden_size_, total_num_kv_heads * head_dim_, quantization_method, - use_bias_, dtype, device, tp_rank, tp_size); - INFINICORE_NN_MODULE_INIT(v_proj, hidden_size_, total_num_kv_heads * head_dim_, quantization_method, - use_bias_, dtype, device, tp_rank, tp_size); - INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * head_dim_, hidden_size_, quantization_method, - use_output_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); - break; - default: - throw std::runtime_error("infinilm::models::minicpm_sala::AttentionBase: unsupported quantization scheme"); - break; +namespace { + +// Per-layer KV tensor layout from `StaticKVCache::create_layer_kv_cache`: [2, B, n_kv, max_len, D]. +void minicpm_sala_update_layer_kv_tensor(infinicore::Tensor &kv_bundle, + const infinicore::Tensor &k_permuted, + const infinicore::Tensor &v_permuted, + const infinicore::Tensor &past_sequence_lengths) { + auto k_cache_layer = kv_bundle->narrow({{0, 0, 1}})->squeeze(0); + auto v_cache_layer = kv_bundle->narrow({{0, 1, 1}})->squeeze(0); + +#ifdef ENABLE_KV_CACHING + infinicore::op::kv_caching_( + k_cache_layer, + v_cache_layer, + k_permuted, + v_permuted, + past_sequence_lengths); +#else + const size_t cache_pos = static_cast( + reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]); + const size_t update_len = k_permuted->size(2); + const size_t result_len = cache_pos + update_len; + if (result_len > k_cache_layer->size(2)) { + throw std::runtime_error("MiniCPMSALAAttention(KV update): KV cache length exceeded"); + } + k_cache_layer->narrow({{2, cache_pos, update_len}})->copy_from(k_permuted); + v_cache_layer->narrow({{2, cache_pos, update_len}})->copy_from(v_permuted); +#endif +} + +// Same as HF MiniCPM-SALA _build_slope_tensor (used for Simple GLA decay). +std::vector build_slope_tensor(size_t n) { + auto get_slopes_power_of_2 = [](size_t n) -> std::vector { + double log2n = std::log2(static_cast(n)); + double start = std::pow(2.0, -(std::pow(2.0, -(log2n - 3)))); + double ratio = start; + std::vector out; + out.reserve(n); + for (size_t i = 0; i < n; ++i) { + out.push_back(static_cast(start * std::pow(ratio, static_cast(i)))); + } + return out; + }; + if (n == 0) return {}; + double log2n = std::log2(static_cast(n)); + if (std::abs(log2n - std::floor(log2n)) < 1e-9) { + return get_slopes_power_of_2(n); + } + size_t closest = static_cast(std::pow(2.0, std::floor(log2n))); + auto first = get_slopes_power_of_2(closest); + auto rest = build_slope_tensor(2 * closest); + for (size_t i = 0; i < n - closest; ++i) { + first.push_back(rest[i * 2]); + } + return first; +} + +} // namespace + +namespace { +void ensure_gla_state_allocated(infinicore::Tensor &state, + const infinicore::Device &device, + size_t batch_size, + size_t n_h, + size_t head_dim) { + const std::vector want = {batch_size, n_h, head_dim, head_dim}; + if (!state || state->shape() != want || state->dtype() != infinicore::DataType::F32 || state->device() != device) { + state = infinicore::Tensor::zeros(want, infinicore::DataType::F32, device); } +} +} // namespace + +MiniCPMSALALightningAttention::MiniCPMSALALightningAttention(std::shared_ptr model_config, + const infinicore::Device &device, + size_t layer_idx) + : layer_idx_(layer_idx) { + const auto dtype = model_config->get_dtype(); + const size_t hidden_size = model_config->get("hidden_size"); + num_attention_heads_ = model_config->get_or("lightning_nh", model_config->get("num_attention_heads")); + num_key_value_heads_ = model_config->get_or("lightning_nkv", model_config->get("num_key_value_heads")); + head_dim_ = model_config->get_or("lightning_head_dim", model_config->get("head_dim")); + scaling_ = static_cast(1.0 / std::sqrt(static_cast(head_dim_))); + + use_rope_ = model_config->get_or("lightning_use_rope", true); rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); - float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); - attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, - num_key_value_heads_, layer_idx_, - kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + use_qk_norm_ = model_config->get_or("qk_norm", true); + use_output_gate_ = model_config->get_or("use_output_gate", true); + + INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, num_attention_heads_ * head_dim_, false, dtype, device); + INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, num_key_value_heads_ * head_dim_, false, dtype, device); + INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, num_key_value_heads_ * head_dim_, false, dtype, device); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size, false, dtype, device); - auto kv_quant_scheme = infinilm::global_state::get_infinilm_config().model_config->get_kv_quant_scheme(); - switch (kv_quant_scheme) { - case (infinicore::quantization::KVQuantAlgo::NONE): { - break; + if (use_qk_norm_) { + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config->get("rms_norm_eps"), dtype, device); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config->get("rms_norm_eps"), dtype, device); } - case (infinicore::quantization::KVQuantAlgo::INT8): { - INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); - INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); - break; + use_output_norm_ = true; + INFINICORE_NN_MODULE_INIT(o_norm, hidden_size, model_config->get("rms_norm_eps"), dtype, device); + INFINICORE_NN_MODULE_INIT(z_proj, hidden_size, hidden_size, false, dtype, device); + + std::vector slopes = build_slope_tensor(num_attention_heads_); + auto g_cpu = infinicore::Tensor::empty( + {num_attention_heads_}, infinicore::DataType::F32, infinicore::Device::cpu()); + float *ptr = reinterpret_cast(g_cpu->data()); + for (size_t h = 0; h < num_attention_heads_; ++h) + ptr[h] = -slopes[h]; + g_gamma_ = g_cpu->to(device); +} + +void MiniCPMSALALightningAttention::reset_state() { + gla_state_valid_ = false; + gla_state_cached_len_ = 0; + gla_state_ = {}; +} + +infinicore::Tensor MiniCPMSALALightningAttention::forward(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + const auto &attn_meta = infinilm::global_state::get_forward_context().attn_metadata; + auto past_sequence_lengths = attn_meta.past_sequence_lengths; + auto total_sequence_lengths = attn_meta.total_sequence_lengths; + auto cu_seqlens = attn_meta.cu_seqlens; + // input_offsets/block_tables/slot_mapping are not used in this dense/per-layer-kv implementation yet. + (void)cu_seqlens; + // Input: [B, S, H] + auto shape = hidden_states->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; + + auto hs_mut = hidden_states; + auto q = q_proj_->forward(hs_mut); + auto k = k_proj_->forward(hs_mut); + auto v = v_proj_->forward(hs_mut); + // View requires contiguous layout; only call contiguous when needed (proj output often already contiguous). + auto q_reshaped = q->contiguous()->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + if (use_qk_norm_) { + // RMSNorm op only supports 2D/3D; normalize over head_dim with a 3D view. + auto q3 = q_reshaped->view({batch_size * seq_len, num_attention_heads_, head_dim_}); + auto k3 = k_reshaped->view({batch_size * seq_len, num_key_value_heads_, head_dim_}); + q3 = q_norm_->forward(q3); + k3 = k_norm_->forward(k3); + q_reshaped = q3->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + k_reshaped = k3->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); } - default: { - throw std::runtime_error("infinilm::layers::attention: unsupported kv_quant_scheme"); - break; + + // RoPE only for lightning layers (HyPE) + if (use_rope_) { + if (!rotary_emb_) { + throw std::runtime_error("MiniCPMSALALightningAttention: rotary_emb is not set but use_rope=true"); + } + // position_ids can be [B,S] or [S]; follow LlamaAttention behavior. + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids->contiguous(); + } else { + throw std::runtime_error("MiniCPMSALALightningAttention: Unexpected position_ids shape"); + } + + rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); } + + // Compute dense attention (GQA): reshape as LlamaAttention does + size_t total_seq_len = seq_len; + size_t cache_pos = 0; + const bool has_cache_meta = past_sequence_lengths.has_value() && total_sequence_lengths.has_value(); + if (has_cache_meta) { + auto past_cpu = past_sequence_lengths.value()->to(infinicore::Device::cpu()); + cache_pos = reinterpret_cast(past_cpu->data())[0]; + // `total_sequence_lengths` may be input length (e.g. 1 on decode); KV length is cache_pos + seq_len. + total_seq_len = cache_pos + seq_len; + } else if (total_sequence_lengths.has_value()) { + total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; } -} -InfLLMv2Attention::InfLLMv2Attention(std::shared_ptr model_config, - size_t layer_idx, - const infinicore::Device &device) - : AttentionBase(model_config, - model_config->get("num_attention_heads"), - model_config->get("num_key_value_heads"), - layer_idx, device) { - use_output_gate_ = model_config->get_or("use_output_gate", false); - const auto &dtype{model_config->get_dtype()}; - size_t num_attention_heads = model_config->get("num_attention_heads"); + // Cache expects [B, n_kv, S, D]. Keep this as a strided view and let the caching op handle strides + // to avoid a full rearrange (permute->contiguous) copy on long-context prefill. + // Correctness: kv_caching_ / StaticKVCache::update is sensitive to input stride/layout. + // Restore contiguous to match HF logits exactly before re-applying any strided optimizations. + auto k_permuted = k_reshaped->permute({0, 2, 1, 3})->contiguous(); // [B, n_kv, S, D] + auto v_permuted = v_reshaped->permute({0, 2, 1, 3})->contiguous(); // [B, n_kv, S, D] + + // Per-layer KV tensors in `global_state::get_forward_context().kv_cache_vec` (same pattern as + // `InfinilmModel::reset_cache` / `StaticAttentionImpl`). + infinicore::Tensor k_total = k_permuted; + infinicore::Tensor v_total = v_permuted; + bool use_forward_kv = false; + if (has_cache_meta) { + auto &kv_vec = infinilm::global_state::get_forward_context().kv_cache_vec; + if (layer_idx_ >= kv_vec.size()) { + throw std::runtime_error( + "MiniCPMSALALightningAttention: forward_context.kv_cache_vec is unset or too small (call reset_cache / align layer count)"); + } + use_forward_kv = true; + minicpm_sala_update_layer_kv_tensor( + kv_vec[layer_idx_], + k_permuted, + v_permuted, + past_sequence_lengths.value()); + auto k_cache_layer = kv_vec[layer_idx_]->narrow({{0, 0, 1}})->squeeze(0); + auto v_cache_layer = kv_vec[layer_idx_]->narrow({{0, 1, 1}})->squeeze(0); + k_total = k_cache_layer; + v_total = v_cache_layer; + } else { + total_seq_len = seq_len; + } + + // Slice to total_seq_len (decode-only / cont-batch) + if (total_seq_len > k_total->shape()[2]) { + throw std::runtime_error("MiniCPMSALALightningAttention: total_seq_len exceeds available KV length (cache not correctly updated)"); + } + k_total = k_total->narrow({{2, 0, total_seq_len}}); + v_total = v_total->narrow({{2, 0, total_seq_len}}); + + infinicore::Tensor attn_output; + { + // Lightning-attn only: Simple GLA (HF-aligned). + // simple_gla_attention(q,k,v,g_gamma,scale) expects [B, T, H, D]; g_gamma [H]. + const size_t n_h = num_attention_heads_; + const size_t n_kv = num_key_value_heads_; + infinicore::Tensor k_use = k_total; + infinicore::Tensor v_use = v_total; + if (n_kv < n_h) { + // Repeat KV heads to match n_h (same as HF repeat_kv / repeat_interleave). + // Use as_strided view then contiguous() so one copy instead of n_h narrow/copy_from calls. + const size_t ngroup = n_h / n_kv; + const std::vector repeat_strides = { + static_cast(n_kv * total_seq_len * head_dim_), + static_cast(total_seq_len * head_dim_), + 0, + static_cast(head_dim_), + 1, + }; + k_use = k_total->as_strided( + {batch_size, n_kv, ngroup, total_seq_len, head_dim_}, repeat_strides) + ->contiguous() + ->view({batch_size, n_h, total_seq_len, head_dim_}); + v_use = v_total->as_strided( + {batch_size, n_kv, ngroup, total_seq_len, head_dim_}, repeat_strides) + ->contiguous() + ->view({batch_size, n_h, total_seq_len, head_dim_}); + } + // GLA expects [B, S, H, D]. `q_reshaped` is already [B, S, H, D], so avoid permute+contiguous. + auto q_bthd = q_reshaped; // [B, S_q, H, D] + // Correctness: restore contiguous layout for K/V before `simple_gla_attention`. + auto k_bthd = k_use->permute({0, 2, 1, 3})->contiguous(); // [B, S_kv, H, D] + auto v_bthd = v_use->permute({0, 2, 1, 3})->contiguous(); // [B, S_kv, H, D] + + // Lightning fast decode: maintain recurrent state locally (do NOT depend on StaticKVCache extensions). + // We rebuild state on-demand if it is out-of-sync with cache_pos. + const bool is_decode = has_cache_meta && use_forward_kv && (seq_len == 1) && (total_seq_len > 1); + if (is_decode) { + ensure_gla_state_allocated(gla_state_, q_bthd->device(), batch_size, n_h, head_dim_); + + // Ensure `state` corresponds to exactly `cache_pos` cached tokens (excluding current token). + if (!gla_state_valid_ || gla_state_cached_len_ != cache_pos) { + // Rebuild from available KV. This is O(T) once after reset / mismatch. + infinicore::op::zeros_(gla_state_); + if (cache_pos > 0) { + auto k_prev = k_bthd->narrow({{1, 0, cache_pos}}); + auto v_prev = v_bthd->narrow({{1, 0, cache_pos}}); + infinicore::op::simple_gla_recurrent_state_append_segment(gla_state_, k_prev, v_prev, g_gamma_); + } + gla_state_cached_len_ = cache_pos; + gla_state_valid_ = true; + } + + // Decode-step uses only the newest KV at position (total_seq_len - 1). + auto q_new = q_bthd; // [B,1,H,D] + auto k_new = k_bthd->narrow({{1, total_seq_len - 1, 1}}); + auto v_new = v_bthd->narrow({{1, total_seq_len - 1, 1}}); + auto out_b1hd = infinicore::op::simple_gla_decode_step(q_new, k_new, v_new, gla_state_, g_gamma_, scaling_); + gla_state_cached_len_ = cache_pos + 1; + attn_output = out_b1hd->view({batch_size, seq_len, n_h * head_dim_}); + // Fall through to output norm/gate + o_proj below (do not run full-sequence GLA again). + } else { + // Prefill / non-decode batching: non-recurrent kernels, then update local recurrent state. + infinicore::Tensor q_full; + if (seq_len == total_seq_len) { + q_full = q_bthd; + } else { + // q shorter than KV: pad q to [B, total_seq_len, H, D]. + q_full = infinicore::Tensor::zeros( + {batch_size, total_seq_len, n_h, head_dim_}, q_bthd->dtype(), q_bthd->device()); + auto q_slot = q_full->narrow({{1, total_seq_len - seq_len, seq_len}}); + q_slot->copy_from(q_bthd); + } + + infinicore::Tensor gla_out; + // Fused prefill: naive kernel for head_dim<=64; chunked/tiled kernel for head_dim>64 (e.g. 128). + bool use_fused_prefill = (batch_size == 1) && (seq_len == total_seq_len); + if (use_fused_prefill) { + gla_out = infinicore::op::simple_gla_prefill(q_full, k_bthd, v_bthd, g_gamma_, scaling_); + } else { + gla_out = infinicore::op::simple_gla_attention(q_full, k_bthd, v_bthd, g_gamma_, scaling_); + } + + // Keep local recurrent state in sync for subsequent decode steps. + ensure_gla_state_allocated(gla_state_, q_bthd->device(), batch_size, n_h, head_dim_); + if (cache_pos == 0) { + infinicore::op::zeros_(gla_state_); + gla_state_cached_len_ = 0; + gla_state_valid_ = true; + } + // Append the segment we just wrote: [cache_pos, cache_pos + seq_len) + if (gla_state_valid_ && gla_state_cached_len_ == cache_pos) { + auto k_seg = k_bthd->narrow({{1, cache_pos, seq_len}}); + auto v_seg = v_bthd->narrow({{1, cache_pos, seq_len}}); + infinicore::op::simple_gla_recurrent_state_append_segment(gla_state_, k_seg, v_seg, g_gamma_); + gla_state_cached_len_ = cache_pos + seq_len; + } else { + // Out-of-sync; force rebuild next time we need recurrent decode. + gla_state_valid_ = false; + } + + infinicore::Tensor out_slice = gla_out->narrow({{1, total_seq_len - seq_len, seq_len}}); + attn_output = out_slice->view({batch_size, seq_len, n_h * head_dim_}); + } + } + + // Lightning output gate/norm if (use_output_gate_) { - INFINICORE_NN_MODULE_INIT(o_gate, hidden_size_, num_attention_heads * head_dim_, - model_config->get_quantization_method(), use_bias_, dtype, device); + auto z_in = hidden_states; + auto z = z_proj_->forward(z_in); + infinicore::op::sigmoid_(z, z); + if (use_output_norm_ && o_norm_) { + attn_output = o_norm_->forward(attn_output); + } + attn_output = infinicore::op::mul(attn_output, z); + } else if (use_output_norm_ && o_norm_) { + attn_output = o_norm_->forward(attn_output); + } + + auto attn_out_mut = attn_output; + auto out = o_proj_->forward(attn_out_mut); + + return out; +} + +MiniCPMSALAMinicpm4Attention::MiniCPMSALAMinicpm4Attention(std::shared_ptr model_config, + const infinicore::Device &device, + size_t layer_idx) + : layer_idx_(layer_idx) { + (void)device; + const auto dtype = model_config->get_dtype(); + const size_t hidden_size = model_config->get("hidden_size"); + num_attention_heads_ = model_config->get("num_attention_heads"); + num_key_value_heads_ = model_config->get("num_key_value_heads"); + head_dim_ = model_config->get("head_dim"); + scaling_ = static_cast(1.0 / std::sqrt(static_cast(head_dim_))); + + int sparse_window_size = model_config->get_or("sparse_window_size", -1); + if (sparse_window_size <= 0) { + auto sparse_cfg = model_config->get_or("sparse_config", nlohmann::json{}); + if (!sparse_cfg.is_null() && sparse_cfg.contains("window_size")) { + sparse_window_size = sparse_cfg["window_size"].get(); + } else { + sparse_window_size = model_config->get_or("window_size", -1); + } } + if (sparse_window_size > 0) { + infllmv2_window_left_ = sparse_window_size; + use_local_window_ = true; + } + + INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, num_attention_heads_ * head_dim_, false, dtype, device); + INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, num_key_value_heads_ * head_dim_, false, dtype, device); + INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, num_key_value_heads_ * head_dim_, false, dtype, device); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size, false, dtype, device); + INFINICORE_NN_MODULE_INIT(o_gate, hidden_size, hidden_size, false, dtype, device); } -infinicore::Tensor InfLLMv2Attention::forward(const infinicore::Tensor &positions, - const infinicore::Tensor &hidden_states) const { - spdlog::error("InfLLMv2Attention is not implemented"); - return hidden_states; +void MiniCPMSALAMinicpm4Attention::reset_state() { + // no local recurrent state } -LightningAttention::LightningAttention(std::shared_ptr model_config, - size_t layer_idx, - const infinicore::Device &device) - : AttentionBase(model_config, - model_config->get("num_attention_heads"), - model_config->get("lightning_nkv"), - layer_idx, device) { +infinicore::Tensor MiniCPMSALAMinicpm4Attention::forward(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + (void)position_ids; + const auto &attn_meta = infinilm::global_state::get_forward_context().attn_metadata; + auto past_sequence_lengths = attn_meta.past_sequence_lengths; + auto total_sequence_lengths = attn_meta.total_sequence_lengths; - qk_norm_ = model_config->get_or("qk_norm", false); - use_output_norm_ = model_config->get_or("use_output_norm", false); - use_output_gate_ = model_config->get_or("use_output_gate", false); - const auto &dtype{model_config->get_dtype()}; - double rms_norm_eps = model_config->get("rms_norm_eps"); - size_t num_attention_heads = model_config->get("num_attention_heads"); + auto shape = hidden_states->shape(); + const size_t batch_size = shape[0]; + const size_t seq_len = shape[1]; - if (qk_norm_) { - INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + auto hs_mut = hidden_states; + auto q = q_proj_->forward(hs_mut); + auto k = k_proj_->forward(hs_mut); + auto v = v_proj_->forward(hs_mut); + auto q_reshaped = q->contiguous()->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + // KV update via per-layer kv_cache_vec when metadata present + size_t total_seq_len = seq_len; + size_t cache_pos = 0; + const bool has_cache_meta = past_sequence_lengths.has_value() && total_sequence_lengths.has_value(); + if (has_cache_meta) { + auto past_cpu = past_sequence_lengths.value()->to(infinicore::Device::cpu()); + cache_pos = reinterpret_cast(past_cpu->data())[0]; + total_seq_len = cache_pos + seq_len; } - if (use_output_norm_) { - INFINICORE_NN_MODULE_INIT(o_norm, num_attention_heads * head_dim_, rms_norm_eps, dtype, device); + auto k_permuted = k_reshaped->permute({0, 2, 1, 3})->contiguous(); + auto v_permuted = v_reshaped->permute({0, 2, 1, 3})->contiguous(); + + infinicore::Tensor k_total = k_permuted; + infinicore::Tensor v_total = v_permuted; + bool use_forward_kv = false; + if (has_cache_meta) { + auto &kv_vec = infinilm::global_state::get_forward_context().kv_cache_vec; + if (layer_idx_ >= kv_vec.size()) { + throw std::runtime_error( + "MiniCPMSALAMinicpm4Attention: forward_context.kv_cache_vec is unset or too small"); + } + use_forward_kv = true; + minicpm_sala_update_layer_kv_tensor( + kv_vec[layer_idx_], + k_permuted, + v_permuted, + past_sequence_lengths.value()); + auto k_cache_layer = kv_vec[layer_idx_]->narrow({{0, 0, 1}})->squeeze(0); + auto v_cache_layer = kv_vec[layer_idx_]->narrow({{0, 1, 1}})->squeeze(0); + k_total = k_cache_layer; + v_total = v_cache_layer; + } else { + total_seq_len = seq_len; } - if (use_output_gate_) { - INFINICORE_NN_MODULE_INIT(z_proj, hidden_size_, num_attention_heads * head_dim_, - model_config->get_quantization_method(), use_bias_, dtype, device); + + if (total_seq_len > k_total->shape()[2]) { + throw std::runtime_error("MiniCPMSALAMinicpm4Attention: total_seq_len exceeds available KV length"); } -} + k_total = k_total->narrow({{2, 0, total_seq_len}}); + v_total = v_total->narrow({{2, 0, total_seq_len}}); + + try { + if (!total_sequence_lengths.has_value()) { + throw std::runtime_error("MiniCPMSALAMinicpm4Attention: total_sequence_lengths is required for InfLLM-v2 path"); + } + const auto cache_lens = total_sequence_lengths.value(); + const bool force_varlen_decode = [&]() { + const char *env = std::getenv("INFINI_MINICPM4_DECODE_VARLEN"); + return env && env[0] != '\0' && env[0] != '0'; + }(); -infinicore::Tensor LightningAttention::forward(const infinicore::Tensor &positions, - const infinicore::Tensor &hidden_states) const { - spdlog::error("LightningAttention is not implemented"); - return hidden_states; + infinicore::Tensor attn_output; + if (seq_len == total_seq_len || (force_varlen_decode && batch_size == 1)) { + if (batch_size != 1) { + throw std::runtime_error("MiniCPMSALAMinicpm4Attention: varlen path requires batch_size=1"); + } + auto q_bshd = q_reshaped->contiguous(); + auto k_btkd = k_total->permute({0, 2, 1, 3})->contiguous(); + auto v_btkd = v_total->permute({0, 2, 1, 3})->contiguous(); + auto q_var = q_bshd->view({static_cast(seq_len), static_cast(num_attention_heads_), static_cast(head_dim_)}); + auto k_var = k_btkd->view({static_cast(total_seq_len), static_cast(num_key_value_heads_), static_cast(head_dim_)}); + auto v_var = v_btkd->view({static_cast(total_seq_len), static_cast(num_key_value_heads_), static_cast(head_dim_)}); + + auto cuq_cpu = infinicore::Tensor::empty({2}, infinicore::DataType::I32, infinicore::Device::cpu()); + reinterpret_cast(cuq_cpu->data())[0] = 0; + reinterpret_cast(cuq_cpu->data())[1] = static_cast(seq_len); + infinicore::Tensor cu_q = cuq_cpu->to(q_var->device()); + auto cuk_cpu = infinicore::Tensor::empty({2}, infinicore::DataType::I32, infinicore::Device::cpu()); + reinterpret_cast(cuk_cpu->data())[0] = 0; + reinterpret_cast(cuk_cpu->data())[1] = static_cast(total_seq_len); + infinicore::Tensor cu_k = cuk_cpu->to(q_var->device()); + + const bool infllmv2_causal = !use_local_window_; + const int window_left = use_local_window_ ? infllmv2_window_left_ : -1; + const int window_right = use_local_window_ ? 0 : -1; + + auto out_var = infinicore::op::infllmv2_varlen( + q_var, k_var, v_var, + cu_q, cu_k, + static_cast(seq_len), + static_cast(total_seq_len), + scaling_, + /*causal=*/infllmv2_causal, + /*window_size_left=*/window_left, + /*window_size_right=*/window_right); + attn_output = out_var->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); + } else if (use_forward_kv) { + if (batch_size != 1) { + throw std::runtime_error("MiniCPMSALAMinicpm4Attention: kvcache decode requires batch_size=1"); + } + auto q_bshd = q_reshaped->contiguous(); + auto k_bthd = k_total->permute({0, 2, 1, 3})->contiguous(); + auto v_bthd = v_total->permute({0, 2, 1, 3})->contiguous(); + + const bool infllmv2_causal = !use_local_window_; + const int window_left = use_local_window_ ? infllmv2_window_left_ : -1; + const int window_right = use_local_window_ ? 0 : -1; + + auto out_bshd = infinicore::op::infllmv2_kvcache( + q_bshd, + k_bthd, + v_bthd, + cache_lens, + scaling_, + /*causal=*/infllmv2_causal, + /*window_size_left=*/window_left, + /*window_size_right=*/window_right); + attn_output = out_bshd->contiguous()->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); + } else { + throw std::runtime_error("MiniCPMSALAMinicpm4Attention: decode requires KV cache"); + } + + // Sparse gate + o_proj + auto gate = o_gate_->forward(hs_mut); + infinicore::op::sigmoid_(gate, gate); + attn_output = infinicore::op::mul(attn_output, gate); + auto out = o_proj_->forward(attn_output); + return out; + } catch (const std::exception &e) { + throw std::runtime_error( + std::string("MiniCPMSALAMinicpm4Attention: InfLLM-v2 attention failed. ") + + "Original error: " + e.what()); + } } } // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.hpp b/csrc/models/minicpm_sala/minicpm_sala_attention.hpp index 81a032b6..9af665aa 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_attention.hpp +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.hpp @@ -1,88 +1,108 @@ #pragma once -#include "../../layers/common_modules.hpp" +#include "../../config/model_config.hpp" +#include "../../layers/rotary_embedding/rotary_embedding.hpp" -namespace infinilm::layers::attention { -class AttentionLayer; -} +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/tensor.hpp" + +#include +#include namespace infinilm::models::minicpm_sala { -class AttentionBase : public infinicore::nn::Module { -protected: - AttentionBase(std::shared_ptr model_config, - size_t num_attention_heads, - size_t num_key_value_heads, - size_t layer_idx, - const infinicore::Device &device); +class MiniCPMSALAAttentionBase : public infinicore::nn::Module { +public: + virtual infinicore::Tensor forward(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const = 0; + virtual void reset_state() = 0; + virtual ~MiniCPMSALAAttentionBase() = default; +}; +// Lightning attention path (Simple GLA). Parameter names align with HF: +// model.layers.N.self_attn.{q_proj,k_proj,v_proj,o_proj,q_norm,k_norm,o_norm,z_proj,...} +class MiniCPMSALALightningAttention : public MiniCPMSALAAttentionBase { public: - size_t layer_idx() const { return layer_idx_; } - size_t num_heads() const { return num_attention_heads_; } - size_t num_kv_heads() const { return num_key_value_heads_; } - size_t head_dim() const { return head_dim_; } - size_t hidden_size() const { return hidden_size_; } + MiniCPMSALALightningAttention(std::shared_ptr model_config, + const infinicore::Device &device, + size_t layer_idx); + + // Match `infinilm::layers::attention::Attention` API: metadata is pulled from + // `global_state::get_forward_context().attn_metadata`. + infinicore::Tensor forward(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const override; + + void reset_state() override; protected: - INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, q_proj); - INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, k_proj); - INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, v_proj); - INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + // Projections (HF-aligned naming) + INFINICORE_NN_MODULE(infinicore::nn::Linear, q_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, k_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, v_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj); + + // Optional (Lightning layers): q_norm/k_norm/o_norm + z_proj + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, o_norm); + INFINICORE_NN_MODULE(infinicore::nn::Linear, z_proj); - std::shared_ptr attn_; - ::infinilm::backends::AttentionBackend attention_backend_; std::shared_ptr rotary_emb_; size_t layer_idx_; - size_t hidden_size_; size_t num_attention_heads_; size_t num_key_value_heads_; size_t head_dim_; - bool use_bias_; - bool use_output_bias_; - - // For off-line kv cache quantization - INFINICORE_NN_PARAMETER(kv_cache_k_scale); - INFINICORE_NN_PARAMETER(kv_cache_v_scale); -}; + float scaling_; -/** - * @brief InfLLMv2 attention with optional output gate - */ -class InfLLMv2Attention : public AttentionBase { -public: - InfLLMv2Attention(std::shared_ptr model_config, - size_t layer_idx, - const infinicore::Device &device); + bool use_qk_norm_ = false; + bool use_output_gate_ = false; + bool use_output_norm_ = false; + bool use_rope_ = false; - infinicore::Tensor forward(const infinicore::Tensor &positions, - const infinicore::Tensor &hidden_states) const; + // Lightning layers only: per-head log-decay for Simple GLA (HF _build_slope_tensor * -1). + infinicore::Tensor g_gamma_; -protected: - bool use_output_gate_; - INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, o_gate); + // Lightning layers only: recurrent state for fast decode. + // Shape: [B, H, D, D] float32. Tracks how many KV tokens are folded into the state. + mutable infinicore::Tensor gla_state_; + mutable size_t gla_state_cached_len_ = 0; + mutable bool gla_state_valid_ = false; }; -/** - * @brief Lightning attention with optional output norm and gate - */ -class LightningAttention : public AttentionBase { +// Sparse attention path (`mixer_type=="minicpm4"`) using InfLLM-v2 operators. +// Parameter names align with HF: +// model.layers.N.self_attn.{q_proj,k_proj,v_proj,o_proj,o_gate,...} +class MiniCPMSALAMinicpm4Attention : public MiniCPMSALAAttentionBase { public: - LightningAttention(std::shared_ptr model_config, - size_t layer_idx, - const infinicore::Device &device); + MiniCPMSALAMinicpm4Attention(std::shared_ptr model_config, + const infinicore::Device &device, + size_t layer_idx); + + infinicore::Tensor forward(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const override; - infinicore::Tensor forward(const infinicore::Tensor &positions, - const infinicore::Tensor &hidden_states) const; + void reset_state() override; protected: - bool qk_norm_; - bool use_output_norm_; - bool use_output_gate_; - INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); - INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); - INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, o_norm); - INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, z_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, q_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, k_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, v_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, o_gate); + + size_t layer_idx_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t head_dim_; + float scaling_; + + // InfLLM-v2 local-window masking plumbing. + int infllmv2_window_left_ = -1; + bool use_local_window_ = false; }; } // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp b/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp deleted file mode 100644 index ff3c113f..00000000 --- a/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include "minicpm_sala_decoderLayer.hpp" - -#include "infinicore/ops.hpp" -#include -#include -#include - -namespace infinilm::models::minicpm_sala { - -MiniCPMSALADecoderLayer::MiniCPMSALADecoderLayer(std::shared_ptr model_config, - size_t layer_idx, - const infinicore::Device &device) - : layer_idx_(layer_idx) { - const auto &dtype{model_config->get_dtype()}; - size_t hidden_size = model_config->get("hidden_size"); - double rms_norm_eps = model_config->get("rms_norm_eps"); - - INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(mlp, model_config, device); - - std::vector mixer_types = model_config->get>("mixer_types"); - std::string mixer_type = mixer_types[layer_idx]; - if ("minicpm4" == mixer_type) { - self_attn_ = std::make_shared(this->register_module("self_attn", model_config, layer_idx, device)); - } else if ("lightning" == mixer_type || "lightning_attn" == mixer_type || "lightning-attn" == mixer_type) { - self_attn_ = std::make_shared(this->register_module("self_attn", model_config, layer_idx, device)); - } else { - throw std::runtime_error("infinilm::models::minicpm_sala::MiniCPMSALADecoderLayer: unsupported mixer_type '" + mixer_type + "' for layer " + std::to_string(layer_idx)); - } -} - -std::tuple MiniCPMSALADecoderLayer::forward(const infinicore::Tensor &positions, - infinicore::Tensor &hidden_states, - infinicore::Tensor &residual) { - input_layernorm_->forward_inplace(hidden_states, residual); - hidden_states = std::visit( - [&](auto &attn_ptr) { return attn_ptr->forward(positions, hidden_states); }, *self_attn_); - - post_attention_layernorm_->forward_inplace(hidden_states, residual); - hidden_states = mlp_->forward(hidden_states); - return std::make_tuple(hidden_states, residual); -} - -infinicore::Tensor MiniCPMSALADecoderLayer::forward(const infinicore::Tensor &positions, - infinicore::Tensor &hidden_states) { - auto residual = hidden_states; - hidden_states = input_layernorm_->forward(hidden_states); - hidden_states = std::visit( - [&](auto &attn_ptr) { return attn_ptr->forward(positions, hidden_states); }, *self_attn_); - - hidden_states = infinicore::op::add(residual, hidden_states); - - residual = hidden_states; - hidden_states = post_attention_layernorm_->forward(hidden_states); - hidden_states = mlp_->forward(hidden_states); - hidden_states = infinicore::op::add(residual, hidden_states); - return hidden_states; -} - -} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp b/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp deleted file mode 100644 index 5e8faafb..00000000 --- a/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include "../../layers/mlp/mlp.hpp" -#include "minicpm_sala_attention.hpp" -#include -#include - -namespace infinilm::models::minicpm_sala { -using MiniCPMMLP = infinilm::layers::MLP; -using MiniCPMSALAAttention = std::variant, std::shared_ptr>; - -class MiniCPMSALADecoderLayer : public infinicore::nn::Module { -public: - MiniCPMSALADecoderLayer(std::shared_ptr model_config, - size_t layer_idx, - const infinicore::Device &device); - - std::tuple forward(const infinicore::Tensor &positions, - infinicore::Tensor &hidden_states, - infinicore::Tensor &residual); - - infinicore::Tensor forward(const infinicore::Tensor &positions, - infinicore::Tensor &hidden_states); - -protected: - INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); - INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); - INFINICORE_NN_MODULE(MiniCPMSALAAttention, self_attn); - INFINICORE_NN_MODULE(MiniCPMMLP, mlp); - - size_t layer_idx_; -}; - -} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_decoder_layer.cpp b/csrc/models/minicpm_sala/minicpm_sala_decoder_layer.cpp new file mode 100644 index 00000000..6c04b480 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_decoder_layer.cpp @@ -0,0 +1,66 @@ +#include "minicpm_sala_decoder_layer.hpp" + +#include "infinicore/ops.hpp" +#include "infinicore/context/context.hpp" +#include +#include +#include +#include +#include +#include + +namespace infinilm::models::minicpm_sala { + + +MiniCPMSALADecoderLayer::MiniCPMSALADecoderLayer(std::shared_ptr model_config, + const infinicore::Device &device, + size_t layer_idx, + const std::string &mixer_type) { + // Match parameter dtype with checkpoint `torch_dtype` (e.g. BF16 for MiniCPM-SALA). + const auto dtype = model_config->get_dtype(); + const double eps = model_config->get("rms_norm_eps"); + + // MuP residual scaling at forward (o_proj/down_proj not scaled in loader for minicpm_sala). + const double scale_depth = model_config->get_or("scale_depth", 1.0); + const size_t num_layers = model_config->get("num_hidden_layers"); + residual_scale_ = scale_depth / std::sqrt(static_cast(num_layers)); + + INFINICORE_NN_MODULE_INIT(input_layernorm, model_config->get("hidden_size"), eps, dtype, device); + if (mixer_type == "minicpm4") { + self_attn_ = this->register_module( + "self_attn", model_config, device, layer_idx); + } else { + self_attn_ = this->register_module( + "self_attn", model_config, device, layer_idx); + } + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, model_config->get("hidden_size"), eps, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, model_config, device); +} + +void MiniCPMSALADecoderLayer::reset_attn_state() { + self_attn_->reset_state(); +} + +infinicore::Tensor MiniCPMSALADecoderLayer::forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids) const { + // Pre-norm attention + auto hs1 = input_layernorm_->forward(hidden_states); + auto attn_out = self_attn_->forward(position_ids, hs1); + + // residual + scale_down * attn_out (MuP) + auto ones_attn = infinicore::Tensor::empty(attn_out->shape(), attn_out->dtype(), attn_out->device()); + infinicore::op::ones_(ones_attn); + auto out1 = infinicore::op::addcmul(hidden_states, attn_out, ones_attn, static_cast(residual_scale_)); + + // Pre-norm MLP + auto hs2 = post_attention_layernorm_->forward(out1); + auto mlp_out = mlp_->forward(hs2); + // residual + scale_down * mlp_out (MuP) + auto ones_mlp = infinicore::Tensor::empty(mlp_out->shape(), mlp_out->dtype(), mlp_out->device()); + infinicore::op::ones_(ones_mlp); + auto out2 = infinicore::op::addcmul(out1, mlp_out, ones_mlp, static_cast(residual_scale_)); + + return out2; +} + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_decoder_layer.hpp b/csrc/models/minicpm_sala/minicpm_sala_decoder_layer.hpp new file mode 100644 index 00000000..305ab967 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_decoder_layer.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include "minicpm_sala_attention.hpp" +#include "minicpm_sala_mlp.hpp" + +#include "../../config/model_config.hpp" + +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/tensor.hpp" + +#include +#include + +namespace infinilm::models::minicpm_sala { + +class MiniCPMSALADecoderLayer : public infinicore::nn::Module { +public: + MiniCPMSALADecoderLayer(std::shared_ptr model_config, + const infinicore::Device &device, + size_t layer_idx, + const std::string &mixer_type); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids) const; + + void reset_attn_state(); + +private: + double residual_scale_ = 1.0; + +protected: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + // Registered under the HF-compatible name "self_attn" in ctor. + std::shared_ptr self_attn_; + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(MiniCPMSALAMLP, mlp); +}; + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp index 793f86bd..de6f34e1 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp +++ b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp @@ -1,49 +1,74 @@ #include "minicpm_sala_for_causal_lm.hpp" -#include "../../global_state/global_state.hpp" #include "../models_registry.hpp" + +#include "../../global_state/global_state.hpp" +#include "infinicore/ops.hpp" +#include #include #include namespace infinilm::models::minicpm_sala { -MiniCPMSALAForCausalLM::MiniCPMSALAForCausalLM(std::shared_ptr model_config, - const infinicore::Device &device) { +std::vector minicpm_sala_allocate_kv_cache_tensors( + const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend); + +std::shared_ptr create_minicpm_sala_model_config( + std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("minicpm_sala" != model_type) { + throw std::runtime_error("infinilm::models::minicpm_sala::create_minicpm_sala_model_config: model_type is not minicpm_sala"); + } + return model_config; +} + +MiniCPMSALAForCausalLM::MiniCPMSALAForCausalLM( + std::shared_ptr model_config, + const infinicore::Device &device) { + device_ = device; model_config_ = model_config; - size_t hidden_size = model_config->get("hidden_size"); - size_t vocab_size = model_config->get("vocab_size"); - const auto &dtype{model_config->get_dtype()}; + // Match parameter dtype with checkpoint `torch_dtype` (e.g. BF16 for MiniCPM-SALA). + const auto dtype = model_config->get_dtype(); INFINICORE_NN_MODULE_INIT(model, model_config, device); + + const size_t hidden_size = model_config->get("hidden_size"); + const size_t vocab_size = model_config->get("vocab_size"); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); } -infinilm::InfinilmModel::Output MiniCPMSALAForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { - auto hidden_states = model_->forward(input); +MiniCPMSALAForCausalLM::Output MiniCPMSALAForCausalLM::forward( + const Input &input) const { + auto input_ids = input.input_ids.value(); + auto position_ids = input.position_ids.value(); + + auto hidden_states = model_->forward(input_ids, position_ids); + + // MuP lm_head scale baked into lm_head.weight at load time; no forward scaling here. auto logits = lm_head_->forward(hidden_states); return {logits}; } void MiniCPMSALAForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { - if (nullptr == cache_config) { - InfinilmModel::reset_cache(nullptr); + // Match `InfinilmModel::reset_cache`: own `cache_config_` + `kv_cache_vec` here; inner model only + // resets per-layer attention state. MiniCPM uses `minicpm_sala_allocate_kv_cache_tensors` instead of + // `default_allocate_kv_cache_tensors`. + if (cache_config == nullptr) { + cache_config_.reset(); + infinilm::global_state::get_forward_context().kv_cache_vec.clear(); + model_->reset_state(); return; } cache_config_ = cache_config->unique_copy(); - auto &kv_cache_vec = infinilm::global_state::get_forward_context().kv_cache_vec; kv_cache_vec.clear(); - const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; - - auto new_kv_cache_vec = minicpm_sala_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend); - kv_cache_vec = std::move(new_kv_cache_vec); -} - -std::shared_ptr create_minicpm_sala_model_config(std::shared_ptr model_config) { - const std::string &model_type = model_config->get("model_type"); - if ("minicpm_sala" != model_type) { - throw std::runtime_error("infinilm::models::minicpm_sala::create_minicpm_sala_model_config: model_type is not minicpm_sala"); - } - return model_config; + const backends::AttentionBackend attention_backend = + infinilm::global_state::get_infinilm_config().attention_backend; + kv_cache_vec = std::move( + minicpm_sala_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend)); + model_->reset_state(); } } // namespace infinilm::models::minicpm_sala @@ -54,3 +79,4 @@ INFINILM_REGISTER_CAUSAL_LM_MODEL( infinilm::models::minicpm_sala::MiniCPMSALAForCausalLM, infinilm::models::minicpm_sala::create_minicpm_sala_model_config); } // namespace + diff --git a/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp index f0d0aaae..0a53e101 100644 --- a/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp +++ b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp @@ -1,13 +1,18 @@ #pragma once -#include "minicpm_sala_decoderLayer.hpp" -#include -#include +#include "../infinilm_model.hpp" +#include "minicpm_sala_model.hpp" -namespace infinilm::models::minicpm_sala { +#include "../../config/model_config.hpp" +#include "../../layers/linear/linear.hpp" + +#include "infinicore/device.hpp" -using MiniCPMSALAModel = infinilm::layers::causal_lm_templates::TextModel; +namespace infinilm::models::minicpm_sala { +// Milestone-0 stub. Full implementation will follow the MiniCPM-SALA design: +// - Lightning Attention (Simple GLA) layers + InfLLM-V2 sparse layers in a 1:3 ratio +// - HyPE (RoPE on linear layers; NoPE on sparse layers) class MiniCPMSALAForCausalLM : public InfinilmModel { public: MiniCPMSALAForCausalLM(std::shared_ptr model_config, @@ -17,15 +22,16 @@ class MiniCPMSALAForCausalLM : public InfinilmModel { void reset_cache(const cache::CacheConfig *cache_config) override; -protected: +private: INFINICORE_NN_MODULE(MiniCPMSALAModel, model); INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); }; -std::shared_ptr create_minicpm_sala_model_config(std::shared_ptr model_config); +} // namespace infinilm::models::minicpm_sala + +namespace infinilm::models::minicpm_sala { + +std::shared_ptr create_minicpm_sala_model_config( + std::shared_ptr model_config); -/** Implemented in `minicpm_sala_allocate_kv_cache_tensors.cpp`. */ -std::vector minicpm_sala_allocate_kv_cache_tensors(const cache::CacheConfig *cache_config, - const std::shared_ptr &text_config, - const backends::AttentionBackend &attention_backend); } // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_mlp.cpp b/csrc/models/minicpm_sala/minicpm_sala_mlp.cpp new file mode 100644 index 00000000..b9ebd3c6 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_mlp.cpp @@ -0,0 +1,32 @@ +#include "minicpm_sala_mlp.hpp" + +#include "infinicore/ops.hpp" + +namespace infinilm::models::minicpm_sala { + +MiniCPMSALAMLP::MiniCPMSALAMLP(std::shared_ptr model_config, + const infinicore::Device &device) { + // Match parameter dtype with checkpoint `torch_dtype` (e.g. BF16 for MiniCPM-SALA). + const auto dtype = model_config->get_dtype(); + const size_t hidden_size = model_config->get("hidden_size"); + const size_t intermediate_size = model_config->get("intermediate_size"); + + INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, dtype, device); + INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, dtype, device); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, dtype, device); +} + +infinicore::Tensor MiniCPMSALAMLP::forward(const infinicore::Tensor &x) const { + auto x_mut = x; + auto gate = gate_proj_->forward(x_mut); + auto up = up_proj_->forward(x_mut); + + // SwiGLU: silu(gate) * up — fused single kernel (swiglu(a,b) = a*b*sigmoid(b) => swiglu(up,gate)) + auto act = infinicore::op::swiglu(up, gate); + + auto act_mut = act; + return down_proj_->forward(act_mut); +} + +} // namespace infinilm::models::minicpm_sala + diff --git a/csrc/models/minicpm_sala/minicpm_sala_mlp.hpp b/csrc/models/minicpm_sala/minicpm_sala_mlp.hpp new file mode 100644 index 00000000..3150670b --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_mlp.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "../../config/model_config.hpp" + +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" + +#include + +namespace infinilm::models::minicpm_sala { + +class MiniCPMSALAMLP : public infinicore::nn::Module { +public: + MiniCPMSALAMLP(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &x) const; + +protected: + INFINICORE_NN_MODULE(infinicore::nn::Linear, gate_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, up_proj); + INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj); +}; + +} // namespace infinilm::models::minicpm_sala + diff --git a/csrc/models/minicpm_sala/minicpm_sala_model.cpp b/csrc/models/minicpm_sala/minicpm_sala_model.cpp new file mode 100644 index 00000000..20c6d420 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_model.cpp @@ -0,0 +1,65 @@ +#include "minicpm_sala_model.hpp" + +#include "infinicore/context/context.hpp" +#include "infinicore/ops.hpp" +#include +#include +#include +#include +#include +#include + +namespace infinilm::models::minicpm_sala { + +MiniCPMSALAModel::MiniCPMSALAModel(std::shared_ptr model_config, + const infinicore::Device &device) { + + // Match parameter dtype with checkpoint `torch_dtype` (e.g. BF16 for MiniCPM-SALA). + const auto dtype = model_config->get_dtype(); + + hidden_size_ = model_config->get("hidden_size"); + + const size_t vocab_size = model_config->get("vocab_size"); + const size_t num_layers = model_config->get("num_hidden_layers"); + + INFINICORE_NN_MODULE_INIT(embed_tokens, vocab_size, hidden_size_, std::nullopt, dtype, device); + INFINICORE_NN_MODULE_INIT(norm, hidden_size_, model_config->get("rms_norm_eps"), dtype, device); + + // Mixer types per-layer decide attention flavor (minicpm4 vs lightning-attn). + std::vector mixer_types; + try { + mixer_types = model_config->get>("mixer_types"); + } catch (...) { + mixer_types.assign(num_layers, "minicpm4"); + } + if (mixer_types.size() != num_layers) { + mixer_types.resize(num_layers, mixer_types.empty() ? "minicpm4" : mixer_types.back()); + } + + layers_.reserve(num_layers); + for (size_t i = 0; i < num_layers; ++i) { + layers_.push_back(this->register_module( + "layers." + std::to_string(i), model_config, device, i, mixer_types[i])); + } +} + +void MiniCPMSALAModel::reset_state() { + for (auto &layer : layers_) { + layer->reset_attn_state(); + } +} + +infinicore::Tensor MiniCPMSALAModel::forward(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids) const { + // MuP scaling baked into weights at load time for minicpm_sala; no forward scaling here. + auto hs = embed_tokens_->forward(input_ids); + + for (size_t i = 0; i < layers_.size(); ++i) { + hs = layers_[i]->forward(hs, position_ids); + } + + hs = norm_->forward(hs); + return hs; +} + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_model.hpp b/csrc/models/minicpm_sala/minicpm_sala_model.hpp new file mode 100644 index 00000000..811ecbf7 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_model.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "minicpm_sala_decoder_layer.hpp" + +#include "../../config/model_config.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/tensor.hpp" + +#include +#include +#include + +namespace infinilm::models::minicpm_sala { + +class MiniCPMSALAModel : public infinicore::nn::Module { +public: + MiniCPMSALAModel(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids) const; + + void reset_state(); + + size_t hidden_size() const { return hidden_size_; } + +protected: + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + INFINICORE_NN_MODULE_VEC(MiniCPMSALADecoderLayer, layers); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + +private: + size_t hidden_size_; +}; + +} // namespace infinilm::models::minicpm_sala + diff --git a/python/infinilm/auto_config.py b/python/infinilm/auto_config.py index 7e2d4afd..b6f96ff5 100644 --- a/python/infinilm/auto_config.py +++ b/python/infinilm/auto_config.py @@ -27,6 +27,8 @@ def from_pretrained(model_path): return LlamaConfig(**config_dict) elif config_dict["model_type"] == "minicpm": return LlamaConfig(**config_dict) + elif config_dict["model_type"] == "minicpm_sala": + return LlamaConfig(**config_dict) elif config_dict["model_type"] == "fm9g": return LlamaConfig(**config_dict) elif config_dict["model_type"] == "fm9g7b": diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index a67add6f..f25b97b9 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -164,6 +164,32 @@ def generate( dtype=infinicore.int32, ) + # Decode metadata fast path (batch=1, static cache): avoid per-step from_list() allocations + # for tiny scalar tensors (these live on CPU and are H2D-copied each forward). + fast_decode_meta = (not self.enable_paged_attn) and (initial_batch_size == 1) + if fast_decode_meta: + cpu = infinicore.device("cpu", 0) + + # Reusable metadata tensors; values updated via pybind write_i32/write_i64. + position_ids_decode = infinicore.empty( + [1, 1], dtype=infinicore.int64, device=cpu + ) + past_kv_lengths_decode = infinicore.empty( + [1], dtype=infinicore.int32, device=cpu + ) + total_kv_lengths_decode = infinicore.empty( + [1], dtype=infinicore.int32, device=cpu + ) + cu_seqlens_decode = infinicore.empty( + [2], dtype=infinicore.int32, device=cpu + ) + input_offsets_decode = infinicore.empty( + [2], dtype=infinicore.int32, device=cpu + ) + input_offsets_decode.write_i32(0, 0) + input_offsets_decode.write_i32(1, 1) + + decode_total_open = False for iter in range(0, generation_config.max_new_tokens): if _measure_and_log_time: start_time = time.perf_counter() @@ -203,29 +229,54 @@ def generate( dtype=infinicore.int64, ) else: - position_ids = infinicore.from_list( - [ - list(range(past_seq_len, past_seq_len + seq_len)) - for _ in range(batch_size) - ], - dtype=infinicore.int64, - ) + if fast_decode_meta and iter > 0 and batch_size == 1 and seq_len == 1: + position_ids_decode.write_i64(0, int(past_seq_len)) + past_kv_lengths_decode.write_i32(0, int(past_seq_len)) + total_kv_lengths_decode.write_i32(0, int(past_seq_len + seq_len)) + cu_seqlens_decode.write_i32(0, 0) + cu_seqlens_decode.write_i32(1, int(past_seq_len + seq_len)) + position_ids = position_ids_decode + past_kv_lengths = past_kv_lengths_decode + total_kv_lengths = total_kv_lengths_decode + cu_seqlens = cu_seqlens_decode + input_offsets = input_offsets_decode + else: + position_ids = infinicore.from_list( + [ + list(range(past_seq_len, past_seq_len + seq_len)) + for _ in range(batch_size) + ], + dtype=infinicore.int64, + ) + past_kv_lengths = infinicore.from_list( + [past_seq_len] * batch_size, dtype=infinicore.int32 + ) + total_kv_lengths = infinicore.from_list( + [past_seq_len + seq_len] * batch_size, dtype=infinicore.int32 + ) + cu_seqlens = infinicore.from_list( + [(past_seq_len + seq_len) * i for i in range(batch_size + 1)], + dtype=infinicore.int32, + ) + input_offsets = infinicore.from_list( + [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32 + ) slot_mapping = None - - past_kv_lengths = infinicore.from_list( - [past_seq_len] * batch_size, dtype=infinicore.int32 - ) - total_kv_lengths = infinicore.from_list( - [past_seq_len + seq_len] * batch_size, dtype=infinicore.int32 - ) - cu_seqlens = infinicore.from_list( - [(past_seq_len + seq_len) * i for i in range(batch_size + 1)], - dtype=infinicore.int32, - ) - input_offsets = infinicore.from_list( - [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32 - ) + if self.enable_paged_attn: + past_kv_lengths = infinicore.from_list( + [past_seq_len] * batch_size, dtype=infinicore.int32 + ) + total_kv_lengths = infinicore.from_list( + [past_seq_len + seq_len] * batch_size, dtype=infinicore.int32 + ) + cu_seqlens = infinicore.from_list( + [(past_seq_len + seq_len) * i for i in range(batch_size + 1)], + dtype=infinicore.int32, + ) + input_offsets = infinicore.from_list( + [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32 + ) output_id = self( input_ids=input_ids, diff --git a/python/infinilm/llm/static_scheduler.py b/python/infinilm/llm/static_scheduler.py index de4d9d35..c9b0cb30 100644 --- a/python/infinilm/llm/static_scheduler.py +++ b/python/infinilm/llm/static_scheduler.py @@ -60,9 +60,17 @@ def build_model_inputs( tokens = req.get_input_tokens() prefix_hit_len = self.prefix_hit_len input_tokens = tokens[prefix_hit_len:] - input_ids = [input_tokens] - position_ids = [list(range(prefix_hit_len, len(tokens)))] - past_kv_len = prefix_hit_len + if len(input_tokens) == 0: + # Full prefix hit: avoid empty tensor conversion in model input path. + # Recompute the last prompt token as a one-token prefill step. + input_tokens = [tokens[-1]] + input_ids = [input_tokens] + position_ids = [[len(tokens) - 1]] + past_kv_len = len(tokens) - 1 + else: + input_ids = [input_tokens] + position_ids = [list(range(prefix_hit_len, len(tokens)))] + past_kv_len = prefix_hit_len total_kv_len = len(tokens) input_offsets = [0, len(input_tokens)] else: diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 1d21f2d9..03d3c062 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -1,4 +1,6 @@ import os +import json +import math from typing import Dict, Union import time import torch @@ -152,6 +154,29 @@ def load_model_state_dict_by_file( torch_dtype = infinicore.utils.to_torch_dtype(dtype) model_keys = model.state_dict_keyname() + # MiniCPM-SALA scaling (bake selected MuP scales into weights). + # This matches `InfiniLM/scripts/jiuge.py` weight scaling behavior for `model_type=="minicpm_sala"`. + scale_input = 1.0 + scale_output = 1.0 + scale_o = 1.0 + scale_down = 1.0 + scale_lm_head = 1.0 + try: + # TODO: fetch config from model rather than file directly + with open(os.path.join(model_path, "config.json")) as f: + cfg = json.load(f) + if cfg.get("model_type") == "minicpm_sala" and "scale_emb" in cfg and "scale_depth" in cfg: + scale_input = float(cfg["scale_emb"]) + scale_o = float(cfg["scale_depth"]) / math.sqrt(float(cfg["num_hidden_layers"])) + scale_down = float(cfg["scale_depth"]) / math.sqrt(float(cfg["num_hidden_layers"])) + if "dim_model_base" in cfg and "hidden_size" in cfg: + scale_lm_head = float(cfg["dim_model_base"]) / float(cfg["hidden_size"]) + # minicpm_sala: only bake embed and lm_head; residual scaling done at forward in C++ + scale_o = 1.0 + scale_down = 1.0 + except Exception: + pass + already_loaded_keys = [] file_list = glob.glob(os.path.join(model_path, "*.safetensors")) @@ -167,6 +192,24 @@ def load_model_state_dict_by_file( ) already_loaded_keys.extend(model_param.keys()) + # Apply MiniCPM scaling to loaded tensors (in torch space). + if scale_input != 1.0 and "model.embed_tokens.weight" in model_param: + model_param["model.embed_tokens.weight"] = ( + model_param["model.embed_tokens.weight"] * scale_input + ) + if scale_output != 1.0 and "model.norm.weight" in model_param: + model_param["model.norm.weight"] = ( + model_param["model.norm.weight"] * scale_output + ) + if scale_o != 1.0 or scale_down != 1.0: + for k, v in list(model_param.items()): + if scale_o != 1.0 and k.endswith(".self_attn.o_proj.weight"): + model_param[k] = v * scale_o + elif scale_down != 1.0 and k.endswith(".mlp.down_proj.weight"): + model_param[k] = v * scale_down + if scale_lm_head != 1.0 and "lm_head.weight" in model_param: + model_param["lm_head.weight"] = model_param["lm_head.weight"] * scale_lm_head + # --------------------------------------------------------- # # model_param_infini references torch.Tensor # --------------------------------------------------------- # @@ -180,6 +223,19 @@ def load_model_state_dict_by_file( file_path = os.path.join(model_path, "pytorch_model.bin") model_params = torch.load(file_path, weights_only=True, map_location="cpu") + if scale_input != 1.0 and "model.embed_tokens.weight" in model_params: + model_params["model.embed_tokens.weight"] = model_params["model.embed_tokens.weight"] * scale_input + if scale_output != 1.0 and "model.norm.weight" in model_params: + model_params["model.norm.weight"] = model_params["model.norm.weight"] * scale_output + if scale_o != 1.0 or scale_down != 1.0: + for k, v in list(model_params.items()): + if scale_o != 1.0 and k.endswith(".self_attn.o_proj.weight"): + model_params[k] = v * scale_o + elif scale_down != 1.0 and k.endswith(".mlp.down_proj.weight"): + model_params[k] = v * scale_down + if scale_lm_head != 1.0 and "lm_head.weight" in model_params: + model_params["lm_head.weight"] = model_params["lm_head.weight"] * scale_lm_head + model_param_infini = {} for key in model_params.keys(): model_param_infini[key] = infinicore.from_torch( diff --git a/xmake.lua b/xmake.lua index 2b1b51d3..5282f6a7 100644 --- a/xmake.lua +++ b/xmake.lua @@ -56,7 +56,7 @@ target_end() target("_infinilm") add_packages("pybind11") set_default(false) - add_rules("python.module", {soabi = true}) + add_rules("python.library", {soabi = true}) set_languages("cxx17") set_kind("shared") @@ -70,6 +70,7 @@ target("_infinilm") add_linkdirs(INFINI_ROOT.."/lib") add_links("infinicore_cpp_api", "infiniop", "infinirt", "infiniccl") + add_rpathdirs(INFINI_ROOT.."/lib") -- Add src files add_files("csrc/**.cpp")