Skip to content

Commit 22722f9

Browse files
authored
Qwen 3.5 Causal Adapter (#1253)
* Initial Qwen 3.5 adapter * Latest verifications
1 parent 9612ebe commit 22722f9

8 files changed

Lines changed: 916 additions & 4 deletions

File tree

tests/unit/test_qwen3_5_adapter.py

Lines changed: 643 additions & 0 deletions
Large diffs are not rendered by default.

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Phi3ArchitectureAdapter,
5050
PhiArchitectureAdapter,
5151
Qwen2ArchitectureAdapter,
52+
Qwen3_5ArchitectureAdapter,
5253
Qwen3ArchitectureAdapter,
5354
Qwen3MoeArchitectureAdapter,
5455
Qwen3NextArchitectureAdapter,
@@ -107,6 +108,7 @@
107108
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
108109
"Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter,
109110
"Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
111+
"Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter,
110112
"StableLmForCausalLM": StableLmArchitectureAdapter,
111113
"T5ForConditionalGeneration": T5ArchitectureAdapter,
112114
"XGLMForCausalLM": XGLMArchitectureAdapter,

transformer_lens/model_bridge/sources/transformers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def determine_architecture_from_hf_config(hf_config):
223223
"qwen": "QwenForCausalLM",
224224
"qwen2": "Qwen2ForCausalLM",
225225
"qwen3": "Qwen3ForCausalLM",
226+
# qwen3_5 is the top-level multimodal config type; qwen3_5_text is
227+
# the text-only sub-config. Both map to the text-only adapter so
228+
# Qwen3.5 checkpoints (which report qwen3_5 even when loaded as
229+
# text-only) are routed to Qwen3_5ForCausalLM.
230+
"qwen3_5": "Qwen3_5ForCausalLM",
231+
"qwen3_5_text": "Qwen3_5ForCausalLM",
226232
"openelm": "OpenELMForCausalLM",
227233
"stablelm": "StableLmForCausalLM",
228234
"t5": "T5ForConditionalGeneration",

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@
147147
from transformer_lens.model_bridge.supported_architectures.qwen3_next import (
148148
Qwen3NextArchitectureAdapter,
149149
)
150+
from transformer_lens.model_bridge.supported_architectures.qwen3_5 import (
151+
Qwen3_5ArchitectureAdapter,
152+
)
150153
from transformer_lens.model_bridge.supported_architectures.stablelm import (
151154
StableLmArchitectureAdapter,
152155
)
@@ -206,6 +209,7 @@
206209
"Qwen3ArchitectureAdapter",
207210
"Qwen3MoeArchitectureAdapter",
208211
"Qwen3NextArchitectureAdapter",
212+
"Qwen3_5ArchitectureAdapter",
209213
"StableLmArchitectureAdapter",
210214
"T5ArchitectureAdapter",
211215
"XGLMArchitectureAdapter",
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Qwen3_5 architecture adapter.
2+
3+
Qwen3_5ForCausalLM is a hybrid linear-attention + full-attention architecture
4+
with a dense gated MLP on every layer. Layers follow a repeating pattern of
5+
3 GatedDeltaNet (linear attention) layers followed by 1 standard full-attention
6+
layer (every 4th layer by default).
7+
8+
Since self_attn is absent on linear-attention layers, we only map submodules
9+
that exist on ALL layers (norms, MLP). The HF native forward handles
10+
linear/full attention dispatch internally, and GatedMLPBridge maps the dense
11+
gate_proj/up_proj/down_proj structure on every layer.
12+
13+
Hook coverage:
14+
- Block-level: hook_resid_pre, hook_resid_post on every layer
15+
- Normalization: ln1 (input_layernorm), ln2 (post_attention_layernorm)
16+
- MLP: hook_in, hook_out via GatedMLPBridge (gate_proj, up_proj, down_proj)
17+
- Attention internals are NOT individually hooked (self_attn absent on
18+
linear-attention layers; mapping it would crash on those layers)
19+
20+
Optional parameters:
21+
- n_key_value_heads: only set when using GQA (num_key_value_heads != num_attention_heads)
22+
"""
23+
24+
from typing import Any
25+
26+
import torch
27+
28+
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
29+
from transformer_lens.model_bridge.generalized_components import (
30+
BlockBridge,
31+
EmbeddingBridge,
32+
GatedMLPBridge,
33+
LinearBridge,
34+
RMSNormalizationBridge,
35+
RotaryEmbeddingBridge,
36+
UnembeddingBridge,
37+
)
38+
39+
40+
class Qwen3_5ArchitectureAdapter(ArchitectureAdapter):
41+
"""Architecture adapter for Qwen3_5 models.
42+
43+
Qwen3_5ForCausalLM is a hybrid linear-attention + full-attention
44+
architecture with dense gated MLPs, sharing the same hybrid design as
45+
Qwen3Next but replacing the sparse MoE MLP with a standard dense MLP:
46+
- Uses RMSNorm for all normalizations
47+
- Uses rotary position embeddings (RoPE) with partial rotation
48+
- Every 4th layer is a full-attention layer (self_attn); the rest are
49+
GatedDeltaNet linear-attention layers (linear_attn)
50+
- Uses dense gated MLP (gate_proj + up_proj -> down_proj) on ALL layers
51+
- No biases on any linear layers
52+
- Full-attention layers have Q/K normalization (q_norm, k_norm)
53+
- Full-attention q_proj outputs n_heads * head_dim * 2 (interleaved
54+
query+gate layout); the preprocess_weights method slices the query half
55+
56+
Since self_attn is absent on linear-attention layers, only universally
57+
present submodules (norms, MLP) are mapped as block submodules. The HF
58+
native forward handles per-layer attention dispatch internally.
59+
60+
Optional parameters:
61+
- n_key_value_heads: set when num_key_value_heads != num_attention_heads (GQA)
62+
"""
63+
64+
def __init__(self, cfg: Any) -> None:
65+
"""Initialize the Qwen3_5 architecture adapter."""
66+
super().__init__(cfg)
67+
68+
# Core config attributes
69+
self.cfg.normalization_type = "RMS"
70+
self.cfg.positional_embedding_type = "rotary"
71+
self.cfg.final_rms = True
72+
self.cfg.gated_mlp = True
73+
self.cfg.attn_only = False
74+
self.cfg.uses_rms_norm = True
75+
self.cfg.default_prepend_bos = False
76+
77+
# Disable fold_ln: ln1 is followed by self_attn on full-attention
78+
# layers and by linear_attn (GatedDeltaNet) on linear-attention layers,
79+
# but neither is mapped as a bridge submodule (see class docstring for
80+
# why). With no bridge-mapped target to fold into, the standard fold_ln
81+
# pass leaves LN weights in an inconsistent state and the processed
82+
# bridge output diverges from the unprocessed / HF output. Skipping
83+
# fold_ln keeps processed-mode forward passes numerically equivalent.
84+
self.supports_fold_ln = False
85+
86+
# Use eager attention to support output_attentions for hook_attn_scores
87+
# and hook_pattern. SDPA doesn't support output_attentions.
88+
self.cfg.attn_implementation = "eager"
89+
90+
# GQA: only set n_key_value_heads when using grouped-query attention
91+
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
92+
self.cfg.n_key_value_heads = cfg.n_key_value_heads
93+
94+
self.weight_processing_conversions: dict = {}
95+
self.component_mapping: dict = {
96+
"embed": EmbeddingBridge(name="model.embed_tokens"),
97+
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
98+
"blocks": BlockBridge(
99+
name="model.layers",
100+
submodules={
101+
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
102+
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
103+
# Dense gated MLP present on every layer (unlike Qwen3Next's MoE).
104+
# gate_proj + up_proj feed into down_proj via SwiGLU activation.
105+
"mlp": GatedMLPBridge(
106+
name="mlp",
107+
config=self.cfg,
108+
submodules={
109+
"gate": LinearBridge(name="gate_proj"),
110+
"in": LinearBridge(name="up_proj"),
111+
"out": LinearBridge(name="down_proj"),
112+
},
113+
),
114+
},
115+
),
116+
"ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
117+
"unembed": UnembeddingBridge(name="lm_head"),
118+
}
119+
120+
def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
121+
"""Swap the multimodal Qwen3_5Config for its text-only Qwen3_5TextConfig.
122+
123+
Published Qwen3.5 checkpoints (e.g. Qwen/Qwen3.5-0.8B) carry
124+
model_type='qwen3_5' and architectures=['Qwen3_5ForConditionalGeneration'].
125+
AutoModelForCausalLM would load the full VLM (Qwen3_5ForConditionalGeneration)
126+
with its vision tower, wasting memory and failing the bridge.
127+
128+
Instead we replace model_kwargs['config'] with the nested text_config so
129+
AutoModelForCausalLM loads Qwen3_5ForCausalLM (text only).
130+
"""
131+
config = model_kwargs.get("config")
132+
if config is not None and hasattr(config, "text_config"):
133+
model_kwargs["config"] = config.text_config
134+
135+
def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
136+
"""No-op for hybrid models.
137+
138+
Hybrid models don't map attention as a block submodule (self_attn is
139+
absent on linear-attention layers), so there are no rotary embedding
140+
references to set up.
141+
142+
Note: to find which layers are full_attention at runtime, use:
143+
layer_types = getattr(hf_model.config, "layer_types", [])
144+
first_full_attn_idx = next(
145+
i for i, t in enumerate(layer_types) if t == "full_attention"
146+
)
147+
Do NOT use hf_model.config.full_attention_interval -- it is not stored
148+
on the config object (consumed during __init__ to build layer_types).
149+
"""
150+
151+
def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
152+
"""Slice query half from q_proj.weight (interleaved per-head layout).
153+
154+
In Qwen3_5, q_proj.weight has shape (n_heads * head_dim * 2, hidden_size).
155+
Rows are organized as per-head interleaved:
156+
head_0_query (d_head rows), head_0_gate (d_head rows),
157+
head_1_query (d_head rows), head_1_gate (d_head rows), ...
158+
159+
A naive first-half slice would be wrong. We must reshape by head, then
160+
take the first d_head rows of each head (the query half).
161+
162+
Note: since self_attn is NOT currently mapped as a bridge submodule,
163+
these weights will not be loaded by the bridge. This method is included
164+
for correctness and forward-compatibility.
165+
"""
166+
n_heads = self.cfg.n_heads
167+
d_head = self.cfg.d_head
168+
keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")]
169+
for key in keys_to_update:
170+
w = state_dict[key] # shape: (n_heads * d_head * 2, hidden_size)
171+
# Reshape to expose per-head layout
172+
w = w.view(n_heads, d_head * 2, -1)
173+
# Take only the first d_head rows of each head (query half)
174+
state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1)
175+
return state_dict

transformer_lens/tools/model_registry/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
"Qwen2ForCausalLM",
8282
"Qwen3ForCausalLM",
8383
"Qwen3NextForCausalLM",
84+
"Qwen3_5ForCausalLM",
8485
"StableLmForCausalLM",
8586
"T5ForConditionalGeneration",
8687
}

transformer_lens/tools/model_registry/data/supported_models.json

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
"min_downloads": 500,
77
"scan_duration_seconds": 3.9
88
},
9-
"total_architectures": 40,
10-
"total_models": 6868,
11-
"total_verified": 699,
9+
"total_architectures": 43,
10+
"total_models": 7006,
11+
"total_verified": 704,
1212
"models": [
1313
{
1414
"architecture_id": "Qwen3NextForCausalLM",
@@ -99551,6 +99551,57 @@
9955199551
"phase4_score": null,
9955299552
"phase7_score": null,
9955399553
"phase8_score": null
99554+
},
99555+
{
99556+
"architecture_id": "Qwen3_5ForCausalLM",
99557+
"model_id": "Qwen/Qwen3.5-0.8B",
99558+
"status": 1,
99559+
"verified_date": "2026-04-14",
99560+
"metadata": {
99561+
"downloads": 2577198,
99562+
"total_params": 950000000
99563+
},
99564+
"note": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)",
99565+
"phase1_score": 100.0,
99566+
"phase2_score": 100.0,
99567+
"phase3_score": 94.1,
99568+
"phase4_score": 91.5,
99569+
"phase7_score": null,
99570+
"phase8_score": null
99571+
},
99572+
{
99573+
"architecture_id": "Qwen3_5ForCausalLM",
99574+
"model_id": "Qwen/Qwen3.5-4B",
99575+
"status": 1,
99576+
"verified_date": "2026-04-14",
99577+
"metadata": {
99578+
"downloads": 2920685,
99579+
"total_params": 3660000000
99580+
},
99581+
"note": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)",
99582+
"phase1_score": 100.0,
99583+
"phase2_score": 100.0,
99584+
"phase3_score": 94.1,
99585+
"phase4_score": 98.5,
99586+
"phase7_score": null,
99587+
"phase8_score": null
99588+
},
99589+
{
99590+
"architecture_id": "Qwen3_5ForCausalLM",
99591+
"model_id": "Qwen/Qwen3.5-9B",
99592+
"status": 0,
99593+
"verified_date": null,
99594+
"metadata": {
99595+
"downloads": 5662081,
99596+
"total_params": 8750000000
99597+
},
99598+
"note": null,
99599+
"phase1_score": null,
99600+
"phase2_score": null,
99601+
"phase3_score": null,
99602+
"phase4_score": null,
99603+
"phase7_score": null,
99604+
"phase8_score": null
9955499605
}
9955599606
]
9955699607
}

transformer_lens/tools/model_registry/data/verification_history.json

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"last_updated": "2026-04-10T18:43:37.000957",
2+
"last_updated": "2026-04-14T13:03:57.367589",
33
"records": [
44
{
55
"model_id": "Macropodus/macbert4mdcspell_v1",
@@ -11260,6 +11260,36 @@
1126011260
"notes": "Full verification completed",
1126111261
"invalidated": false,
1126211262
"invalidation_reason": null
11263+
},
11264+
{
11265+
"model_id": "Qwen/Qwen3.5-0.8B",
11266+
"architecture_id": "Qwen3_5ForCausalLM",
11267+
"verified_date": "2026-04-14",
11268+
"verified_by": "verify_models",
11269+
"transformerlens_version": null,
11270+
"notes": "Below threshold: P1=0.0% < 100.0% (failed: load_bridge_unprocessed) \u2014 Failed to load unprocessed TransformerBridge: Could not determine supported architecture from config. Available architectures: ['ApertusForCausalLM', ",
11271+
"invalidated": false,
11272+
"invalidation_reason": null
11273+
},
11274+
{
11275+
"model_id": "Qwen/Qwen3.5-0.8B",
11276+
"architecture_id": "Qwen3_5ForCausalLM",
11277+
"verified_date": "2026-04-14",
11278+
"verified_by": "verify_models",
11279+
"transformerlens_version": null,
11280+
"notes": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)",
11281+
"invalidated": false,
11282+
"invalidation_reason": null
11283+
},
11284+
{
11285+
"model_id": "Qwen/Qwen3.5-4B",
11286+
"architecture_id": "Qwen3_5ForCausalLM",
11287+
"verified_date": "2026-04-14",
11288+
"verified_by": "verify_models",
11289+
"transformerlens_version": null,
11290+
"notes": "Full verification completed with issues: P3=94.1% (failed: attention_output_centering)",
11291+
"invalidated": false,
11292+
"invalidation_reason": null
1126311293
}
1126411294
]
1126511295
}

0 commit comments

Comments
 (0)