|
| 1 | +"""Qwen3_5 architecture adapter. |
| 2 | +
|
| 3 | +Qwen3_5ForCausalLM is a hybrid linear-attention + full-attention architecture |
| 4 | +with a dense gated MLP on every layer. Layers follow a repeating pattern of |
| 5 | +3 GatedDeltaNet (linear attention) layers followed by 1 standard full-attention |
| 6 | +layer (every 4th layer by default). |
| 7 | +
|
| 8 | +Since self_attn is absent on linear-attention layers, we only map submodules |
| 9 | +that exist on ALL layers (norms, MLP). The HF native forward handles |
| 10 | +linear/full attention dispatch internally, and GatedMLPBridge maps the dense |
| 11 | +gate_proj/up_proj/down_proj structure on every layer. |
| 12 | +
|
| 13 | +Hook coverage: |
| 14 | +- Block-level: hook_resid_pre, hook_resid_post on every layer |
| 15 | +- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm) |
| 16 | +- MLP: hook_in, hook_out via GatedMLPBridge (gate_proj, up_proj, down_proj) |
| 17 | +- Attention internals are NOT individually hooked (self_attn absent on |
| 18 | + linear-attention layers; mapping it would crash on those layers) |
| 19 | +
|
| 20 | +Optional parameters: |
| 21 | +- n_key_value_heads: only set when using GQA (num_key_value_heads != num_attention_heads) |
| 22 | +""" |
| 23 | + |
| 24 | +from typing import Any |
| 25 | + |
| 26 | +import torch |
| 27 | + |
| 28 | +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter |
| 29 | +from transformer_lens.model_bridge.generalized_components import ( |
| 30 | + BlockBridge, |
| 31 | + EmbeddingBridge, |
| 32 | + GatedMLPBridge, |
| 33 | + LinearBridge, |
| 34 | + RMSNormalizationBridge, |
| 35 | + RotaryEmbeddingBridge, |
| 36 | + UnembeddingBridge, |
| 37 | +) |
| 38 | + |
| 39 | + |
| 40 | +class Qwen3_5ArchitectureAdapter(ArchitectureAdapter): |
| 41 | + """Architecture adapter for Qwen3_5 models. |
| 42 | +
|
| 43 | + Qwen3_5ForCausalLM is a hybrid linear-attention + full-attention |
| 44 | + architecture with dense gated MLPs, sharing the same hybrid design as |
| 45 | + Qwen3Next but replacing the sparse MoE MLP with a standard dense MLP: |
| 46 | + - Uses RMSNorm for all normalizations |
| 47 | + - Uses rotary position embeddings (RoPE) with partial rotation |
| 48 | + - Every 4th layer is a full-attention layer (self_attn); the rest are |
| 49 | + GatedDeltaNet linear-attention layers (linear_attn) |
| 50 | + - Uses dense gated MLP (gate_proj + up_proj -> down_proj) on ALL layers |
| 51 | + - No biases on any linear layers |
| 52 | + - Full-attention layers have Q/K normalization (q_norm, k_norm) |
| 53 | + - Full-attention q_proj outputs n_heads * head_dim * 2 (interleaved |
| 54 | + query+gate layout); the preprocess_weights method slices the query half |
| 55 | +
|
| 56 | + Since self_attn is absent on linear-attention layers, only universally |
| 57 | + present submodules (norms, MLP) are mapped as block submodules. The HF |
| 58 | + native forward handles per-layer attention dispatch internally. |
| 59 | +
|
| 60 | + Optional parameters: |
| 61 | + - n_key_value_heads: set when num_key_value_heads != num_attention_heads (GQA) |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self, cfg: Any) -> None: |
| 65 | + """Initialize the Qwen3_5 architecture adapter.""" |
| 66 | + super().__init__(cfg) |
| 67 | + |
| 68 | + # Core config attributes |
| 69 | + self.cfg.normalization_type = "RMS" |
| 70 | + self.cfg.positional_embedding_type = "rotary" |
| 71 | + self.cfg.final_rms = True |
| 72 | + self.cfg.gated_mlp = True |
| 73 | + self.cfg.attn_only = False |
| 74 | + self.cfg.uses_rms_norm = True |
| 75 | + self.cfg.default_prepend_bos = False |
| 76 | + |
| 77 | + # Disable fold_ln: ln1 is followed by self_attn on full-attention |
| 78 | + # layers and by linear_attn (GatedDeltaNet) on linear-attention layers, |
| 79 | + # but neither is mapped as a bridge submodule (see class docstring for |
| 80 | + # why). With no bridge-mapped target to fold into, the standard fold_ln |
| 81 | + # pass leaves LN weights in an inconsistent state and the processed |
| 82 | + # bridge output diverges from the unprocessed / HF output. Skipping |
| 83 | + # fold_ln keeps processed-mode forward passes numerically equivalent. |
| 84 | + self.supports_fold_ln = False |
| 85 | + |
| 86 | + # Use eager attention to support output_attentions for hook_attn_scores |
| 87 | + # and hook_pattern. SDPA doesn't support output_attentions. |
| 88 | + self.cfg.attn_implementation = "eager" |
| 89 | + |
| 90 | + # GQA: only set n_key_value_heads when using grouped-query attention |
| 91 | + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: |
| 92 | + self.cfg.n_key_value_heads = cfg.n_key_value_heads |
| 93 | + |
| 94 | + self.weight_processing_conversions: dict = {} |
| 95 | + self.component_mapping: dict = { |
| 96 | + "embed": EmbeddingBridge(name="model.embed_tokens"), |
| 97 | + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), |
| 98 | + "blocks": BlockBridge( |
| 99 | + name="model.layers", |
| 100 | + submodules={ |
| 101 | + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), |
| 102 | + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), |
| 103 | + # Dense gated MLP present on every layer (unlike Qwen3Next's MoE). |
| 104 | + # gate_proj + up_proj feed into down_proj via SwiGLU activation. |
| 105 | + "mlp": GatedMLPBridge( |
| 106 | + name="mlp", |
| 107 | + config=self.cfg, |
| 108 | + submodules={ |
| 109 | + "gate": LinearBridge(name="gate_proj"), |
| 110 | + "in": LinearBridge(name="up_proj"), |
| 111 | + "out": LinearBridge(name="down_proj"), |
| 112 | + }, |
| 113 | + ), |
| 114 | + }, |
| 115 | + ), |
| 116 | + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), |
| 117 | + "unembed": UnembeddingBridge(name="lm_head"), |
| 118 | + } |
| 119 | + |
| 120 | + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: |
| 121 | + """Swap the multimodal Qwen3_5Config for its text-only Qwen3_5TextConfig. |
| 122 | +
|
| 123 | + Published Qwen3.5 checkpoints (e.g. Qwen/Qwen3.5-0.8B) carry |
| 124 | + model_type='qwen3_5' and architectures=['Qwen3_5ForConditionalGeneration']. |
| 125 | + AutoModelForCausalLM would load the full VLM (Qwen3_5ForConditionalGeneration) |
| 126 | + with its vision tower, wasting memory and failing the bridge. |
| 127 | +
|
| 128 | + Instead we replace model_kwargs['config'] with the nested text_config so |
| 129 | + AutoModelForCausalLM loads Qwen3_5ForCausalLM (text only). |
| 130 | + """ |
| 131 | + config = model_kwargs.get("config") |
| 132 | + if config is not None and hasattr(config, "text_config"): |
| 133 | + model_kwargs["config"] = config.text_config |
| 134 | + |
| 135 | + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: |
| 136 | + """No-op for hybrid models. |
| 137 | +
|
| 138 | + Hybrid models don't map attention as a block submodule (self_attn is |
| 139 | + absent on linear-attention layers), so there are no rotary embedding |
| 140 | + references to set up. |
| 141 | +
|
| 142 | + Note: to find which layers are full_attention at runtime, use: |
| 143 | + layer_types = getattr(hf_model.config, "layer_types", []) |
| 144 | + first_full_attn_idx = next( |
| 145 | + i for i, t in enumerate(layer_types) if t == "full_attention" |
| 146 | + ) |
| 147 | + Do NOT use hf_model.config.full_attention_interval -- it is not stored |
| 148 | + on the config object (consumed during __init__ to build layer_types). |
| 149 | + """ |
| 150 | + |
| 151 | + def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| 152 | + """Slice query half from q_proj.weight (interleaved per-head layout). |
| 153 | +
|
| 154 | + In Qwen3_5, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size). |
| 155 | + Rows are organized as per-head interleaved: |
| 156 | + head_0_query (d_head rows), head_0_gate (d_head rows), |
| 157 | + head_1_query (d_head rows), head_1_gate (d_head rows), ... |
| 158 | +
|
| 159 | + A naive first-half slice would be wrong. We must reshape by head, then |
| 160 | + take the first d_head rows of each head (the query half). |
| 161 | +
|
| 162 | + Note: since self_attn is NOT currently mapped as a bridge submodule, |
| 163 | + these weights will not be loaded by the bridge. This method is included |
| 164 | + for correctness and forward-compatibility. |
| 165 | + """ |
| 166 | + n_heads = self.cfg.n_heads |
| 167 | + d_head = self.cfg.d_head |
| 168 | + keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")] |
| 169 | + for key in keys_to_update: |
| 170 | + w = state_dict[key] # shape: (n_heads * d_head * 2, hidden_size) |
| 171 | + # Reshape to expose per-head layout |
| 172 | + w = w.view(n_heads, d_head * 2, -1) |
| 173 | + # Take only the first d_head rows of each head (query half) |
| 174 | + state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1) |
| 175 | + return state_dict |
0 commit comments