Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/mocks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,34 @@ def __init__(self):
layer.mlp.down_proj = nn.Linear(2048, 512, bias=False)
self.model.norm = nn.LayerNorm(512)
self.lm_head = nn.Linear(512, 1000, bias=False)


class MockOpenElmModel(nn.Module):
"""A mock implementation of the OpenELM model architecture for testing purposes.

Replicates the key architectural components of OpenELM:
- Embedding layer (token_embeddings) under 'transformer' root
- Multiple transformer layers with:
- RMSNorm for attention (attn_norm) and FFN (ffn_norm)
- Combined QKV attention (qkv_proj + out_proj)
- Combined gate+up MLP (proj_1 + proj_2)
- Final RMSNorm (transformer.norm)
- Synthetic lm_head (weight-tied to embeddings)
"""

def __init__(self):
super().__init__()
self.transformer = nn.Module()
self.transformer.token_embeddings = nn.Embedding(1000, 512)
self.transformer.layers = nn.ModuleList([nn.Module() for _ in range(2)])
for layer in self.transformer.layers:
layer.attn_norm = nn.LayerNorm(512) # RMSNorm in real model
layer.ffn_norm = nn.LayerNorm(512) # RMSNorm in real model
layer.attn = nn.Module()
layer.attn.qkv_proj = nn.Linear(512, 1536, bias=False) # Combined Q+K+V
layer.attn.out_proj = nn.Linear(512, 512, bias=False)
layer.ffn = nn.Module()
layer.ffn.proj_1 = nn.Linear(512, 4096, bias=False) # Combined gate+up
layer.ffn.proj_2 = nn.Linear(2048, 512, bias=False) # Down projection
self.transformer.norm = nn.LayerNorm(512) # RMSNorm in real model
self.lm_head = nn.Linear(512, 1000, bias=False)
3 changes: 2 additions & 1 deletion transformer_lens/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
validate_hook_shape_compatibility,
)
from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite
from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity
from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity, PhaseReferenceData
from transformer_lens.benchmarks.weight_processing import (
benchmark_weight_modification,
benchmark_weight_processing,
Expand All @@ -49,6 +49,7 @@
# Result types
"BenchmarkResult",
"BenchmarkSeverity",
"PhaseReferenceData",
# Forward pass benchmarks
"benchmark_forward_pass",
"benchmark_logits_equivalence",
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/benchmarks/backward_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def hook_fn(tensor, hook):
if cross_model:
# Use relaxed dimensional matching for cross-model comparison
is_compatible, error_msg = validate_hook_shape_compatibility(
bridge_grad.shape, reference_grad.shape, hook_name
bridge_grad.shape, reference_grad.shape, hook_name, cross_model=True
)
if not is_compatible:
mismatches.append(f"{hook_name}: {error_msg}")
Expand Down Expand Up @@ -410,7 +410,7 @@ def hook_fn(tensor, hook):
if cross_model:
# Use relaxed dimensional matching for cross-model comparison
is_compatible, error_msg = validate_hook_shape_compatibility(
bridge_grad.shape, reference_grad.shape, hook_name
bridge_grad.shape, reference_grad.shape, hook_name, cross_model=True
)
if not is_compatible:
mismatches.append(f"{hook_name}: {error_msg}")
Expand Down
31 changes: 23 additions & 8 deletions transformer_lens/benchmarks/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def benchmark_loss_equivalence(
bridge: TransformerBridge,
test_text: str,
reference_model: Optional[HookedTransformer] = None,
reference_loss: Optional[float] = None,
atol: float = 1e-3,
) -> BenchmarkResult:
"""Benchmark loss computation between TransformerBridge and HookedTransformer.
Expand All @@ -143,6 +144,7 @@ def benchmark_loss_equivalence(
bridge: TransformerBridge model to test
test_text: Input text for testing
reference_model: Optional HookedTransformer reference model
reference_loss: Optional pre-computed reference loss value (e.g., from Phase 1)
atol: Absolute tolerance for comparison

Returns:
Expand All @@ -152,7 +154,7 @@ def benchmark_loss_equivalence(
# Run bridge loss computation
bridge_loss = bridge(test_text, return_type="loss")

if reference_model is None:
if reference_model is None and reference_loss is None:
# No reference - just verify loss is valid
if not isinstance(bridge_loss, torch.Tensor):
return BenchmarkResult(
Expand All @@ -178,12 +180,18 @@ def benchmark_loss_equivalence(
details={"loss": loss_value},
)

# Compare with reference model
reference_loss = reference_model(test_text, return_type="loss")
# Get reference loss from model or pre-computed value
if reference_loss is not None:
ref_loss_val = reference_loss
elif reference_model is not None:
ref_loss_tensor = reference_model(test_text, return_type="loss")
ref_loss_val = ref_loss_tensor.item()
else:
raise ValueError("Either reference_logits or reference_model must be provided")

return compare_scalars(
bridge_loss.item(),
reference_loss.item(),
ref_loss_val,
atol=atol,
name="loss_equivalence",
)
Expand All @@ -201,6 +209,7 @@ def benchmark_logits_equivalence(
bridge: TransformerBridge,
test_text: str,
reference_model: Optional[HookedTransformer] = None,
reference_logits: Optional[torch.Tensor] = None,
atol: float = 3e-2,
rtol: float = 3e-2,
) -> BenchmarkResult:
Expand All @@ -213,6 +222,7 @@ def benchmark_logits_equivalence(
bridge: TransformerBridge model to test
test_text: Input text for testing
reference_model: Optional HookedTransformer reference model
reference_logits: Optional pre-computed reference logits tensor (e.g., from Phase 1)
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison

Expand All @@ -223,7 +233,7 @@ def benchmark_logits_equivalence(
# Run bridge forward pass
bridge_logits = bridge(test_text, return_type="logits")

if reference_model is None:
if reference_model is None and reference_logits is None:
# No reference - just verify logits shape and validity
if not isinstance(bridge_logits, torch.Tensor):
return BenchmarkResult(
Expand All @@ -248,12 +258,17 @@ def benchmark_logits_equivalence(
details={"output_shape": str(bridge_logits.shape)},
)

# Compare with reference model
reference_logits = reference_model(test_text, return_type="logits")
# Get reference logits from model or pre-computed tensor
if reference_logits is not None:
ref_logits = reference_logits.to(bridge_logits.device)
elif reference_model is not None:
ref_logits = reference_model(test_text, return_type="logits")
else:
raise ValueError("Either reference_logits or reference_model must be provided")

return compare_tensors(
bridge_logits,
reference_logits,
ref_logits,
atol=atol,
rtol=rtol,
name="logits_equivalence",
Expand Down
93 changes: 82 additions & 11 deletions transformer_lens/benchmarks/hook_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def validate_hook_shape_compatibility(
target_shape: tuple,
reference_shape: tuple,
hook_name: str,
cross_model: bool = False,
) -> tuple[bool, Optional[str]]:
"""Validate that hook shapes have compatible structure across different models.

Expand All @@ -27,6 +28,8 @@ def validate_hook_shape_compatibility(
target_shape: Shape of the tensor from the target model
reference_shape: Shape of the tensor from the reference model
hook_name: Name of the hook (for error messages)
cross_model: If True, skip sequence dimension checks (different tokenizers
produce different token counts for the same text)

Returns:
Tuple of (is_compatible, error_message)
Expand Down Expand Up @@ -54,7 +57,7 @@ def validate_hook_shape_compatibility(
False,
f"Batch dimension mismatch: {target_shape[0]} vs {reference_shape[0]}",
)
if target_shape[1] != reference_shape[1]:
if not cross_model and target_shape[1] != reference_shape[1]:
return (
False,
f"Sequence dimension mismatch: {target_shape[1]} vs {reference_shape[1]}",
Expand All @@ -79,13 +82,14 @@ def validate_hook_shape_compatibility(
if target_dim <= 0 or ref_dim <= 0:
return False, f"Invalid n_heads dimension: {target_dim} vs {ref_dim}"
else:
# For other hooks, dimension 1 is sequence - should be same
if target_dim != ref_dim:
# For other hooks, dimension 1 is sequence
# Cross-model references may tokenize differently, so skip this check
if not cross_model and target_dim != ref_dim:
return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}"
elif i >= 2 and is_attention_pattern_hook:
# For attention patterns, dimensions 2 and 3 are seq_q and seq_k
# Should be same (both use same test input)
if target_dim != ref_dim:
# Cross-model references may tokenize differently
if not cross_model and target_dim != ref_dim:
return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}"
else: # Model-specific dimensions (d_model, n_heads, d_head, etc.)
# Can differ between models - just verify it's valid
Expand Down Expand Up @@ -261,7 +265,7 @@ def hook_fn(tensor, hook):
if cross_model:
# Use relaxed shape matching for cross-model comparison
is_compatible, error_msg = validate_hook_shape_compatibility(
bridge_tensor.shape, reference_tensor.shape, hook_name
bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True
)
if not is_compatible:
shape_mismatches.append(f"{hook_name}: {error_msg}")
Expand Down Expand Up @@ -457,12 +461,14 @@ def hook_fn(tensor, hook):
def benchmark_hook_registry(
bridge: TransformerBridge,
reference_model: Optional[HookedTransformer] = None,
cross_model: bool = False,
) -> BenchmarkResult:
"""Benchmark hook registry completeness.

Args:
bridge: TransformerBridge model to test
reference_model: Optional HookedTransformer reference model
cross_model: If True, filter out expected architectural differences

Returns:
BenchmarkResult with registry comparison details
Expand Down Expand Up @@ -501,6 +507,26 @@ def benchmark_hook_registry(
missing_hooks = reference_hooks - bridge_hooks
extra_hooks = bridge_hooks - reference_hooks

# In cross-model mode, filter out hooks that are expected to differ
# due to architectural differences (e.g. fused QKV, rotary embeddings)
if cross_model and missing_hooks:
expected_missing_patterns = [
"hook_pos_embed",
"attn.hook_q",
"attn.hook_k",
"attn.hook_v",
"hook_q_input",
"hook_k_input",
"hook_v_input",
"attn.hook_attn_scores",
"attn.hook_pattern",
]
missing_hooks = {
h
for h in missing_hooks
if not any(pattern in h for pattern in expected_missing_patterns)
}

if missing_hooks:
return BenchmarkResult(
name="hook_registry",
Expand Down Expand Up @@ -660,6 +686,25 @@ def hook_fn(tensor, hook):
handle.remove()

# CRITICAL CHECK: Bridge must have all hooks that reference has
# In cross-model mode, filter out expected architectural differences
if cross_model and missing_from_bridge:
expected_missing_patterns = [
"hook_pos_embed",
"attn.hook_q",
"attn.hook_k",
"attn.hook_v",
"hook_q_input",
"hook_k_input",
"hook_v_input",
"attn.hook_attn_scores",
"attn.hook_pattern",
]
missing_from_bridge = [
h
for h in missing_from_bridge
if not any(pattern in h for pattern in expected_missing_patterns)
]

if missing_from_bridge:
return BenchmarkResult(
name="forward_hooks",
Expand All @@ -677,8 +722,17 @@ def hook_fn(tensor, hook):
# Filter out expected missing hooks in cross-model mode
if cross_model and hooks_that_didnt_fire:
# In cross-model mode, some hooks are expected to not fire due to architectural differences
# For example, rotary embedding models (Gemma, LLaMA) don't have hook_pos_embed
expected_missing_patterns = ["hook_pos_embed"]
expected_missing_patterns = [
"hook_pos_embed",
"attn.hook_q",
"attn.hook_k",
"attn.hook_v",
"hook_q_input",
"hook_k_input",
"hook_v_input",
"attn.hook_attn_scores",
"attn.hook_pattern",
]
actual_didnt_fire = [
h
for h in hooks_that_didnt_fire
Expand Down Expand Up @@ -711,7 +765,7 @@ def hook_fn(tensor, hook):
if cross_model:
# Use relaxed dimensional matching for cross-model comparison
is_compatible, error_msg = validate_hook_shape_compatibility(
bridge_tensor.shape, reference_tensor.shape, hook_name
bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True
)
if not is_compatible:
mismatches.append(f"{hook_name}: {error_msg}")
Expand Down Expand Up @@ -911,7 +965,7 @@ def hook_fn(tensor, hook):
if cross_model:
# Use relaxed dimensional matching for cross-model comparison
is_compatible, error_msg = validate_hook_shape_compatibility(
bridge_tensor.shape, reference_tensor.shape, hook_name
bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True
)
if not is_compatible:
mismatches.append(f"{hook_name}: {error_msg}")
Expand All @@ -936,7 +990,22 @@ def hook_fn(tensor, hook):
if cross_model and bridge_missing:
# In cross-model mode, some hooks are expected to be missing due to architectural differences
# For example, rotary embedding models (Gemma, LLaMA) don't have hook_pos_embed
expected_missing_patterns = ["hook_pos_embed"]
# Hooks that may be missing due to architectural differences:
# - hook_pos_embed: rotary models don't have positional embeddings
# - hook_q/k/v: fused QKV architectures (maintain_native_attention)
# - hook_q/k/v_input: same reason
# - hook_attn_scores/pattern: native attention doesn't expose these
expected_missing_patterns = [
"hook_pos_embed",
"attn.hook_q",
"attn.hook_k",
"attn.hook_v",
"hook_q_input",
"hook_k_input",
"hook_v_input",
"attn.hook_attn_scores",
"attn.hook_pattern",
]
actual_missing = [
h
for h in bridge_missing
Expand Down Expand Up @@ -1055,6 +1124,8 @@ def benchmark_hook_functionality(

def ablation_hook(activation, hook):
# Zero out an attention head in layer 0
# Clone to avoid in-place modification of a view from a custom Function
activation = activation.clone()
# For GQA models, the head dimension may be smaller than n_heads
n_heads = activation.shape[2]
head_idx = min(head_to_ablate, n_heads - 1)
Expand Down
Loading