From f480615eeada573f76de5df34d2a8ec84fa6e169 Mon Sep 17 00:00:00 2001 From: Vladimir Moushkov Date: Tue, 31 Mar 2026 15:08:18 +0000 Subject: [PATCH] Add Qwen2.5-32B support with per-linear RMSNorm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds per-linear RMSNorm support to the Qwen2 forward pass in llama.cpp, enabling inference of BitNet QAT-trained Qwen2.5-32B models. The per-linear RMSNorm applies a separate RMSNorm before each quantized linear projection (Q, K, V, O, gate, up, down), adding 448 small norm weight tensors (7 per layer x 64 layers). These are loaded as optional tensors from the GGUF with names like: blk.{layer}.attn_q.rms_norm.weight blk.{layer}.ffn_gate.rms_norm.weight etc. When the norm weights are absent (standard BitNet models), the forward pass falls through to the existing code path — no regression. The qwen2-per-linear-rmsnorm.patch contains the full diff against the llama.cpp submodule (3rdparty/llama.cpp/src/llama.cpp). Co-Authored-By: Claude Opus 4.6 (1M context) --- qwen2-per-linear-rmsnorm.patch | 159 +++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 qwen2-per-linear-rmsnorm.patch diff --git a/qwen2-per-linear-rmsnorm.patch b/qwen2-per-linear-rmsnorm.patch new file mode 100644 index 000000000..ff2e3ee0b --- /dev/null +++ b/qwen2-per-linear-rmsnorm.patch @@ -0,0 +1,159 @@ +596a597,604 +> // per-linear RMSNorm (BitNet QAT) +> LLM_TENSOR_ATTN_Q_RMS_NORM, +> LLM_TENSOR_ATTN_K_RMS_NORM, +> LLM_TENSOR_ATTN_V_RMS_NORM, +> LLM_TENSOR_ATTN_OUT_RMS_NORM, +> LLM_TENSOR_FFN_GATE_RMS_NORM, +> LLM_TENSOR_FFN_DOWN_RMS_NORM, +> LLM_TENSOR_FFN_UP_RMS_NORM, +919a928,935 +> // per-linear RMSNorm (BitNet QAT) +> { LLM_TENSOR_ATTN_Q_RMS_NORM, "blk.%d.attn_q.rms_norm" }, +> { LLM_TENSOR_ATTN_K_RMS_NORM, "blk.%d.attn_k.rms_norm" }, +> { LLM_TENSOR_ATTN_V_RMS_NORM, "blk.%d.attn_v.rms_norm" }, +> { LLM_TENSOR_ATTN_OUT_RMS_NORM, "blk.%d.attn_output.rms_norm" }, +> { LLM_TENSOR_FFN_GATE_RMS_NORM, "blk.%d.ffn_gate.rms_norm" }, +> { LLM_TENSOR_FFN_DOWN_RMS_NORM, "blk.%d.ffn_down.rms_norm" }, +> { LLM_TENSOR_FFN_UP_RMS_NORM, "blk.%d.ffn_up.rms_norm" }, +2801a2818,2826 +> +> // per-linear RMSNorm weights (BitNet QAT) +> struct ggml_tensor * wq_rms_norm = nullptr; +> struct ggml_tensor * wk_rms_norm = nullptr; +> struct ggml_tensor * wv_rms_norm = nullptr; +> struct ggml_tensor * wo_rms_norm = nullptr; +> struct ggml_tensor * ffn_gate_rms_norm = nullptr; +> struct ggml_tensor * ffn_down_rms_norm = nullptr; +> struct ggml_tensor * ffn_up_rms_norm = nullptr; +7925a7951,7959 +> +> // optional per-linear RMSNorm weights (BitNet QAT) +> layer.wq_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_RMS_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); +> layer.wk_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_RMS_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); +> layer.wv_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V_RMS_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); +> layer.wo_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_RMS_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); +> layer.ffn_gate_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_RMS_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); +> layer.ffn_down_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN_RMS_NORM, "weight", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); +> layer.ffn_up_rms_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP_RMS_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); +12397a12432,12455 +> // per-linear RMSNorm before Q projection (BitNet QAT) +> struct ggml_tensor * cur_q = cur; +> if (model.layers[il].wq_rms_norm) { +> cur_q = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur_q = ggml_mul(ctx0, cur_q, model.layers[il].wq_rms_norm); +> cb(cur_q, "wq_rms_norm", il); +> } +> +> // per-linear RMSNorm before K projection (BitNet QAT) +> struct ggml_tensor * cur_k = cur; +> if (model.layers[il].wk_rms_norm) { +> cur_k = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur_k = ggml_mul(ctx0, cur_k, model.layers[il].wk_rms_norm); +> cb(cur_k, "wk_rms_norm", il); +> } +> +> // per-linear RMSNorm before V projection (BitNet QAT) +> struct ggml_tensor * cur_v = cur; +> if (model.layers[il].wv_rms_norm) { +> cur_v = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur_v = ggml_mul(ctx0, cur_v, model.layers[il].wv_rms_norm); +> cb(cur_v, "wv_rms_norm", il); +> } +> +12399c12457 +< struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); +--- +> struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur_q); +12404c12462 +< struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); +--- +> struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur_k); +12409c12467 +< struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); +--- +> struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur_v); +12428,12430c12486,12508 +< cur = llm_build_kv(ctx0, lctx, kv_self, gf, +< model.layers[il].wo, model.layers[il].bo, +< Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); +--- +> // For wo: if per-linear RMSNorm exists, pass NULL for wo to llm_build_kv +> // and apply RMSNorm + wo manually afterward +> if (model.layers[il].wo_rms_norm) { +> cur = llm_build_kv(ctx0, lctx, kv_self, gf, +> NULL, NULL, +> Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); +> +> // per-linear RMSNorm before wo projection (BitNet QAT) +> cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur = ggml_mul(ctx0, cur, model.layers[il].wo_rms_norm); +> cb(cur, "wo_rms_norm", il); +> +> cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); +> if (model.layers[il].bo) { +> cb(cur, "kqv_wo", il); +> cur = ggml_add(ctx0, cur, model.layers[il].bo); +> } +> cb(cur, "kqv_out", il); +> } else { +> cur = llm_build_kv(ctx0, lctx, kv_self, gf, +> model.layers[il].wo, model.layers[il].bo, +> Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); +> } +12449,12454c12527,12574 +< cur = llm_build_ffn(ctx0, lctx, cur, +< model.layers[il].ffn_up, NULL, NULL, +< model.layers[il].ffn_gate, NULL, NULL, +< model.layers[il].ffn_down, NULL, NULL, +< NULL, +< LLM_FFN_SILU, LLM_FFN_PAR, cb, il); +--- +> // If per-linear RMSNorm exists for FFN, handle projections manually +> if (model.layers[il].ffn_gate_rms_norm || model.layers[il].ffn_up_rms_norm || model.layers[il].ffn_down_rms_norm) { +> // per-linear RMSNorm before gate projection +> struct ggml_tensor * cur_gate = cur; +> if (model.layers[il].ffn_gate_rms_norm) { +> cur_gate = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur_gate = ggml_mul(ctx0, cur_gate, model.layers[il].ffn_gate_rms_norm); +> cb(cur_gate, "ffn_gate_rms_norm", il); +> } +> +> // per-linear RMSNorm before up projection +> struct ggml_tensor * cur_up = cur; +> if (model.layers[il].ffn_up_rms_norm) { +> cur_up = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur_up = ggml_mul(ctx0, cur_up, model.layers[il].ffn_up_rms_norm); +> cb(cur_up, "ffn_up_rms_norm", il); +> } +> +> // gate and up projections (parallel: SwiGLU) +> struct ggml_tensor * gate_out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate, cur_gate); +> cb(gate_out, "ffn_gate", il); +> +> struct ggml_tensor * up_out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_up, cur_up); +> cb(up_out, "ffn_up", il); +> +> // SiLU activation on gate, then element-wise multiply with up +> gate_out = ggml_silu(ctx0, gate_out); +> cb(gate_out, "ffn_silu", il); +> +> cur = ggml_mul(ctx0, gate_out, up_out); +> cb(cur, "ffn_gate_par", il); +> +> // per-linear RMSNorm before down projection +> if (model.layers[il].ffn_down_rms_norm) { +> cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); +> cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_rms_norm); +> cb(cur, "ffn_down_rms_norm", il); +> } +> +> cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); +> } else { +> cur = llm_build_ffn(ctx0, lctx, cur, +> model.layers[il].ffn_up, NULL, NULL, +> model.layers[il].ffn_gate, NULL, NULL, +> model.layers[il].ffn_down, NULL, NULL, +> NULL, +> LLM_FFN_SILU, LLM_FFN_PAR, cb, il); +> }