Skip to content

Commit a41297f

Browse files
authored
Initial XGLM Adapter setup (#1250)
1 parent 71b42a0 commit a41297f

5 files changed

Lines changed: 734 additions & 0 deletions

File tree

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""Unit tests for XGLMArchitectureAdapter.
2+
3+
Tests cover:
4+
- Config attribute validation (all required attributes set correctly) [Phase A]
5+
- Weight conversion keys and structure [Phase A]
6+
- Component mapping structure (correct bridge types and HF module paths) [Phase B]
7+
- Embedding scale hook compatibility [Phase C]
8+
- Factory registration (XGLMForCausalLM maps to the right adapter) [Phase D]
9+
"""
10+
11+
import math
12+
from types import SimpleNamespace
13+
14+
import pytest
15+
import torch
16+
17+
from transformer_lens.config import TransformerBridgeConfig
18+
from transformer_lens.model_bridge.generalized_components import (
19+
AttentionBridge,
20+
BlockBridge,
21+
EmbeddingBridge,
22+
NormalizationBridge,
23+
SymbolicBridge,
24+
UnembeddingBridge,
25+
)
26+
from transformer_lens.model_bridge.supported_architectures.xglm import (
27+
XGLMArchitectureAdapter,
28+
)
29+
30+
# ---------------------------------------------------------------------------
31+
# Fixtures
32+
# ---------------------------------------------------------------------------
33+
34+
35+
def _make_cfg(
36+
n_heads: int = 4,
37+
d_model: int = 64,
38+
n_layers: int = 2,
39+
d_mlp: int = 256,
40+
d_vocab: int = 1000,
41+
n_ctx: int = 512,
42+
) -> TransformerBridgeConfig:
43+
"""Return a minimal TransformerBridgeConfig for XGLM adapter tests."""
44+
return TransformerBridgeConfig(
45+
d_model=d_model,
46+
d_head=d_model // n_heads,
47+
n_layers=n_layers,
48+
n_ctx=n_ctx,
49+
n_heads=n_heads,
50+
d_vocab=d_vocab,
51+
d_mlp=d_mlp,
52+
default_prepend_bos=True,
53+
architecture="XGLMForCausalLM",
54+
)
55+
56+
57+
@pytest.fixture
58+
def cfg() -> TransformerBridgeConfig:
59+
return _make_cfg()
60+
61+
62+
@pytest.fixture
63+
def adapter(cfg: TransformerBridgeConfig) -> XGLMArchitectureAdapter:
64+
return XGLMArchitectureAdapter(cfg)
65+
66+
67+
# ---------------------------------------------------------------------------
68+
# Phase A: Config attribute tests
69+
# ---------------------------------------------------------------------------
70+
71+
72+
class TestXGLMAdapterConfig:
73+
"""Adapter must set all required config attributes to the correct values."""
74+
75+
def test_normalization_type_is_ln(self, adapter: XGLMArchitectureAdapter) -> None:
76+
assert adapter.cfg.normalization_type == "LN"
77+
78+
def test_positional_embedding_type_is_standard(self, adapter: XGLMArchitectureAdapter) -> None:
79+
assert adapter.cfg.positional_embedding_type == "standard"
80+
81+
def test_final_rms_is_false(self, adapter: XGLMArchitectureAdapter) -> None:
82+
assert adapter.cfg.final_rms is False
83+
84+
def test_gated_mlp_is_false(self, adapter: XGLMArchitectureAdapter) -> None:
85+
assert adapter.cfg.gated_mlp is False
86+
87+
def test_attn_only_is_false(self, adapter: XGLMArchitectureAdapter) -> None:
88+
assert adapter.cfg.attn_only is False
89+
90+
def test_uses_rms_norm_is_false(self, adapter: XGLMArchitectureAdapter) -> None:
91+
assert adapter.cfg.uses_rms_norm is False
92+
93+
94+
# ---------------------------------------------------------------------------
95+
# Phase A: Weight processing conversion tests
96+
# ---------------------------------------------------------------------------
97+
98+
99+
class TestXGLMAdapterWeightConversions:
100+
"""Adapter must define exactly the four standard QKVO weight conversions."""
101+
102+
def test_q_weight_key_present(self, adapter: XGLMArchitectureAdapter) -> None:
103+
assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions
104+
105+
def test_k_weight_key_present(self, adapter: XGLMArchitectureAdapter) -> None:
106+
assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions
107+
108+
def test_v_weight_key_present(self, adapter: XGLMArchitectureAdapter) -> None:
109+
assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions
110+
111+
def test_o_weight_key_present(self, adapter: XGLMArchitectureAdapter) -> None:
112+
assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions
113+
114+
def test_exactly_four_conversion_keys(self, adapter: XGLMArchitectureAdapter) -> None:
115+
assert len(adapter.weight_processing_conversions) == 4
116+
117+
118+
# ---------------------------------------------------------------------------
119+
# Phase B: Component mapping structure tests
120+
# ---------------------------------------------------------------------------
121+
122+
123+
class TestXGLMAdapterComponentMapping:
124+
"""Component mapping must have the correct bridge types and HF module paths."""
125+
126+
def test_embed_is_embedding_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
127+
assert isinstance(adapter.component_mapping["embed"], EmbeddingBridge)
128+
129+
def test_embed_name(self, adapter: XGLMArchitectureAdapter) -> None:
130+
assert adapter.component_mapping["embed"].name == "model.embed_tokens"
131+
132+
def test_no_pos_embed_in_mapping(self, adapter: XGLMArchitectureAdapter) -> None:
133+
# Sinusoidal embeddings have no weights — no bridge entry expected
134+
assert "pos_embed" not in adapter.component_mapping
135+
136+
def test_blocks_is_block_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
137+
assert isinstance(adapter.component_mapping["blocks"], BlockBridge)
138+
139+
def test_blocks_name(self, adapter: XGLMArchitectureAdapter) -> None:
140+
assert adapter.component_mapping["blocks"].name == "model.layers"
141+
142+
def test_ln_final_is_normalization_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
143+
assert isinstance(adapter.component_mapping["ln_final"], NormalizationBridge)
144+
145+
def test_ln_final_name(self, adapter: XGLMArchitectureAdapter) -> None:
146+
assert adapter.component_mapping["ln_final"].name == "model.layer_norm"
147+
148+
def test_unembed_is_unembedding_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
149+
assert isinstance(adapter.component_mapping["unembed"], UnembeddingBridge)
150+
151+
def test_unembed_name(self, adapter: XGLMArchitectureAdapter) -> None:
152+
assert adapter.component_mapping["unembed"].name == "lm_head"
153+
154+
def test_ln1_is_normalization_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
155+
blocks = adapter.component_mapping["blocks"]
156+
assert isinstance(blocks.submodules["ln1"], NormalizationBridge)
157+
158+
def test_ln1_name(self, adapter: XGLMArchitectureAdapter) -> None:
159+
blocks = adapter.component_mapping["blocks"]
160+
assert blocks.submodules["ln1"].name == "self_attn_layer_norm"
161+
162+
def test_attn_is_attention_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
163+
blocks = adapter.component_mapping["blocks"]
164+
assert isinstance(blocks.submodules["attn"], AttentionBridge)
165+
166+
def test_attn_name(self, adapter: XGLMArchitectureAdapter) -> None:
167+
blocks = adapter.component_mapping["blocks"]
168+
assert blocks.submodules["attn"].name == "self_attn"
169+
170+
def test_attn_requires_attention_mask(self, adapter: XGLMArchitectureAdapter) -> None:
171+
blocks = adapter.component_mapping["blocks"]
172+
assert blocks.submodules["attn"].requires_attention_mask is True
173+
174+
def test_attn_attention_mask_4d(self, adapter: XGLMArchitectureAdapter) -> None:
175+
blocks = adapter.component_mapping["blocks"]
176+
assert blocks.submodules["attn"].attention_mask_4d is True
177+
178+
def test_attn_q_name(self, adapter: XGLMArchitectureAdapter) -> None:
179+
attn = adapter.component_mapping["blocks"].submodules["attn"]
180+
assert attn.submodules["q"].name == "q_proj"
181+
182+
def test_attn_k_name(self, adapter: XGLMArchitectureAdapter) -> None:
183+
attn = adapter.component_mapping["blocks"].submodules["attn"]
184+
assert attn.submodules["k"].name == "k_proj"
185+
186+
def test_attn_v_name(self, adapter: XGLMArchitectureAdapter) -> None:
187+
attn = adapter.component_mapping["blocks"].submodules["attn"]
188+
assert attn.submodules["v"].name == "v_proj"
189+
190+
def test_attn_o_name_is_out_proj(self, adapter: XGLMArchitectureAdapter) -> None:
191+
# Critical: XGLM uses out_proj, not o_proj (scaffold error pattern)
192+
attn = adapter.component_mapping["blocks"].submodules["attn"]
193+
assert attn.submodules["o"].name == "out_proj"
194+
195+
def test_ln2_is_normalization_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
196+
blocks = adapter.component_mapping["blocks"]
197+
assert isinstance(blocks.submodules["ln2"], NormalizationBridge)
198+
199+
def test_ln2_name(self, adapter: XGLMArchitectureAdapter) -> None:
200+
blocks = adapter.component_mapping["blocks"]
201+
assert blocks.submodules["ln2"].name == "final_layer_norm"
202+
203+
def test_mlp_is_symbolic_bridge(self, adapter: XGLMArchitectureAdapter) -> None:
204+
blocks = adapter.component_mapping["blocks"]
205+
assert isinstance(blocks.submodules["mlp"], SymbolicBridge)
206+
207+
def test_mlp_in_name(self, adapter: XGLMArchitectureAdapter) -> None:
208+
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
209+
assert mlp.submodules["in"].name == "fc1"
210+
211+
def test_mlp_out_name(self, adapter: XGLMArchitectureAdapter) -> None:
212+
mlp = adapter.component_mapping["blocks"].submodules["mlp"]
213+
assert mlp.submodules["out"].name == "fc2"
214+
215+
216+
# ---------------------------------------------------------------------------
217+
# Phase C: Embedding scale hook compatibility tests
218+
# ---------------------------------------------------------------------------
219+
220+
221+
def _make_mock_bridge() -> SimpleNamespace:
222+
"""Return a minimal mock bridge with embed.hook_out for hook-compat tests."""
223+
hook_out = SimpleNamespace(hook_conversion=None)
224+
embed = SimpleNamespace(hook_out=hook_out)
225+
return SimpleNamespace(embed=embed)
226+
227+
228+
class TestXGLMAdapterHookCompatibility:
229+
"""setup_hook_compatibility must attach a scale conversion to hook_embed."""
230+
231+
def test_sets_hook_conversion_on_embed_hook_out(self, adapter: XGLMArchitectureAdapter) -> None:
232+
bridge = _make_mock_bridge()
233+
adapter.setup_hook_compatibility(bridge)
234+
assert bridge.embed.hook_out.hook_conversion is not None
235+
236+
def test_scales_by_sqrt_d_model(self, adapter: XGLMArchitectureAdapter) -> None:
237+
# d_model=64, sqrt(64)=8 exactly
238+
bridge = _make_mock_bridge()
239+
adapter.setup_hook_compatibility(bridge)
240+
conv = bridge.embed.hook_out.hook_conversion
241+
x = torch.ones(2, 4, 64)
242+
result = conv.handle_conversion(x)
243+
expected_scale = math.sqrt(64) # 8.0
244+
assert torch.allclose(result, x * expected_scale, atol=1e-6)
245+
246+
def test_revert_inverts_scale(self, adapter: XGLMArchitectureAdapter) -> None:
247+
# round-trip: revert(handle_conversion(x)) == x; exact for sqrt(64)=8
248+
bridge = _make_mock_bridge()
249+
adapter.setup_hook_compatibility(bridge)
250+
conv = bridge.embed.hook_out.hook_conversion
251+
x = torch.randn(2, 4, 64)
252+
assert torch.allclose(conv.revert(conv.handle_conversion(x)), x, atol=1e-6)
253+
254+
def test_no_error_when_embed_missing(self, adapter: XGLMArchitectureAdapter) -> None:
255+
# Guard: if bridge lacks embed, setup_hook_compatibility should not raise
256+
bridge = SimpleNamespace() # no embed attribute
257+
adapter.setup_hook_compatibility(bridge) # must not raise
258+
259+
def test_no_error_when_hook_out_missing(self, adapter: XGLMArchitectureAdapter) -> None:
260+
# Guard: if embed lacks hook_out, no error expected
261+
bridge = SimpleNamespace(embed=SimpleNamespace()) # embed but no hook_out
262+
adapter.setup_hook_compatibility(bridge) # must not raise
263+
264+
265+
# ---------------------------------------------------------------------------
266+
# Phase D: Factory registration tests
267+
# ---------------------------------------------------------------------------
268+
269+
270+
class TestXGLMFactoryRegistration:
271+
"""XGLMForCausalLM must be registered in SUPPORTED_ARCHITECTURES and resolve correctly."""
272+
273+
def test_factory_returns_xglm_adapter(self) -> None:
274+
from transformer_lens.factories.architecture_adapter_factory import (
275+
ArchitectureAdapterFactory,
276+
)
277+
278+
cfg = _make_cfg()
279+
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
280+
assert isinstance(adapter, XGLMArchitectureAdapter)
281+
282+
def test_factory_key_is_xglm_for_causal_lm(self) -> None:
283+
from transformer_lens.factories.architecture_adapter_factory import (
284+
SUPPORTED_ARCHITECTURES,
285+
)
286+
287+
assert "XGLMForCausalLM" in SUPPORTED_ARCHITECTURES

transformer_lens/factories/architecture_adapter_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
QwenArchitectureAdapter,
5454
StableLmArchitectureAdapter,
5555
T5ArchitectureAdapter,
56+
XGLMArchitectureAdapter,
5657
)
5758

5859
# Export supported architectures
@@ -104,6 +105,7 @@
104105
"Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
105106
"StableLmForCausalLM": StableLmArchitectureAdapter,
106107
"T5ForConditionalGeneration": T5ArchitectureAdapter,
108+
"XGLMForCausalLM": XGLMArchitectureAdapter,
107109
"NanoGPTForCausalLM": NanogptArchitectureAdapter,
108110
"MinGPTForCausalLM": MingptArchitectureAdapter,
109111
"GPTNeoForCausalLM": NeoArchitectureAdapter,

transformer_lens/model_bridge/supported_architectures/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@
147147
from transformer_lens.model_bridge.supported_architectures.t5 import (
148148
T5ArchitectureAdapter,
149149
)
150+
from transformer_lens.model_bridge.supported_architectures.xglm import (
151+
XGLMArchitectureAdapter,
152+
)
150153

151154
__all__ = [
152155
"ApertusArchitectureAdapter",
@@ -197,4 +200,5 @@
197200
"Qwen3NextArchitectureAdapter",
198201
"StableLmArchitectureAdapter",
199202
"T5ArchitectureAdapter",
203+
"XGLMArchitectureAdapter",
200204
]

0 commit comments

Comments
 (0)