Skip to content
58 changes: 42 additions & 16 deletions tests/unit/test_qwen3_5_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@
)
from transformer_lens.tools.model_registry import HF_SUPPORTED_ARCHITECTURES

try:
from transformers import Qwen3_5ForCausalLM as _Qwen3_5ForCausalLM
from transformers import Qwen3_5TextConfig

_QWEN3_5_AVAILABLE = True
except ImportError:
_QWEN3_5_AVAILABLE = False

# ============================================================================
# Test: Registration
# ============================================================================


@pytest.mark.skipif(
not _QWEN3_5_AVAILABLE,
reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers",
)
class TestQwen3_5Registration:
"""Verify the adapter is properly registered in all lookup tables."""

Expand Down Expand Up @@ -79,6 +91,10 @@ def _make_bridge_cfg(**overrides):
# ============================================================================


@pytest.mark.skipif(
not _QWEN3_5_AVAILABLE,
reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers",
)
class TestQwen3_5ComponentMapping:
"""Verify the component_mapping structure for Qwen3_5.

Expand Down Expand Up @@ -134,18 +150,28 @@ def test_unembed_path(self, adapter):
# ---- Block submodules ----

def test_block_submodules_keys(self, adapter):
"""blocks submodules must contain ln1, ln2, mlp but NOT attn.
"""blocks submodules must contain ln1, ln2, mlp, and optional attn + linear_attn."""
submodules = adapter.component_mapping["blocks"].submodules
assert set(submodules.keys()) == {"ln1", "ln2", "mlp", "attn", "linear_attn"}

Critical correctness test: self_attn is absent on linear-attention
layers, so mapping attn as a block submodule would crash on those layers.
"""
def test_attn_is_optional(self, adapter):
"""attn must be marked optional (absent on linear-attention layers)."""
submodules = adapter.component_mapping["blocks"].submodules
assert submodules["attn"].optional is True

def test_linear_attn_is_optional(self, adapter):
"""linear_attn must be marked optional (absent on full-attention layers)."""
submodules = adapter.component_mapping["blocks"].submodules
assert set(submodules.keys()) == {"ln1", "ln2", "mlp"}
assert submodules["linear_attn"].optional is True

def test_linear_attn_bridge_type(self, adapter):
"""linear_attn must be a GatedDeltaNetBridge."""
from transformer_lens.model_bridge.generalized_components.gated_delta_net import (
GatedDeltaNetBridge,
)

def test_no_attn_in_block_submodules(self, adapter):
"""attn must NOT appear as a block submodule (hybrid architecture safety check)."""
submodules = adapter.component_mapping["blocks"].submodules
assert "attn" not in submodules
assert isinstance(submodules["linear_attn"], GatedDeltaNetBridge)

def test_ln1_path(self, adapter):
"""ln1 maps to input_layernorm."""
Expand Down Expand Up @@ -257,6 +283,10 @@ def test_weight_processing_conversions_empty(self, adapter):
# ============================================================================


@pytest.mark.skipif(
not _QWEN3_5_AVAILABLE,
reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers",
)
class TestQwen3_5ConfigAttributes:
"""Verify all cfg attributes are set correctly by the adapter."""

Expand Down Expand Up @@ -341,6 +371,10 @@ def test_n_key_value_heads_not_set_when_absent(self):
# ============================================================================


@pytest.mark.skipif(
not _QWEN3_5_AVAILABLE,
reason="Qwen3_5TextConfig / Qwen3_5ForCausalLM not available in installed transformers",
)
class TestQwen3_5PreprocessWeights:
"""Verify preprocess_weights correctly slices q_proj.weight per-head.

Expand Down Expand Up @@ -478,14 +512,6 @@ def test_weight_processing_conversions_is_empty_dict(self, adapter):
# Test: Integration (Phase A+B)
# ============================================================================

try:
from transformers import Qwen3_5ForCausalLM as _Qwen3_5ForCausalLM
from transformers import Qwen3_5TextConfig

_QWEN3_5_AVAILABLE = True
except ImportError:
_QWEN3_5_AVAILABLE = False


def _make_tiny_hf_model():
"""Create a tiny Qwen3_5ForCausalLM for integration testing.
Expand Down
27 changes: 18 additions & 9 deletions tests/unit/test_qwen3_next_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,28 @@ def test_unembed_path(self, adapter):
# ---- Block submodules ----

def test_block_submodules_keys(self, adapter):
"""blocks submodules must contain ln1, ln2, mlp but NOT attn.
"""blocks submodules must contain ln1, ln2, mlp, and optional attn + linear_attn."""
submodules = adapter.component_mapping["blocks"].submodules
assert set(submodules.keys()) == {"ln1", "ln2", "mlp", "attn", "linear_attn"}

