diff --git a/tests/integration/model_bridge/test_mpt_adapter.py b/tests/integration/model_bridge/test_mpt_adapter.py new file mode 100644 index 000000000..a9c94cbfd --- /dev/null +++ b/tests/integration/model_bridge/test_mpt_adapter.py @@ -0,0 +1,182 @@ +"""Integration tests for MPT architecture adapter — Phase C. + +Builds a tiny MptForCausalLM programmatically (no HF Hub download) and wraps +it in TransformerBridge via MPTArchitectureAdapter. Verifies: + +- Forward output matches HF at max_diff < 1e-4 +- Attention hooks fire: blocks.0.attn.hook_q/k/v, hook_attn_scores, hook_pattern +- MLP hooks fire: blocks.0.mlp.hook_in, blocks.0.mlp.hook_out +- Norm hooks fire: blocks.0.ln1.hook_out, blocks.0.ln2.hook_out +- Residual stream hooks fire: blocks.0.hook_resid_pre, blocks.0.hook_resid_post +""" + +import pytest +import torch +from transformers import MptConfig +from transformers.models.mpt.modeling_mpt import MptForCausalLM + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.bridge import TransformerBridge +from transformer_lens.model_bridge.supported_architectures.mpt import ( + MPTArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Tiny model parameters — deterministic, no download, fits in <50 MB RAM +# --------------------------------------------------------------------------- +_D_MODEL = 64 +_N_HEADS = 2 +_N_LAYERS = 2 +_D_MLP = 256 +_D_VOCAB = 256 +_MAX_SEQ_LEN = 32 + + +def _make_hf_config() -> MptConfig: + """Return a tiny MptConfig. max_seq_len is the MPT-specific name.""" + return MptConfig( + d_model=_D_MODEL, + n_heads=_N_HEADS, + n_layers=_N_LAYERS, + expansion_ratio=_D_MLP // _D_MODEL, + max_seq_len=_MAX_SEQ_LEN, + vocab_size=_D_VOCAB, + no_bias=True, + ) + + +def _make_bridge() -> TransformerBridge: + """Construct a TransformerBridge from a programmatic tiny MptForCausalLM. + + Bypasses boot_transformers (which calls AutoConfig.from_pretrained) and + directly instantiates the adapter and bridge. Safe for CI — no download. + """ + hf_cfg = _make_hf_config() + hf_model = MptForCausalLM(hf_cfg) + hf_model.eval() + + bridge_cfg = TransformerBridgeConfig( + d_model=_D_MODEL, + d_head=_D_MODEL // _N_HEADS, + n_layers=_N_LAYERS, + n_ctx=_MAX_SEQ_LEN, + n_heads=_N_HEADS, + d_vocab=_D_VOCAB, + d_mlp=_D_MLP, + default_prepend_bos=False, + architecture="MPTForCausalLM", + device="cpu", + ) + + adapter = MPTArchitectureAdapter(bridge_cfg) + return TransformerBridge(hf_model, adapter, tokenizer=None) + + +# --------------------------------------------------------------------------- +# Module-scoped fixture — one bridge for the whole file +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def mpt_bridge() -> TransformerBridge: + return _make_bridge() + + +# --------------------------------------------------------------------------- +# Forward pass: HF numerical equivalence +# --------------------------------------------------------------------------- + + +class TestMPTForwardPass: + """Bridge forward must match HF MptForCausalLM.forward at atol=1e-4.""" + + def test_forward_output_shape(self, mpt_bridge: TransformerBridge) -> None: + tokens = torch.randint(0, _D_VOCAB, (1, 8)) + with torch.no_grad(): + out = mpt_bridge(tokens) + assert out.shape == (1, 8, _D_VOCAB) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_forward_matches_hf(self, mpt_bridge: TransformerBridge) -> None: + """Logits from bridge must match HF at max_diff < 1e-4.""" + tokens = torch.randint(0, _D_VOCAB, (1, 8)) + hf_model = mpt_bridge.original_model + with torch.no_grad(): + bridge_out = mpt_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 1e-4, f"Bridge vs HF max diff = {max_diff:.2e} (threshold 1e-4)" + + def test_forward_batch2_matches_hf(self, mpt_bridge: TransformerBridge) -> None: + """Batch=2 check: no batch-broadcast bug in ALiBi unsqueeze(0) path.""" + tokens = torch.randint(0, _D_VOCAB, (2, 8)) + hf_model = mpt_bridge.original_model + with torch.no_grad(): + bridge_out = mpt_bridge(tokens) + hf_out = hf_model(tokens).logits + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 1e-4, f"Batch=2 bridge vs HF max diff = {max_diff:.2e}" + + +# --------------------------------------------------------------------------- +# Hook coverage via run_with_cache +# --------------------------------------------------------------------------- + + +class TestMPTHookCoverage: + """All required hooks must appear in the cache after a single forward pass.""" + + @pytest.fixture(scope="class") + def cache(self, mpt_bridge: TransformerBridge) -> dict: + tokens = torch.randint(0, _D_VOCAB, (1, 8)) + with torch.no_grad(): + _, cache = mpt_bridge.run_with_cache(tokens) + return cache + + # Attention hooks + def test_hook_q_fires(self, cache: dict) -> None: + assert "blocks.0.attn.hook_q" in cache, f"keys: {sorted(cache.keys())}" + + def test_hook_k_fires(self, cache: dict) -> None: + assert "blocks.0.attn.hook_k" in cache + + def test_hook_v_fires(self, cache: dict) -> None: + assert "blocks.0.attn.hook_v" in cache + + def test_hook_attn_scores_fires(self, cache: dict) -> None: + assert "blocks.0.attn.hook_attn_scores" in cache + + def test_hook_pattern_fires(self, cache: dict) -> None: + assert "blocks.0.attn.hook_pattern" in cache + + # MLP hooks + def test_hook_mlp_in_fires(self, cache: dict) -> None: + assert "blocks.0.mlp.hook_in" in cache + + def test_hook_mlp_out_fires(self, cache: dict) -> None: + assert "blocks.0.mlp.hook_out" in cache + + # Norm hooks + def test_hook_ln1_fires(self, cache: dict) -> None: + assert "blocks.0.ln1.hook_out" in cache + + def test_hook_ln2_fires(self, cache: dict) -> None: + assert "blocks.0.ln2.hook_out" in cache + + # Residual stream hooks + def test_hook_resid_pre_fires(self, cache: dict) -> None: + assert "blocks.0.hook_resid_pre" in cache + + def test_hook_resid_post_fires(self, cache: dict) -> None: + assert "blocks.0.hook_resid_post" in cache + + # Shape sanity: attention pattern must be causal (lower-triangular) + def test_attn_pattern_is_causal(self, cache: dict) -> None: + """Attention pattern upper triangle must be zero (causal structure).""" + pattern = cache["blocks.0.attn.hook_pattern"] # [batch, n_heads, seq, seq] + seq = pattern.shape[-1] + upper = torch.triu(pattern[0, 0], diagonal=1) + assert ( + upper.abs() < 1e-6 + ).all(), f"Attention pattern is not causal; upper-triangle max = {upper.abs().max():.2e}" diff --git a/tests/unit/model_bridge/generalized_components/test_mpt_alibi_attention.py b/tests/unit/model_bridge/generalized_components/test_mpt_alibi_attention.py new file mode 100644 index 000000000..da8e2515a --- /dev/null +++ b/tests/unit/model_bridge/generalized_components/test_mpt_alibi_attention.py @@ -0,0 +1,321 @@ +"""Unit tests for MPTALiBiAttentionBridge. + +Exercises the reimplemented MPT ALiBi attention against a live MptAttention +module — no Hub download, tiny programmatic config only. + +Covers: +- Numerical match vs HF MptAttention.forward at atol=1e-5 (batch=2) +- ALiBi slicing when seq_len < max_seq_len +- Boolean causal mask enforces causal structure +- hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern all fire +""" + +import torch +import torch.nn as nn +from transformers import MptConfig +from transformers.models.mpt.modeling_mpt import MptAttention + +from transformer_lens.model_bridge.generalized_components import LinearBridge +from transformer_lens.model_bridge.generalized_components.mpt_alibi_attention import ( + MPTALiBiAttentionBridge, + _build_mpt_alibi_tensor, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _MockConfig: + """Minimal config for MPTALiBiAttentionBridge.""" + + def __init__(self, n_heads: int, d_model: int) -> None: + self.n_heads = n_heads + self.d_model = d_model + self.d_head = d_model // n_heads + self.n_key_value_heads = None + self.positional_embedding_type = "alibi" + + +def _make_tiny_mpt_attention(d_model: int = 64, n_heads: int = 2) -> MptAttention: + """Create a tiny MptAttention with random weights — no download.""" + cfg = MptConfig( + d_model=d_model, + n_heads=n_heads, + n_layers=1, + expansion_ratio=4, + max_seq_len=32, + vocab_size=256, + no_bias=True, + ) + return MptAttention(cfg) + + +def _make_split_fn(d_model: int) -> object: + """Return a split_qkv_matrix function for MPT's Wqkv [3*d_model, d_model] layout. + + Splits row-wise along dim=0 (NOT the GPT-2 style dim=1 split). + """ + + def split_qkv(attn_component: object) -> tuple[nn.Linear, nn.Linear, nn.Linear]: + w = attn_component.Wqkv.weight.detach().clone() # [3*d_model, d_model] + w_q, w_k, w_v = torch.chunk(w, 3, dim=0) # each [d_model, d_model] + + def make_linear(weight: torch.Tensor) -> nn.Linear: + lin = nn.Linear(d_model, d_model, bias=False, device=weight.device, dtype=weight.dtype) + lin.weight = nn.Parameter(weight.contiguous()) + return lin + + return make_linear(w_q), make_linear(w_k), make_linear(w_v) + + return split_qkv + + +def _build_bridge(hf_attn: MptAttention) -> MPTALiBiAttentionBridge: + """Wrap a live MptAttention in MPTALiBiAttentionBridge with the correct QKV split.""" + d_model = hf_attn.hidden_size + n_heads = hf_attn.n_heads + cfg = _MockConfig(n_heads=n_heads, d_model=d_model) + + bridge = MPTALiBiAttentionBridge( + name="attn", + config=cfg, + split_qkv_matrix=_make_split_fn(d_model), + submodules={ + "qkv": LinearBridge(name="Wqkv"), + "o": LinearBridge(name="out_proj"), + }, + ) + bridge.set_original_component(hf_attn) + return bridge + + +def _make_inputs( + d_model: int, + n_heads: int, + max_seq_len: int, + seq_len: int, + batch_size: int = 2, +) -> dict: + """Build matching inputs for both HF MptAttention and bridge forward.""" + hidden = torch.randn(batch_size, seq_len, d_model) + # position_bias: [n_heads, 1, max_seq_len] as MptModel.forward produces + position_bias = _build_mpt_alibi_tensor(n_heads, max_seq_len) + # bool causal mask: [batch, 1, seq, seq] + causal = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) + causal_mask = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) + return { + "hidden": hidden, + "position_bias": position_bias, + "causal_mask": causal_mask, + } + + +# --------------------------------------------------------------------------- +# Numerical match against HF MptAttention.forward +# --------------------------------------------------------------------------- + + +class TestMPTALiBiMatchesHF: + """Bridge output must numerically match HF MptAttention.forward.""" + + def test_forward_matches_hf_batch2(self) -> None: + """Primary correctness test: batch_size=2, seq_len=8. + + Uses batch_size >= 2 to catch any latent batch-broadcast bug in the + position_bias unsqueeze(0) path. + """ + hf_attn = _make_tiny_mpt_attention(d_model=64, n_heads=2) + hf_attn.eval() + bridge = _build_bridge(hf_attn) + + inputs = _make_inputs(d_model=64, n_heads=2, max_seq_len=32, seq_len=8, batch_size=2) + + with torch.no_grad(): + hf_out, *_ = hf_attn( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + bridge_out, *_ = bridge( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 1e-5, f"Bridge vs HF max diff = {max_diff:.2e} (threshold 1e-5)" + + def test_forward_matches_hf_batch1(self) -> None: + """Single-batch sanity check — no batch-dim interaction.""" + hf_attn = _make_tiny_mpt_attention(d_model=64, n_heads=2) + hf_attn.eval() + bridge = _build_bridge(hf_attn) + + inputs = _make_inputs(d_model=64, n_heads=2, max_seq_len=32, seq_len=6, batch_size=1) + + with torch.no_grad(): + hf_out, *_ = hf_attn( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + bridge_out, *_ = bridge( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 1e-5, f"Bridge vs HF max diff = {max_diff:.2e}" + + +# --------------------------------------------------------------------------- +# ALiBi slicing: seq_len < max_seq_len +# --------------------------------------------------------------------------- + + +class TestALiBiSlicing: + """position_bias covers max_seq_len; bridge must slice to current kv_len.""" + + def test_short_seq_slicing_no_error(self) -> None: + """seq_len=4 with max_seq_len=32: bridge must slice without error.""" + hf_attn = _make_tiny_mpt_attention(d_model=64, n_heads=2) + hf_attn.eval() + bridge = _build_bridge(hf_attn) + + # max_seq_len=32, but only use seq_len=4 + inputs = _make_inputs(d_model=64, n_heads=2, max_seq_len=32, seq_len=4, batch_size=2) + + with torch.no_grad(): + out, *_ = bridge( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + + assert out.shape == (2, 4, 64) + assert not torch.isnan(out).any() + + def test_slicing_matches_hf_short_seq(self) -> None: + """Bridge output at seq_len=4 must match HF at same seq_len.""" + hf_attn = _make_tiny_mpt_attention(d_model=64, n_heads=2) + hf_attn.eval() + bridge = _build_bridge(hf_attn) + + inputs = _make_inputs(d_model=64, n_heads=2, max_seq_len=32, seq_len=4, batch_size=2) + + with torch.no_grad(): + hf_out, *_ = hf_attn( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + bridge_out, *_ = bridge( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + + max_diff = (bridge_out - hf_out).abs().max().item() + assert max_diff < 1e-5, f"Short-seq bridge vs HF diff = {max_diff:.2e}" + + +# --------------------------------------------------------------------------- +# Boolean mask enforces causal structure +# --------------------------------------------------------------------------- + + +class TestBoolMaskCausalStructure: + """Boolean causal mask must zero out attention to future tokens.""" + + def test_bool_mask_applies_causal_structure(self) -> None: + """Attention pattern must be lower-triangular under a strict causal bool mask. + + Uses batch_size=2 to also exercise the bool-mask path with batched input. + """ + d_model, n_heads, seq_len, batch_size = 64, 2, 6, 2 + hf_attn = _make_tiny_mpt_attention(d_model=d_model, n_heads=n_heads) + hf_attn.eval() + bridge = _build_bridge(hf_attn) + + hidden = torch.randn(batch_size, seq_len, d_model) + position_bias = _build_mpt_alibi_tensor(n_heads, seq_len) + causal = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1) + causal_mask = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) + + pattern_captured: dict[str, torch.Tensor] = {} + + def capture_pattern(tensor, hook): + pattern_captured["pattern"] = tensor.detach().clone() + return tensor + + bridge.hook_pattern.add_hook(capture_pattern) + + with torch.no_grad(): + bridge(hidden, position_bias=position_bias, attention_mask=causal_mask) + + assert "pattern" in pattern_captured + pattern = pattern_captured["pattern"] # [batch, n_heads, seq, seq] + # Upper triangle (above diagonal) must be zero after causal masking + upper = torch.triu(pattern[0, 0], diagonal=1) + assert (upper.abs() < 1e-6).all(), "Future positions must have zero attention weight" + + +# --------------------------------------------------------------------------- +# Hook coverage +# --------------------------------------------------------------------------- + + +class TestHooksFire: + """All required hooks must fire during a single forward pass.""" + + def _run_forward_with_hooks(self, hook_names: list[str]) -> dict[str, torch.Tensor]: + d_model, n_heads, seq_len = 64, 2, 8 + hf_attn = _make_tiny_mpt_attention(d_model=d_model, n_heads=n_heads) + hf_attn.eval() + bridge = _build_bridge(hf_attn) + + captured: dict[str, torch.Tensor] = {} + + def make_hook(name: str): + def fn(tensor, hook): + captured[name] = tensor.detach().clone() + return tensor + + return fn + + for hook_name in hook_names: + hook_obj: object = bridge + for part in hook_name.split("."): + hook_obj = getattr(hook_obj, part) + hook_obj.add_hook(make_hook(hook_name)) # type: ignore[union-attr] + + inputs = _make_inputs(d_model=d_model, n_heads=n_heads, max_seq_len=32, seq_len=seq_len) + with torch.no_grad(): + bridge( + inputs["hidden"], + position_bias=inputs["position_bias"], + attention_mask=inputs["causal_mask"], + ) + return captured + + def test_hook_q_fires(self) -> None: + captured = self._run_forward_with_hooks(["q.hook_out"]) + assert "q.hook_out" in captured + + def test_hook_k_fires(self) -> None: + captured = self._run_forward_with_hooks(["k.hook_out"]) + assert "k.hook_out" in captured + + def test_hook_v_fires(self) -> None: + captured = self._run_forward_with_hooks(["v.hook_out"]) + assert "v.hook_out" in captured + + def test_hook_attn_scores_fires(self) -> None: + captured = self._run_forward_with_hooks(["hook_attn_scores"]) + assert "hook_attn_scores" in captured + + def test_hook_pattern_fires(self) -> None: + captured = self._run_forward_with_hooks(["hook_pattern"]) + assert "hook_pattern" in captured diff --git a/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py b/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py new file mode 100644 index 000000000..695a6e21b --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_mpt_adapter.py @@ -0,0 +1,283 @@ +"""Unit tests for MPTArchitectureAdapter — Phase A (config + weight conversions), +Phase B-1 (component mapping + QKV split), and Phase D (factory registration). + +Tests cover: +- Config attribute validation (all required attributes set correctly) +- Weight conversion keys (four standard QKVO keys with .weight suffix) +- LayerNorm with bias=None wraps without error (MptBlock sets norm.bias = None) +- Component mapping keys (embed/blocks/ln_final/unembed; no pos_embed/rotary_emb) +- Block/attn/mlp submodule keys +- _split_mpt_qkv: output shapes and round-trip correctness +- Factory resolves MPTForCausalLM -> MPTArchitectureAdapter (no download) +""" + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.supported_architectures.mpt import ( + MPTArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 2, + d_model: int = 64, + n_layers: int = 2, + d_mlp: int = 256, + d_vocab: int = 256, + n_ctx: int = 128, +) -> TransformerBridgeConfig: + """Return a minimal TransformerBridgeConfig for MPT adapter tests. + + Uses tiny dimensions — no HF Hub download required. + """ + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + d_mlp=d_mlp, + default_prepend_bos=False, + architecture="MPTForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> MPTArchitectureAdapter: + return MPTArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestMPTAdapterConfig: + """Verify all required config attributes are set correctly.""" + + def test_normalization_type_is_ln(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type_is_alibi(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "alibi" + + def test_gated_mlp_is_false(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_final_rms_is_false(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_attn_only_is_false(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_default_prepend_bos_is_false(self, adapter: MPTArchitectureAdapter) -> None: + assert adapter.cfg.default_prepend_bos is False + + +# --------------------------------------------------------------------------- +# Weight processing conversion tests +# --------------------------------------------------------------------------- + + +class TestMPTAdapterWeightConversions: + """Verify weight_processing_conversions has exactly the four QKVO keys.""" + + def test_q_weight_key_present(self, adapter: MPTArchitectureAdapter) -> None: + assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions + + def test_k_weight_key_present(self, adapter: MPTArchitectureAdapter) -> None: + assert "blocks.{i}.attn.k.weight" in adapter.weight_processing_conversions + + def test_v_weight_key_present(self, adapter: MPTArchitectureAdapter) -> None: + assert "blocks.{i}.attn.v.weight" in adapter.weight_processing_conversions + + def test_o_weight_key_present(self, adapter: MPTArchitectureAdapter) -> None: + assert "blocks.{i}.attn.o.weight" in adapter.weight_processing_conversions + + def test_exactly_four_conversion_keys(self, adapter: MPTArchitectureAdapter) -> None: + # No MLP conversions — up_proj/down_proj use standard [out, in] layout. + assert len(adapter.weight_processing_conversions) == 4 + + def test_no_mlp_conversion_keys(self, adapter: MPTArchitectureAdapter) -> None: + keys = adapter.weight_processing_conversions + assert not any("mlp" in k for k in keys), "MLP weights need no special conversion" + + +# --------------------------------------------------------------------------- +# LayerNorm with bias=None test +# --------------------------------------------------------------------------- + + +class TestMPTLayerNormBiasNone: + """Verify NormalizationBridge handles MPT's bias=None LayerNorm correctly.""" + + def test_layernorm_bias_none_wraps_without_error(self, cfg: TransformerBridgeConfig) -> None: + """NormalizationBridge must accept and forward through a bias=None LayerNorm. + + MptBlock.__init__ explicitly sets norm_1.bias = None for backward compatibility + with Hub weights. This test front-loads any surprise from that pattern. + """ + from transformer_lens.model_bridge.generalized_components import ( + NormalizationBridge, + ) + + # Replicate what MptBlock does: LayerNorm then strip bias + ln = nn.LayerNorm(cfg.d_model, eps=1e-5) + ln.bias = None # exactly as MptBlock.__init__ does + + bridge = NormalizationBridge(name="norm_1", config=cfg) + bridge.set_original_component(ln) + + x = torch.randn(2, 4, cfg.d_model) + with torch.no_grad(): + out = bridge(x) + + assert out.shape == x.shape, "Output shape must match input shape" + assert not torch.isnan(out).any(), "Output must not contain NaN" + assert not torch.isinf(out).any(), "Output must not contain Inf" + + +# --------------------------------------------------------------------------- +# Component mapping structure tests (Phase B-1) +# --------------------------------------------------------------------------- + + +class TestMPTComponentMappingKeys: + """Verify top-level and nested component mapping keys are correct.""" + + def test_top_level_keys_present(self, adapter: MPTArchitectureAdapter) -> None: + keys = set(adapter.component_mapping.keys()) + assert {"embed", "blocks", "ln_final", "unembed"} <= keys + + def test_no_pos_embed_key(self, adapter: MPTArchitectureAdapter) -> None: + # ALiBi has no learnable positional embedding module. + assert "pos_embed" not in adapter.component_mapping + + def test_no_rotary_emb_key(self, adapter: MPTArchitectureAdapter) -> None: + assert "rotary_emb" not in adapter.component_mapping + + def test_block_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: + block = adapter.component_mapping["blocks"] + subkeys = set(block.submodules.keys()) + assert {"ln1", "attn", "ln2", "mlp"} <= subkeys + + def test_attn_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: + attn = adapter.component_mapping["blocks"].submodules["attn"] + subkeys = set(attn.submodules.keys()) + # qkv and o are the projection submodules; q/k/v are created during split + assert {"qkv", "o"} <= subkeys + + def test_mlp_submodule_keys(self, adapter: MPTArchitectureAdapter) -> None: + mlp = adapter.component_mapping["blocks"].submodules["mlp"] + subkeys = set(mlp.submodules.keys()) + assert {"in", "out"} <= subkeys + + +# --------------------------------------------------------------------------- +# _split_mpt_qkv tests (Phase B-1) +# --------------------------------------------------------------------------- + + +class TestMPTSplitQKV: + """Verify _split_mpt_qkv correctly decomposes Wqkv [3*d_model, d_model].""" + + def _make_fake_attn_component(self, d_model: int) -> object: + """Return a stub object with a Wqkv Linear attribute (no bias, row-concat layout).""" + + class _FakeAttn(nn.Module): + def __init__(self) -> None: + super().__init__() + # Wqkv: [3*d_model, d_model] — MPT row-wise concat layout + self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False) + + return _FakeAttn() + + def test_split_returns_three_linears(self, adapter: MPTArchitectureAdapter) -> None: + d_model = adapter.cfg.d_model + fake_attn = self._make_fake_attn_component(d_model) + result = adapter._split_mpt_qkv(fake_attn) + assert len(result) == 3 + assert all(isinstance(lin, nn.Linear) for lin in result) + + def test_split_output_shapes(self, adapter: MPTArchitectureAdapter) -> None: + """Each output linear must have weight shape [d_model, d_model].""" + d_model = adapter.cfg.d_model + fake_attn = self._make_fake_attn_component(d_model) + q_lin, k_lin, v_lin = adapter._split_mpt_qkv(fake_attn) + for lin in (q_lin, k_lin, v_lin): + assert lin.weight.shape == ( + d_model, + d_model, + ), f"Expected ({d_model}, {d_model}), got {lin.weight.shape}" + + def test_split_roundtrip(self, adapter: MPTArchitectureAdapter) -> None: + """cat([q.weight, k.weight, v.weight], dim=0) must recover original Wqkv.weight. + + Uses batch_size=2 worth of distinct rows to surface any row/col transposition. + """ + d_model = adapter.cfg.d_model + fake_attn = self._make_fake_attn_component(d_model) + original_w = fake_attn.Wqkv.weight.detach().clone() # [3*d_model, d_model] + + q_lin, k_lin, v_lin = adapter._split_mpt_qkv(fake_attn) + recovered = torch.cat([q_lin.weight, k_lin.weight, v_lin.weight], dim=0) + + assert torch.allclose( + recovered, original_w + ), "Round-trip failed: cat(Q,K,V) != original Wqkv" + + +# --------------------------------------------------------------------------- +# Factory registration test (Phase D) +# --------------------------------------------------------------------------- + + +class TestMPTFactoryRegistration: + """ArchitectureAdapterFactory must resolve MPTForCausalLM -> MPTArchitectureAdapter.""" + + def test_factory_resolves_mpt_architecture(self) -> None: + """Factory returns an MPTArchitectureAdapter instance for MPTForCausalLM. + + Uses a fully programmatic config — no HF Hub download. + """ + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + cfg.architecture = "MPTForCausalLM" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, MPTArchitectureAdapter) + + def test_factory_unknown_architecture_raises(self) -> None: + """Factory raises ValueError for an unregistered architecture key.""" + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + cfg.architecture = "NonExistentForCausalLM" + with pytest.raises(ValueError, match="Unsupported architecture"): + ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + def test_mpt_in_supported_architectures_dict(self) -> None: + """MPTForCausalLM must appear in the SUPPORTED_ARCHITECTURES mapping.""" + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "MPTForCausalLM" in SUPPORTED_ARCHITECTURES + assert SUPPORTED_ARCHITECTURES["MPTForCausalLM"] is MPTArchitectureAdapter diff --git a/transformer_lens/benchmarks/component_outputs.py b/transformer_lens/benchmarks/component_outputs.py index dcbc85cca..ba2d03edd 100644 --- a/transformer_lens/benchmarks/component_outputs.py +++ b/transformer_lens/benchmarks/component_outputs.py @@ -415,15 +415,14 @@ def _test_component_recursive( ): return - # Skip BLOOM and T5 attention and MLP components - they have custom signatures that require - # residual connections, alibi bias, or cache_position from the full model context + # Skip models whose MLP/attn forward signatures require extra context from the block: + # - BLOOM: MLP requires residual and alibi bias + # - T5: requires cache_position for relative position embeddings + # - MPT: MLP.forward(hidden_states, residual) performs the residual addition internally if "attn" in component_path or "mlp" in component_path: - # Check if this is a BLOOM or T5 model by looking at the HF model config hf_model_config = getattr(self.hf_model, "config", None) if hf_model_config and hasattr(hf_model_config, "model_type"): - # BLOOM requires residual and alibi bias - # T5 requires cache_position for relative position embeddings - if hf_model_config.model_type in ["bloom", "t5"]: + if hf_model_config.model_type in ["bloom", "t5", "mpt"]: return # Skip components that require specific shaped inputs from their parent modules diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index afeff43b5..098d055be 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -35,6 +35,7 @@ MingptArchitectureAdapter, MistralArchitectureAdapter, MixtralArchitectureAdapter, + MPTArchitectureAdapter, NanogptArchitectureAdapter, NeelSoluOldArchitectureAdapter, NeoArchitectureAdapter, @@ -89,6 +90,7 @@ "MambaForCausalLM": MambaArchitectureAdapter, "MixtralForCausalLM": MixtralArchitectureAdapter, "MistralForCausalLM": MistralArchitectureAdapter, + "MPTForCausalLM": MPTArchitectureAdapter, "NeoForCausalLM": NeoArchitectureAdapter, "NeoXForCausalLM": NeoxArchitectureAdapter, "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter, diff --git a/transformer_lens/model_bridge/compat.py b/transformer_lens/model_bridge/compat.py index 673784b6b..f711fbe17 100644 --- a/transformer_lens/model_bridge/compat.py +++ b/transformer_lens/model_bridge/compat.py @@ -2,6 +2,20 @@ These patches are applied lazily (only when missing) so they're safe to call from multiple adapters — the first caller wins, subsequent calls are no-ops. + +WARNING: patches here mutate classes from the installed `transformers` package +in place. They are process-global and persist for the entire Python session — +every model loaded afterward, including ones unrelated to the caller, sees the +patched class. This is acceptable because the shims only *add* v4-era methods +that v5 removed; they do not change v5 behavior. But it means a bug in a shim +affects the whole session, not just the adapter that invoked it. + +REMOVAL: drop the corresponding block (and its call sites) once the minimum +supported `transformers` version provides the method natively, or once all +remote-code models we support have been updated for v5. Track upstream status +against `transformers.cache_utils.DynamicCache` — when `from_legacy_cache`, +`to_legacy_cache`, and `get_usable_length` are restored or no longer needed, +`patch_dynamic_cache_v5` can be deleted outright. """ @@ -11,6 +25,9 @@ def patch_dynamic_cache_v5() -> None: Remote-code models written for transformers v4 call from_legacy_cache, to_legacy_cache, and get_usable_length which were removed in v5. Call this from any adapter's prepare_loading() that needs them. + + Side effect: mutates `transformers.cache_utils.DynamicCache` for the whole + process. See module docstring. """ try: from transformers.cache_utils import DynamicCache diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 518baca09..fb789cc30 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -52,6 +52,9 @@ ) from transformer_lens.model_bridge.generalized_components.linear import LinearBridge from transformer_lens.model_bridge.generalized_components.mlp import MLPBridge +from transformer_lens.model_bridge.generalized_components.mpt_alibi_attention import ( + MPTALiBiAttentionBridge, +) from transformer_lens.model_bridge.generalized_components.moe import MoEBridge from transformer_lens.model_bridge.generalized_components.normalization import ( NormalizationBridge, @@ -108,6 +111,7 @@ "JointGateUpMLPBridge", "LinearBridge", "MLPBridge", + "MPTALiBiAttentionBridge", "GatedMLPBridge", "GatedRMSNormBridge", "MoEBridge", diff --git a/transformer_lens/model_bridge/generalized_components/mpt_alibi_attention.py b/transformer_lens/model_bridge/generalized_components/mpt_alibi_attention.py new file mode 100644 index 000000000..fdebcf0c3 --- /dev/null +++ b/transformer_lens/model_bridge/generalized_components/mpt_alibi_attention.py @@ -0,0 +1,134 @@ +"""MPT ALiBi attention bridge — MPT uses ``position_bias`` kwarg + bool causal mask.""" + +from __future__ import annotations + +import math +from typing import Any, Dict, Optional + +import torch +from packaging import version + +from transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention import ( + ALiBiJointQKVAttentionBridge, +) + +try: + import transformers as _transformers + + _TRANSFORMERS_V5 = version.parse(_transformers.__version__) >= version.parse("5.0.0") +except Exception: + _TRANSFORMERS_V5 = False + + +def _build_mpt_alibi_tensor(num_heads: int, seq_len: int, alibi_bias_max: int = 8) -> torch.Tensor: + """MPT ALiBi bias [num_heads, 1, seq_len] — mirrors HF's ``build_mpt_alibi_tensor``.""" + alibi = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) + num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads)) + + base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64).float() + base = base * (alibi_bias_max / num_heads_power_of_2) + slopes = 1.0 / torch.pow(2, base) + slopes = slopes.view(1, num_heads_power_of_2, 1, 1) + + if num_heads_power_of_2 != num_heads: + slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[ + :, :num_heads, ... + ] + + alibi = alibi * slopes # [1, n_heads, 1, seq_len] + return alibi.squeeze(0) # [n_heads, 1, seq_len] + + +class MPTALiBiAttentionBridge(ALiBiJointQKVAttentionBridge): + """ALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv.""" + + _clip_qkv: Optional[float] = None + + def forward( + self, *args: Any, **kwargs: Any + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, None]: + """2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5.""" + output, attn_weights = super().forward(*args, **kwargs) + if _TRANSFORMERS_V5: + return output, attn_weights + return output, attn_weights, None + + def set_original_component(self, original_component: torch.nn.Module) -> None: + super().set_original_component(original_component) + if hasattr(self, "o") and hasattr(original_component, "out_proj"): + self.o.set_original_component(original_component.out_proj) + clip = getattr(original_component, "clip_qkv", None) + self._clip_qkv = float(clip) if clip is not None else None + + def _reconstruct_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any + ) -> tuple[torch.Tensor, torch.Tensor]: + # clip_qkv is post-projection, pre-head-split — must happen before reshape. + if self._clip_qkv is not None: + q = q.clamp(min=-self._clip_qkv, max=self._clip_qkv) + k = k.clamp(min=-self._clip_qkv, max=self._clip_qkv) + v = v.clamp(min=-self._clip_qkv, max=self._clip_qkv) + + num_heads = self.config.n_heads if self.config else 32 + q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( + q, k, v, num_heads, num_heads + ) + + softmax_scale = head_dim**-0.5 + attn_scores = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale + + # position_bias is [n_heads, 1, max_seq_len]; slice trailing kv_len, broadcast over batch. + position_bias = kwargs.get("position_bias", None) + if position_bias is not None: + kv_len = attn_scores.shape[-1] + pb = position_bias[:, :, -kv_len:] + attn_scores = attn_scores + pb.unsqueeze(0) + + # MPT passes a bool 4D mask (True = masked), not an additive float mask. + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + attn_scores = attn_scores.masked_fill( + attention_mask, torch.finfo(attn_scores.dtype).min + ) + + attn_scores = self.hook_attn_scores(attn_scores) + + attn_weights = self._softmax_dropout_pattern( + attn_scores, upcast_to_fp32=True, target_dtype=q.dtype + ) + + attn_output = torch.matmul(attn_weights, v) + attn_output = self._reshape_attn_output( + attn_output, batch_size, seq_len, num_heads, head_dim + ) + attn_output = self._apply_output_projection(attn_output) + return attn_output, attn_weights + + def get_random_inputs( + self, + batch_size: int = 2, + seq_len: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Dict[str, Any]: + """Test inputs using MPT's kwarg names: position_bias (no batch dim) + bool causal mask.""" + if device is None: + device = torch.device("cpu") + if dtype is None: + dtype = torch.float32 + + d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 2048 + num_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 32 + + position_bias = _build_mpt_alibi_tensor(num_heads, seq_len).to(device=device, dtype=dtype) + + causal = torch.triu( + torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1 + ) + causal_mask = causal.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) + + return { + "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype), + "position_bias": position_bias, + "attention_mask": causal_mask, + } diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 06647a62b..57bd677d2 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -90,6 +90,9 @@ from transformer_lens.model_bridge.supported_architectures.mixtral import ( MixtralArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.mpt import ( + MPTArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.nanogpt import ( NanogptArchitectureAdapter, ) @@ -184,6 +187,7 @@ "MingptArchitectureAdapter", "MistralArchitectureAdapter", "MixtralArchitectureAdapter", + "MPTArchitectureAdapter", "NanogptArchitectureAdapter", "NeelSoluOldArchitectureAdapter", "NeoArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/mpt.py b/transformer_lens/model_bridge/supported_architectures/mpt.py new file mode 100644 index 000000000..db50a028a --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/mpt.py @@ -0,0 +1,80 @@ +"""MPT (MPTForCausalLM) adapter — ALiBi, fused Wqkv, weight-only LayerNorm, no biases.""" + +from typing import Any + +import torch +import torch.nn as nn + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + LinearBridge, + MLPBridge, + NormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.mpt_alibi_attention import ( + MPTALiBiAttentionBridge, +) + + +class MPTArchitectureAdapter(ArchitectureAdapter): + """MPT adapter: ALiBi bias; all layers bias-free (no b_Q/b_K/b_V/b_O/b_in/b_out/ln bias).""" + + def __init__(self, cfg: Any) -> None: + super().__init__(cfg) + + self.cfg.normalization_type = "LN" + self.cfg.positional_embedding_type = "alibi" + self.cfg.final_rms = False + self.cfg.gated_mlp = False + self.cfg.attn_only = False + self.cfg.default_prepend_bos = False + + # Pure MHA: split_qkv yields [d_model, d_model] per head; standard rearrangements apply. + self.weight_processing_conversions = { + **self._qkvo_weight_conversions(), + } + + self.component_mapping = { + "embed": EmbeddingBridge(name="transformer.wte"), + "blocks": BlockBridge( + name="transformer.blocks", + submodules={ + "ln1": NormalizationBridge(name="norm_1", config=self.cfg), + "attn": MPTALiBiAttentionBridge( + name="attn", + config=self.cfg, + split_qkv_matrix=self._split_mpt_qkv, + submodules={ + "qkv": LinearBridge(name="Wqkv"), + "o": LinearBridge(name="out_proj"), + }, + ), + "ln2": NormalizationBridge(name="norm_2", config=self.cfg), + "mlp": MLPBridge( + name="ffn", + submodules={ + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": NormalizationBridge(name="transformer.norm_f", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head"), + } + + def _split_mpt_qkv(self, attn_component: Any) -> tuple[nn.Linear, nn.Linear, nn.Linear]: + """Split fused Wqkv into Q, K, V — row-wise chunk (NOT interleaved like BLOOM).""" + w = attn_component.Wqkv.weight.detach().clone() + w_q, w_k, w_v = torch.chunk(w, 3, dim=0) + d_model = self.cfg.d_model + + def make_linear(weight: torch.Tensor) -> nn.Linear: + lin = nn.Linear(d_model, d_model, bias=False, device=weight.device, dtype=weight.dtype) + lin.weight = nn.Parameter(weight.contiguous()) + return lin + + return make_linear(w_q), make_linear(w_k), make_linear(w_v) diff --git a/transformer_lens/tools/model_registry/data/supported_models.json b/transformer_lens/tools/model_registry/data/supported_models.json index c9b5d8a2b..494569436 100644 --- a/transformer_lens/tools/model_registry/data/supported_models.json +++ b/transformer_lens/tools/model_registry/data/supported_models.json @@ -94133,7 +94133,8 @@ "status": 0, "verified_date": null, "metadata": { - "downloads": 24750,"total_params": null + "downloads": 24750, + "total_params": null }, "note": null, "phase1_score": null, @@ -98599,23 +98600,6 @@ "phase7_score": null, "phase8_score": null }, - { - "architecture_id": "Qwen3MoeForCausalLM", - "model_id": "imdatta0/tiny_qwen3_moe_2.8B_0.7B", - "status": 1, - "verified_date": "2026-04-10", - "metadata": { - "downloads": 218, - "total_params": 2800000000 - }, - "note": "Full verification completed", - "phase1_score": 100.0, - "phase2_score": 100.0, - "phase3_score": 100.0, - "phase4_score": 70.4, - "phase7_score": null, - "phase8_score": null - }, { "architecture_id": "XGLMForCausalLM", "model_id": "facebook/xglm-564M", @@ -99159,6 +99143,414 @@ "phase4_score": null, "phase7_score": null, "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "vinai/PhoGPT-4B-Chat", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 6833, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "vinai/PhoGPT-4B", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 3211, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "anas-awadalla/mpt-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 2850, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "gl198976/mpt-7b-instruct", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 2672, + "total_params": 6649286656 + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "replit/replit-code-v1-3b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 2307, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "echarlaix/tiny-mpt-random-remote-code", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1975, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "wtang06/mpt-125m-c4", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1877, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "lightblue/japanese-mpt-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1726, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "gl198976/mpt-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1239, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "TehVenom/MPT-7b-Chat-Instruct-LongCTX-Merge", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1153, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "TehVenom/MPT-7b-InstructAndStorywriting-50_50-Merge", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1151, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "TehVenom/mpt-7b-InstructAndStorywriting-75_25-Merge", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1138, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "TehVenom/MPT-7b-storywriter-Apache-2.0", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1135, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "Nethermind/Mpt-Instruct-DotNet-S", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1109, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "openaccess-ai-collective/mpt-7b-wizardlm", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1104, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "ethzanalytics/mpt-7b-storywriter-sharded", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1101, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "P1ayer-1/mpt-7b-instruct-base", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1094, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "TehVenom/MPT-7b-WizardLM_Uncensored-Storywriter-Merge", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1066, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "nomic-ai/gpt4all-mpt", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1059, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "rwl4/mpt-7b-chat-extended", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 1047, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "gretelai/mpt-7b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 915, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "fxmarty/tiny-mpt-random-remote-code", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 893, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "jploski/mpt-mini-shakespeare", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 880, + "total_params": null + }, + "note": "Full verification completed", + "phase1_score": 100.0, + "phase2_score": 100.0, + "phase3_score": 100.0, + "phase4_score": 75.5, + "phase7_score": null, + "phase8_score": null + }, + { + "architecture_id": "MPTForCausalLM", + "model_id": "replit/replit-code-v1_5-3b", + "status": 0, + "verified_date": null, + "metadata": { + "downloads": 655, + "total_params": null + }, + "note": null, + "phase1_score": null, + "phase2_score": null, + "phase3_score": null, + "phase4_score": null, + "phase7_score": null, + "phase8_score": null } ] } diff --git a/transformer_lens/tools/model_registry/data/verification_history.json b/transformer_lens/tools/model_registry/data/verification_history.json index 0deb099f1..a635a19f9 100644 --- a/transformer_lens/tools/model_registry/data/verification_history.json +++ b/transformer_lens/tools/model_registry/data/verification_history.json @@ -1,5 +1,5 @@ { - "last_updated": "2026-04-13T12:12:14.839053", + "last_updated": "2026-04-10T18:43:37.000957", "records": [ { "model_id": "Macropodus/macbert4mdcspell_v1", @@ -11202,119 +11202,59 @@ "invalidation_reason": null }, { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", + "model_id": "jploski/mpt-mini-shakespeare", + "architecture_id": "MPTForCausalLM", + "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass); P2=8.3% < 75.0% (failed: g \u2014 74/123 components failed (74 critical)", + "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 8/51 components failed (8 critical)", "invalidated": false, "invalidation_reason": null }, { - "model_id": "internlm/internlm2_5-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", + "model_id": "jploski/mpt-mini-shakespeare", + "architecture_id": "MPTForCausalLM", + "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass); P2=8.3% < 75.0% (failed: g \u2014 74/123 components failed (74 critical)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass_logits) \u2014 Tensors differ: max_diff=1.952242, mean_rel=0.113042", "invalidated": false, "invalidation_reason": null }, { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", + "model_id": "jploski/mpt-mini-shakespeare", + "architecture_id": "MPTForCausalLM", + "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass); P2=8.3% < 75.0% (failed: g \u2014 74/123 components failed (74 critical)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass); P2=8.3% < 75.0% (failed: generation, gene \u2014 Forward pass failed: too many values to unpack (expected 2)", "invalidated": false, "invalidation_reason": null }, { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", + "model_id": "jploski/mpt-mini-shakespeare", + "architecture_id": "MPTForCausalLM", + "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass); P2=8.3% < 75.0% (failed: generation, gene \u2014 Forward pass failed: too many values to unpack (expected 2)", "invalidated": false, "invalidation_reason": null }, { - "model_id": "internlm/internlm2_5-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", + "model_id": "jploski/mpt-mini-shakespeare", + "architecture_id": "MPTForCausalLM", + "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", + "notes": "Below threshold: P1=50.0% < 100.0% (failed: forward_pass); P2=8.3% < 75.0% (failed: generation, gene \u2014 Forward pass failed: too many values to unpack (expected 2)", "invalidated": false, "invalidation_reason": null }, { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Full verification completed", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "internlm/internlm2_5-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Below threshold: P1=0.0% < 100.0% (failed: all_components, forward_pass_logits) \u2014 74/123 components failed (74 critical)", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "internlm/internlm2_5-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Full verification completed", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Full verification completed", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "internlm/internlm2_5-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", - "verified_by": "verify_models", - "transformerlens_version": null, - "notes": "Full verification completed", - "invalidated": false, - "invalidation_reason": null - }, - { - "model_id": "internlm/internlm2-chat-1_8b", - "architecture_id": "InternLM2ForCausalLM", - "verified_date": "2026-04-13", + "model_id": "jploski/mpt-mini-shakespeare", + "architecture_id": "MPTForCausalLM", + "verified_date": "2026-04-10", "verified_by": "verify_models", "transformerlens_version": null, "notes": "Full verification completed",