This is a critical correctness test: self_attn is absent on
linear-attention layers, so mapping attn as a block submodule
would crash on those layers.
"""
def test_attn_is_optional(self, adapter):
"""attn must be marked optional (absent on linear-attention layers)."""
submodules = adapter.component_mapping["blocks"].submodules
assert submodules["attn"].optional is True

def test_linear_attn_is_optional(self, adapter):
"""linear_attn must be marked optional (absent on full-attention layers)."""
submodules = adapter.component_mapping["blocks"].submodules
assert set(submodules.keys()) == {"ln1", "ln2", "mlp"}
assert submodules["linear_attn"].optional is True

def test_linear_attn_bridge_type(self, adapter):
"""linear_attn must be a GatedDeltaNetBridge."""
from transformer_lens.model_bridge.generalized_components.gated_delta_net import (
GatedDeltaNetBridge,
)

def test_no_attn_in_block_submodules(self, adapter):
"""attn must NOT appear as a block submodule (hybrid architecture safety check)."""
submodules = adapter.component_mapping["blocks"].submodules
assert "attn" not in submodules
assert isinstance(submodules["linear_attn"], GatedDeltaNetBridge)

def test_ln1_path(self, adapter):
"""ln1 maps to input_layernorm."""
Expand Down
6 changes: 5 additions & 1 deletion transformer_lens/benchmarks/component_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,12 @@ def benchmark_all_components(
n_layers = self.cfg.n_layers

for layer_idx in range(n_layers):
# Recursively test each subcomponent and its nested subcomponents
# Get the actual block to check which submodules were bound
actual_block = getattr(self.bridge_model, block_type)[layer_idx]
for subcomp_name, subcomponent in blocks_component.submodules.items():
# Skip optional submodules absent on this layer (hybrid architectures)
if subcomp_name not in actual_block._modules:
continue
comp_path = f"{block_type}.{layer_idx}.{subcomp_name}"
self._test_component_recursive(
comp_path, subcomponent, test_inputs, results, skip_components
Expand Down
27 changes: 20 additions & 7 deletions transformer_lens/benchmarks/weight_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,24 @@ def benchmark_mlp_output_centering(
message="Skipped for tiny/test model (random weights don't center meaningfully)",
)

# Check if this is an MoE model - MoE models don't have a single W_out weight
# Find an MLP-like submodule (may be "mlp", "shared_mlp", etc.)
from transformer_lens.model_bridge.generalized_components.moe import MoEBridge

if isinstance(bridge.blocks[0].mlp, MoEBridge):
mlp_module = None
block = bridge.blocks[0]
for name in ("mlp", "shared_mlp"):
if name in block._modules:
mlp_module = block._modules[name]
break
if mlp_module is None:
return BenchmarkResult(
name="mlp_output_centering",
severity=BenchmarkSeverity.WARNING,
message="No MLP submodule found on block 0",
passed=False,
)

if isinstance(mlp_module, MoEBridge):
return BenchmarkResult(
name="mlp_output_centering",
severity=BenchmarkSeverity.INFO,
Expand All @@ -651,11 +665,10 @@ def benchmark_mlp_output_centering(

# Check if W_out exists and is accessible (HT format or bridge format)
w_out = None
if hasattr(bridge.blocks[0].mlp, "W_out"):
w_out = bridge.blocks[0].mlp.W_out
elif hasattr(bridge.blocks[0].mlp, "out"):
# Bridge format: mlp.out is a LinearBridge wrapping nn.Linear
out_module = bridge.blocks[0].mlp.out
if hasattr(mlp_module, "W_out"):
w_out = mlp_module.W_out
elif hasattr(mlp_module, "out"):
out_module = mlp_module.out
if hasattr(out_module, "original_component") and hasattr(
out_module.original_component, "weight"
):
Expand Down
5 changes: 3 additions & 2 deletions transformer_lens/model_bridge/component_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def setup_submodules(
else:
remote_path = submodule.name
is_optional = getattr(submodule, "optional", False)
# Fast path: first segment absent → skip without entering get_remote_component
# Fast path: first segment absent or None → skip
first_segment = remote_path.split(".")[0]
if is_optional and not hasattr(original_model, first_segment):
first_value = getattr(original_model, first_segment, None)
if is_optional and first_value is None:
logger.debug(
"Optional '%s' (path '%s') absent on %s",
module_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention import (
ALiBiJointQKVAttentionBridge,
)
from transformer_lens.model_bridge.generalized_components.gated_delta_net import (
GatedDeltaNetBridge,
)
from transformer_lens.model_bridge.generalized_components.gated_mlp import (
GatedMLPBridge,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
requires_position_embeddings: bool = False,
requires_attention_mask: bool = False,
attention_mask_4d: bool = False,
optional: bool = False,
):
"""Initialize the attention bridge.

Expand All @@ -82,7 +83,11 @@ def __init__(
if conversion_rule is None:
conversion_rule = AttentionAutoConversion(config)
super().__init__(
name, config=config, submodules=submodules or {}, conversion_rule=conversion_rule
name,
config=config,
submodules=submodules or {},
conversion_rule=conversion_rule,
optional=optional,
)
self.hook_attn_scores = HookPoint()
self.hook_pattern = HookPoint()
Expand Down
Loading
Loading