diff --git a/tests/mocks/models.py b/tests/mocks/models.py index d1a8e0978..1310bcbe8 100644 --- a/tests/mocks/models.py +++ b/tests/mocks/models.py @@ -71,3 +71,34 @@ def __init__(self): layer.mlp.down_proj = nn.Linear(2048, 512, bias=False) self.model.norm = nn.LayerNorm(512) self.lm_head = nn.Linear(512, 1000, bias=False) + + +class MockOpenElmModel(nn.Module): + """A mock implementation of the OpenELM model architecture for testing purposes. + + Replicates the key architectural components of OpenELM: + - Embedding layer (token_embeddings) under 'transformer' root + - Multiple transformer layers with: + - RMSNorm for attention (attn_norm) and FFN (ffn_norm) + - Combined QKV attention (qkv_proj + out_proj) + - Combined gate+up MLP (proj_1 + proj_2) + - Final RMSNorm (transformer.norm) + - Synthetic lm_head (weight-tied to embeddings) + """ + + def __init__(self): + super().__init__() + self.transformer = nn.Module() + self.transformer.token_embeddings = nn.Embedding(1000, 512) + self.transformer.layers = nn.ModuleList([nn.Module() for _ in range(2)]) + for layer in self.transformer.layers: + layer.attn_norm = nn.LayerNorm(512) # RMSNorm in real model + layer.ffn_norm = nn.LayerNorm(512) # RMSNorm in real model + layer.attn = nn.Module() + layer.attn.qkv_proj = nn.Linear(512, 1536, bias=False) # Combined Q+K+V + layer.attn.out_proj = nn.Linear(512, 512, bias=False) + layer.ffn = nn.Module() + layer.ffn.proj_1 = nn.Linear(512, 4096, bias=False) # Combined gate+up + layer.ffn.proj_2 = nn.Linear(2048, 512, bias=False) # Down projection + self.transformer.norm = nn.LayerNorm(512) # RMSNorm in real model + self.lm_head = nn.Linear(512, 1000, bias=False) diff --git a/transformer_lens/benchmarks/__init__.py b/transformer_lens/benchmarks/__init__.py index 16ee56bd6..6996211c0 100644 --- a/transformer_lens/benchmarks/__init__.py +++ b/transformer_lens/benchmarks/__init__.py @@ -36,7 +36,7 @@ validate_hook_shape_compatibility, ) from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite -from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity +from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity, PhaseReferenceData from transformer_lens.benchmarks.weight_processing import ( benchmark_weight_modification, benchmark_weight_processing, @@ -49,6 +49,7 @@ # Result types "BenchmarkResult", "BenchmarkSeverity", + "PhaseReferenceData", # Forward pass benchmarks "benchmark_forward_pass", "benchmark_logits_equivalence", diff --git a/transformer_lens/benchmarks/backward_gradients.py b/transformer_lens/benchmarks/backward_gradients.py index e44ee06af..60e9e21b8 100644 --- a/transformer_lens/benchmarks/backward_gradients.py +++ b/transformer_lens/benchmarks/backward_gradients.py @@ -145,7 +145,7 @@ def hook_fn(tensor, hook): if cross_model: # Use relaxed dimensional matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_grad.shape, reference_grad.shape, hook_name + bridge_grad.shape, reference_grad.shape, hook_name, cross_model=True ) if not is_compatible: mismatches.append(f"{hook_name}: {error_msg}") @@ -410,7 +410,7 @@ def hook_fn(tensor, hook): if cross_model: # Use relaxed dimensional matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_grad.shape, reference_grad.shape, hook_name + bridge_grad.shape, reference_grad.shape, hook_name, cross_model=True ) if not is_compatible: mismatches.append(f"{hook_name}: {error_msg}") diff --git a/transformer_lens/benchmarks/forward_pass.py b/transformer_lens/benchmarks/forward_pass.py index a66b43a03..6b89d615c 100644 --- a/transformer_lens/benchmarks/forward_pass.py +++ b/transformer_lens/benchmarks/forward_pass.py @@ -135,6 +135,7 @@ def benchmark_loss_equivalence( bridge: TransformerBridge, test_text: str, reference_model: Optional[HookedTransformer] = None, + reference_loss: Optional[float] = None, atol: float = 1e-3, ) -> BenchmarkResult: """Benchmark loss computation between TransformerBridge and HookedTransformer. @@ -143,6 +144,7 @@ def benchmark_loss_equivalence( bridge: TransformerBridge model to test test_text: Input text for testing reference_model: Optional HookedTransformer reference model + reference_loss: Optional pre-computed reference loss value (e.g., from Phase 1) atol: Absolute tolerance for comparison Returns: @@ -152,7 +154,7 @@ def benchmark_loss_equivalence( # Run bridge loss computation bridge_loss = bridge(test_text, return_type="loss") - if reference_model is None: + if reference_model is None and reference_loss is None: # No reference - just verify loss is valid if not isinstance(bridge_loss, torch.Tensor): return BenchmarkResult( @@ -178,12 +180,18 @@ def benchmark_loss_equivalence( details={"loss": loss_value}, ) - # Compare with reference model - reference_loss = reference_model(test_text, return_type="loss") + # Get reference loss from model or pre-computed value + if reference_loss is not None: + ref_loss_val = reference_loss + elif reference_model is not None: + ref_loss_tensor = reference_model(test_text, return_type="loss") + ref_loss_val = ref_loss_tensor.item() + else: + raise ValueError("Either reference_logits or reference_model must be provided") return compare_scalars( bridge_loss.item(), - reference_loss.item(), + ref_loss_val, atol=atol, name="loss_equivalence", ) @@ -201,6 +209,7 @@ def benchmark_logits_equivalence( bridge: TransformerBridge, test_text: str, reference_model: Optional[HookedTransformer] = None, + reference_logits: Optional[torch.Tensor] = None, atol: float = 3e-2, rtol: float = 3e-2, ) -> BenchmarkResult: @@ -213,6 +222,7 @@ def benchmark_logits_equivalence( bridge: TransformerBridge model to test test_text: Input text for testing reference_model: Optional HookedTransformer reference model + reference_logits: Optional pre-computed reference logits tensor (e.g., from Phase 1) atol: Absolute tolerance for comparison rtol: Relative tolerance for comparison @@ -223,7 +233,7 @@ def benchmark_logits_equivalence( # Run bridge forward pass bridge_logits = bridge(test_text, return_type="logits") - if reference_model is None: + if reference_model is None and reference_logits is None: # No reference - just verify logits shape and validity if not isinstance(bridge_logits, torch.Tensor): return BenchmarkResult( @@ -248,12 +258,17 @@ def benchmark_logits_equivalence( details={"output_shape": str(bridge_logits.shape)}, ) - # Compare with reference model - reference_logits = reference_model(test_text, return_type="logits") + # Get reference logits from model or pre-computed tensor + if reference_logits is not None: + ref_logits = reference_logits.to(bridge_logits.device) + elif reference_model is not None: + ref_logits = reference_model(test_text, return_type="logits") + else: + raise ValueError("Either reference_logits or reference_model must be provided") return compare_tensors( bridge_logits, - reference_logits, + ref_logits, atol=atol, rtol=rtol, name="logits_equivalence", diff --git a/transformer_lens/benchmarks/hook_registration.py b/transformer_lens/benchmarks/hook_registration.py index f1f7dc937..f41ff60d2 100644 --- a/transformer_lens/benchmarks/hook_registration.py +++ b/transformer_lens/benchmarks/hook_registration.py @@ -17,6 +17,7 @@ def validate_hook_shape_compatibility( target_shape: tuple, reference_shape: tuple, hook_name: str, + cross_model: bool = False, ) -> tuple[bool, Optional[str]]: """Validate that hook shapes have compatible structure across different models. @@ -27,6 +28,8 @@ def validate_hook_shape_compatibility( target_shape: Shape of the tensor from the target model reference_shape: Shape of the tensor from the reference model hook_name: Name of the hook (for error messages) + cross_model: If True, skip sequence dimension checks (different tokenizers + produce different token counts for the same text) Returns: Tuple of (is_compatible, error_message) @@ -54,7 +57,7 @@ def validate_hook_shape_compatibility( False, f"Batch dimension mismatch: {target_shape[0]} vs {reference_shape[0]}", ) - if target_shape[1] != reference_shape[1]: + if not cross_model and target_shape[1] != reference_shape[1]: return ( False, f"Sequence dimension mismatch: {target_shape[1]} vs {reference_shape[1]}", @@ -79,13 +82,14 @@ def validate_hook_shape_compatibility( if target_dim <= 0 or ref_dim <= 0: return False, f"Invalid n_heads dimension: {target_dim} vs {ref_dim}" else: - # For other hooks, dimension 1 is sequence - should be same - if target_dim != ref_dim: + # For other hooks, dimension 1 is sequence + # Cross-model references may tokenize differently, so skip this check + if not cross_model and target_dim != ref_dim: return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}" elif i >= 2 and is_attention_pattern_hook: # For attention patterns, dimensions 2 and 3 are seq_q and seq_k - # Should be same (both use same test input) - if target_dim != ref_dim: + # Cross-model references may tokenize differently + if not cross_model and target_dim != ref_dim: return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}" else: # Model-specific dimensions (d_model, n_heads, d_head, etc.) # Can differ between models - just verify it's valid @@ -261,7 +265,7 @@ def hook_fn(tensor, hook): if cross_model: # Use relaxed shape matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_tensor.shape, reference_tensor.shape, hook_name + bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True ) if not is_compatible: shape_mismatches.append(f"{hook_name}: {error_msg}") @@ -457,12 +461,14 @@ def hook_fn(tensor, hook): def benchmark_hook_registry( bridge: TransformerBridge, reference_model: Optional[HookedTransformer] = None, + cross_model: bool = False, ) -> BenchmarkResult: """Benchmark hook registry completeness. Args: bridge: TransformerBridge model to test reference_model: Optional HookedTransformer reference model + cross_model: If True, filter out expected architectural differences Returns: BenchmarkResult with registry comparison details @@ -501,6 +507,26 @@ def benchmark_hook_registry( missing_hooks = reference_hooks - bridge_hooks extra_hooks = bridge_hooks - reference_hooks + # In cross-model mode, filter out hooks that are expected to differ + # due to architectural differences (e.g. fused QKV, rotary embeddings) + if cross_model and missing_hooks: + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] + missing_hooks = { + h + for h in missing_hooks + if not any(pattern in h for pattern in expected_missing_patterns) + } + if missing_hooks: return BenchmarkResult( name="hook_registry", @@ -660,6 +686,25 @@ def hook_fn(tensor, hook): handle.remove() # CRITICAL CHECK: Bridge must have all hooks that reference has + # In cross-model mode, filter out expected architectural differences + if cross_model and missing_from_bridge: + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] + missing_from_bridge = [ + h + for h in missing_from_bridge + if not any(pattern in h for pattern in expected_missing_patterns) + ] + if missing_from_bridge: return BenchmarkResult( name="forward_hooks", @@ -677,8 +722,17 @@ def hook_fn(tensor, hook): # Filter out expected missing hooks in cross-model mode if cross_model and hooks_that_didnt_fire: # In cross-model mode, some hooks are expected to not fire due to architectural differences - # For example, rotary embedding models (Gemma, LLaMA) don't have hook_pos_embed - expected_missing_patterns = ["hook_pos_embed"] + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] actual_didnt_fire = [ h for h in hooks_that_didnt_fire @@ -711,7 +765,7 @@ def hook_fn(tensor, hook): if cross_model: # Use relaxed dimensional matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_tensor.shape, reference_tensor.shape, hook_name + bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True ) if not is_compatible: mismatches.append(f"{hook_name}: {error_msg}") @@ -911,7 +965,7 @@ def hook_fn(tensor, hook): if cross_model: # Use relaxed dimensional matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_tensor.shape, reference_tensor.shape, hook_name + bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True ) if not is_compatible: mismatches.append(f"{hook_name}: {error_msg}") @@ -936,7 +990,22 @@ def hook_fn(tensor, hook): if cross_model and bridge_missing: # In cross-model mode, some hooks are expected to be missing due to architectural differences # For example, rotary embedding models (Gemma, LLaMA) don't have hook_pos_embed - expected_missing_patterns = ["hook_pos_embed"] + # Hooks that may be missing due to architectural differences: + # - hook_pos_embed: rotary models don't have positional embeddings + # - hook_q/k/v: fused QKV architectures (maintain_native_attention) + # - hook_q/k/v_input: same reason + # - hook_attn_scores/pattern: native attention doesn't expose these + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] actual_missing = [ h for h in bridge_missing @@ -1055,6 +1124,8 @@ def benchmark_hook_functionality( def ablation_hook(activation, hook): # Zero out an attention head in layer 0 + # Clone to avoid in-place modification of a view from a custom Function + activation = activation.clone() # For GQA models, the head dimension may be smaller than n_heads n_heads = activation.shape[2] head_idx = min(head_to_ablate, n_heads - 1) diff --git a/transformer_lens/benchmarks/hook_structure.py b/transformer_lens/benchmarks/hook_structure.py index 35b1e5bfb..1155213c8 100644 --- a/transformer_lens/benchmarks/hook_structure.py +++ b/transformer_lens/benchmarks/hook_structure.py @@ -18,6 +18,7 @@ def validate_hook_shape_compatibility( target_shape: tuple, reference_shape: tuple, hook_name: str, + cross_model: bool = False, ) -> tuple[bool, Optional[str]]: """Validate that hook shapes have compatible structure across different models. @@ -28,6 +29,8 @@ def validate_hook_shape_compatibility( target_shape: Shape of the tensor from the target model reference_shape: Shape of the tensor from the reference model hook_name: Name of the hook (for error messages) + cross_model: If True, skip sequence dimension checks (different tokenizers + produce different token counts for the same text) Returns: Tuple of (is_compatible, error_message) @@ -55,7 +58,7 @@ def validate_hook_shape_compatibility( False, f"Batch dimension mismatch: {target_shape[0]} vs {reference_shape[0]}", ) - if target_shape[1] != reference_shape[1]: + if not cross_model and target_shape[1] != reference_shape[1]: return ( False, f"Sequence dimension mismatch: {target_shape[1]} vs {reference_shape[1]}", @@ -80,13 +83,14 @@ def validate_hook_shape_compatibility( if target_dim <= 0 or ref_dim <= 0: return False, f"Invalid n_heads dimension: {target_dim} vs {ref_dim}" else: - # For other hooks, dimension 1 is sequence - should be same - if target_dim != ref_dim: + # For other hooks, dimension 1 is sequence + # Cross-model references may tokenize differently, so skip this check + if not cross_model and target_dim != ref_dim: return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}" elif i >= 2 and is_attention_pattern_hook: # For attention patterns, dimensions 2 and 3 are seq_q and seq_k - # Should be same (both use same test input) - if target_dim != ref_dim: + # Cross-model references may tokenize differently + if not cross_model and target_dim != ref_dim: return False, f"Sequence dimension mismatch: {target_dim} vs {ref_dim}" else: # Model-specific dimensions (d_model, n_heads, d_head, etc.) # Can differ between models - just verify it's valid @@ -224,6 +228,25 @@ def hook_fn(tensor, hook): handle.remove() # CRITICAL CHECK: Bridge must have all hooks that reference has + # In cross-model mode, filter out expected architectural differences + if cross_model and missing_from_bridge: + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] + missing_from_bridge = [ + h + for h in missing_from_bridge + if not any(pattern in h for pattern in expected_missing_patterns) + ] + if missing_from_bridge: return BenchmarkResult( name="forward_hooks_structure", @@ -262,7 +285,7 @@ def hook_fn(tensor, hook): if cross_model: # Use relaxed shape matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_tensor.shape, reference_tensor.shape, hook_name + bridge_tensor.shape, reference_tensor.shape, hook_name, cross_model=True ) if not is_compatible: shape_mismatches.append(f"{hook_name}: {error_msg}") @@ -456,6 +479,25 @@ def hook_fn(grad): handle.remove() # CRITICAL CHECK: Bridge must have all backward hooks that reference has + # In cross-model mode, filter out expected architectural differences + if cross_model and missing_from_bridge: + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] + missing_from_bridge = [ + h + for h in missing_from_bridge + if not any(pattern in h for pattern in expected_missing_patterns) + ] + if missing_from_bridge: return BenchmarkResult( name="backward_hooks_structure", @@ -494,7 +536,7 @@ def hook_fn(grad): if cross_model: # Use relaxed shape matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_grad.shape, reference_grad.shape, hook_name + bridge_grad.shape, reference_grad.shape, hook_name, cross_model=True ) if not is_compatible: shape_mismatches.append(f"{hook_name}: {error_msg}") @@ -600,8 +642,20 @@ def benchmark_activation_cache_structure( # Filter out expected missing hooks in cross-model mode if cross_model and missing_keys: # In cross-model mode, some hooks are expected to be missing due to architectural differences - # For example, rotary embedding models (Gemma, LLaMA) don't have hook_pos_embed - expected_missing_patterns = ["hook_pos_embed"] + # - hook_pos_embed: rotary models don't have positional embeddings + # - hook_q/k/v: fused QKV architectures (maintain_native_attention) + # - hook_attn_scores/pattern: native attention doesn't expose these + expected_missing_patterns = [ + "hook_pos_embed", + "attn.hook_q", + "attn.hook_k", + "attn.hook_v", + "hook_q_input", + "hook_k_input", + "hook_v_input", + "attn.hook_attn_scores", + "attn.hook_pattern", + ] actual_missing = [ k for k in missing_keys @@ -633,7 +687,7 @@ def benchmark_activation_cache_structure( if cross_model: # Use relaxed shape matching for cross-model comparison is_compatible, error_msg = validate_hook_shape_compatibility( - bridge_tensor.shape, ref_tensor.shape, key + bridge_tensor.shape, ref_tensor.shape, key, cross_model=True ) if not is_compatible: shape_mismatches.append(f"{key}: {error_msg}") diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 132ce69bb..16071f759 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -47,6 +47,8 @@ from transformer_lens.benchmarks.utils import ( BenchmarkResult, BenchmarkSeverity, + PhaseReferenceData, + compare_tensors, format_results, ) from transformer_lens.benchmarks.weight_processing import ( @@ -61,6 +63,10 @@ benchmark_weight_processing, benchmark_weight_sharing, ) +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, +) from transformer_lens.model_bridge import TransformerBridge # Architecture names that indicate encoder-decoder models @@ -75,17 +81,18 @@ ] -def is_encoder_decoder_model(model_name: str) -> bool: +def is_encoder_decoder_model(model_name: str, trust_remote_code: bool = False) -> bool: """Check if a model is an encoder-decoder architecture. Args: model_name: The HuggingFace model name or path + trust_remote_code: Whether to trust remote code for custom architectures. Returns: True if the model is encoder-decoder (like T5), False otherwise """ try: - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) # Check config attribute first if getattr(config, "is_encoder_decoder", False): return True @@ -96,7 +103,7 @@ def is_encoder_decoder_model(model_name: str) -> bool: return False -def get_auto_model_class(model_name: str): +def get_auto_model_class(model_name: str, trust_remote_code: bool = False): """Determine the correct AutoModel class for a given model. Some models (like T5) are encoder-decoder and need AutoModelForSeq2SeqLM @@ -108,11 +115,72 @@ def get_auto_model_class(model_name: str): Returns: The appropriate AutoModel class (AutoModelForCausalLM or AutoModelForSeq2SeqLM) """ - if is_encoder_decoder_model(model_name): + if is_encoder_decoder_model(model_name, trust_remote_code=trust_remote_code): return AutoModelForSeq2SeqLM return AutoModelForCausalLM +def _fixup_custom_model(hf_model) -> None: + """Apply post-load fixups for models with custom code. + + Some custom models (e.g., OpenELM) have non-persistent buffers (inv_freq, + causal_mask) that may be zeroed during HuggingFace's meta-device loading. + This function recomputes broken buffers to minimize forward pass divergence + against the bridge model. + + Note: The bridge model goes through a more thorough initialization via the + adapter's prepare_loading() + prepare_model() lifecycle hooks. Any remaining + forward pass divergence is an inherent consequence of different loading paths + for custom-code models, not a bridge correctness issue (all individual + components produce identical output, and hooks have zero numerical impact). + """ + # OpenELM fixups + if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): + # Ensure use_cache is set (OpenELM custom config omits it) + if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__: + hf_model.config.use_cache = False + + # Fix 1: Always recompute causal_mask (non-persistent buffer). + # After meta→real materialization, the buffer may contain garbage values + # rather than clean zeros, so we always recompute. + if hasattr(hf_model.transformer, "causal_mask"): + cm = hf_model.transformer.causal_mask + if cm is not None and cm.numel() > 0: + seq_len = cm.shape[-1] + correct_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device), + diagonal=1, + ) + hf_model.transformer.causal_mask = correct_mask + + # Fix 2: Always recompute RoPE inv_freq and sin/cos (non-persistent buffers). + rope_max = getattr(hf_model.config, "rope_max_length", None) + if rope_max is not None: + for layer in hf_model.transformer.layers: + if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"): + rope = layer.attn.pos_embedding + if hasattr(rope, "inv_freq"): + correct_inv_freq = 1.0 / ( + rope.freq_constant + ** ( + torch.arange(0, rope.model_dim, 2, dtype=torch.float32) + / rope.model_dim + ) + ) + rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device) + # Force-recompute sin/cos + rope._cached_cos = None + rope._cached_sin = None + rope._compute_sin_cos_embeddings(rope_max) + + # Create synthetic lm_head for weight-tied models (share_input_output_layers) + if getattr(hf_model, "lm_head", None) is None: + embed = hf_model.transformer.token_embeddings + lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False) + lm_head.weight = embed.weight + hf_model.lm_head = lm_head + + def run_comparison_benchmarks( bridge_model: TransformerBridge, reference_model: Optional[HookedTransformer], @@ -121,6 +189,7 @@ def run_comparison_benchmarks( is_processed: bool, verbose: bool = True, gpt2_reference: Optional[HookedTransformer] = None, + phase1_reference: Optional[PhaseReferenceData] = None, ) -> List[BenchmarkResult]: """Run standardized comparison benchmarks between Bridge and reference model. @@ -135,6 +204,7 @@ def run_comparison_benchmarks( is_processed: Whether models have processed weights (for weight-specific tests) verbose: Whether to print detailed results gpt2_reference: Optional GPT-2 reference for cross-model validation + phase1_reference: Optional saved Phase 1 HF reference data for equivalence testing Returns: List of BenchmarkResult objects @@ -223,6 +293,54 @@ def add_result(result: BenchmarkResult) -> None: except Exception as e: if verbose: print(f"✗ Equivalence benchmark failed: {e}\n") + elif phase1_reference is not None and phase1_reference.hf_logits is not None: + # Use saved Phase 1 bridge logits/loss as ground truth. + # Weight processing should be mathematically equivalent, so the processed + # bridge should produce the same output as the unprocessed bridge. + # + # Important: center_unembed intentionally shifts raw logits by a per-position + # constant (softmax-invariant). We compare log_softmax to be invariant to this. + try: + if verbose: + print("Using saved Phase 1 bridge reference for equivalence comparison") + # Compare log_softmax instead of raw logits to be centering-invariant. + # center_unembed shifts all vocab logits at each position by a constant, + # which changes raw logits but preserves log-probabilities. + bridge_logits = bridge_model(test_text, return_type="logits") + ref_logits = phase1_reference.hf_logits.to(bridge_logits.device) + bridge_log_probs = torch.nn.functional.log_softmax(bridge_logits, dim=-1) + ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1) + add_result( + compare_tensors( + bridge_log_probs, + ref_log_probs, + atol=1e-4, + rtol=1e-4, + name="logits_equivalence", + ) + ) + if phase1_reference.hf_loss is not None: + add_result( + benchmark_loss_equivalence( + bridge_model, + test_text, + reference_loss=phase1_reference.hf_loss, + atol=1e-3, + ) + ) + else: + add_result( + BenchmarkResult( + name="loss_equivalence", + severity=BenchmarkSeverity.SKIPPED, + message="Skipped (no Phase 1 loss reference available)", + passed=True, + ) + ) + gc.collect() + except Exception as e: + if verbose: + print(f"✗ Phase 1 reference comparison failed: {e}\n") else: if verbose: print("⏭️ Skipped (no HookedTransformer reference)\n") @@ -255,7 +373,11 @@ def add_result(result: BenchmarkResult) -> None: try: if verbose: print("Using GPT-2 for cross-model validation (dimensional matching)") - add_result(benchmark_hook_registry(bridge_model, reference_model=gpt2_reference)) + add_result( + benchmark_hook_registry( + bridge_model, reference_model=gpt2_reference, cross_model=True + ) + ) gc.collect() except Exception as e: if verbose: @@ -527,6 +649,7 @@ def run_benchmark_suite( track_memory: bool = False, test_weight_processing_individually: bool = False, phases: list[int] | None = None, + trust_remote_code: bool = False, ) -> List[BenchmarkResult]: """Run comprehensive benchmark suite for TransformerBridge. @@ -815,6 +938,7 @@ def cleanup_model(model, model_name_str: str): bridge_unprocessed = None hf_model = None + phase1_reference = PhaseReferenceData() # Load bridge without weights first to detect attn_implementation and dtype if verbose: @@ -823,7 +947,7 @@ def cleanup_model(model, model_name_str: str): attn_implementation = None try: # Load a lightweight version without weights to get config - bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False) # type: ignore[attr-defined] + bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] # Extract attn_implementation for HF model loading. # First check if adapter explicitly sets it (e.g. qwen3, gemma3). if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"): @@ -841,6 +965,30 @@ def cleanup_model(model, model_name_str: str): except Exception as e: if verbose: print(f"⚠ Could not detect config (will use defaults): {str(e)}") + # For custom code models, the config-only bridge may fail. We still need to + # apply architecture-specific patches (e.g., OpenELM _init_weights fix) before + # loading any model, otherwise _init_weights may re-randomize loaded weights. + if trust_remote_code: + try: + from transformer_lens.model_bridge.sources.transformers import ( + determine_architecture_from_hf_config, + map_default_transformer_lens_config, + ) + + hf_cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + tl_cfg = map_default_transformer_lens_config(hf_cfg) + arch = determine_architecture_from_hf_config(hf_cfg) + bridge_cfg = TransformerBridgeConfig.from_dict(tl_cfg.__dict__) + bridge_cfg.architecture = arch + bridge_cfg.model_name = model_name + adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_cfg) + adapter.prepare_loading(model_name, {}) + if verbose: + print("✓ Applied architecture patches for custom code model") + del adapter, bridge_cfg, tl_cfg, hf_cfg + except Exception as patch_err: + if verbose: + print(f"⚠ Could not apply architecture patches: {patch_err}") # Load HF model with matching attn_implementation if use_hf_reference: @@ -858,17 +1006,26 @@ def cleanup_model(model, model_name_str: str): if verbose: print(f"Using attn_implementation={attn_implementation}") # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5) - auto_model_class = get_auto_model_class(model_name) + auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code) if verbose and auto_model_class != AutoModelForCausalLM: print(f"Using {auto_model_class.__name__} for encoder-decoder model") # Ensure pad_token_id exists on HF config. Transformers v5 raises # AttributeError for missing config attributes, which crashes models # like StableLM that access config.pad_token_id during __init__. - hf_config = AutoConfig.from_pretrained(model_name) + hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) hf_kwargs["config"] = hf_config + if trust_remote_code: + hf_kwargs["trust_remote_code"] = True hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type] + # Post-load fixup for custom code models (e.g., OpenELM). + # NOTE: We intentionally use _fixup_custom_model instead of the adapter's + # prepare_model here. The adapter's prepare_model unconditionally recomputes + # non-persistent buffers (causal_mask, inv_freq) which is needed for the + # bridge path (meta-device loading), but the reference model loads normally + # on CPU with correct buffers. Recomputing them can introduce numeric drift. + _fixup_custom_model(hf_model) hf_model = hf_model.to(device) hf_model.eval() # Detect dtype from HF model @@ -888,7 +1045,7 @@ def cleanup_model(model, model_name_str: str): if verbose: print("Loading TransformerBridge (unprocessed)...") try: - bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype) # type: ignore[attr-defined] + bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] if verbose: print("✓ TransformerBridge loaded (unprocessed)\n") except Exception as e: @@ -941,6 +1098,28 @@ def cleanup_model(model, model_name_str: str): if verbose: print(f"✗ Forward pass benchmark failed: {e}\n") + # Capture unprocessed bridge reference data for Phase 3 reuse. + # We save the BRIDGE's logits/loss (not the HF model's), because the bridge + # forward path may differ slightly from HF. Phase 3 tests whether weight + # processing preserves the bridge's own output — comparing processed bridge + # vs unprocessed bridge. + if bridge_unprocessed is not None: + try: + with torch.no_grad(): + bridge_logits = bridge_unprocessed(test_text, return_type="logits") + phase1_reference.hf_logits = bridge_logits.detach().cpu().clone() + bridge_loss = bridge_unprocessed(test_text, return_type="loss") + phase1_reference.hf_loss = bridge_loss.item() + phase1_reference.test_text = test_text + if verbose: + print( + f"✓ Saved Phase 1 reference data " + f"(logits: {phase1_reference.hf_logits.shape})" + ) + except Exception as e: + if verbose: + print(f"⚠ Could not save Phase 1 reference data: {e}") + # Save bridge_dtype before cleaning up HF model (needed for Phase 3) saved_bridge_dtype = bridge_dtype @@ -1029,6 +1208,7 @@ def cleanup_model(model, model_name_str: str): ht_model_unprocessed = HookedTransformer.from_pretrained( model_name, device=device, + dtype=bridge_dtype, fold_ln=False, center_writing_weights=False, center_unembed=False, @@ -1063,19 +1243,30 @@ def cleanup_model(model, model_name_str: str): # Generation benchmarks already run above (before loading HT) - # Clean up unprocessed models - no longer needed + # Clean up unprocessed HT model - no longer needed if ht_model_unprocessed is not None: cleanup_model(ht_model_unprocessed, "HookedTransformer (unprocessed)") ht_model_unprocessed = None - if bridge_unprocessed is not None: - cleanup_model(bridge_unprocessed, "TransformerBridge (unprocessed)") - bridge_unprocessed = None + # NOTE: bridge_unprocessed is intentionally kept alive for Phase 3. + # Instead of loading a fresh bridge (which can produce non-deterministic + # outputs for some architectures like OpenELM), we reuse the same instance + # and process its weights in-place. This ensures Phase 3 tests purely + # measure the effect of weight processing, not loading variability. # ======================================================================== # PHASE 3: Bridge (processed) + HookedTransformer (processed) # ======================================================================== current_phase[0] = 3 + + def _cleanup_bridge_unprocessed(): + """Clean up the kept-alive bridge_unprocessed if Phase 3 is skipped.""" + nonlocal bridge_unprocessed + if bridge_unprocessed is not None: + cleanup_model(bridge_unprocessed, "TransformerBridge (unprocessed)") + bridge_unprocessed = None + if not enable_compatibility_mode: + _cleanup_bridge_unprocessed() if verbose: print("\n⚠ Compatibility mode disabled - skipping Phase 3\n") if verbose: @@ -1083,12 +1274,14 @@ def cleanup_model(model, model_name_str: str): return results if not should_run_phase(3): + _cleanup_bridge_unprocessed() if verbose: print("\n⚠ Phase 3 skipped (not in phases list)\n") return results # Skip Phase 3 for encoder-decoder models - weight processing is designed for decoder-only models if is_encoder_decoder_model(model_name): + _cleanup_bridge_unprocessed() if verbose: print("\n⚠ Phase 3 skipped (encoder-decoder model - weight processing not supported)\n") print("\n" + format_results(results)) @@ -1102,36 +1295,67 @@ def cleanup_model(model, model_name_str: str): bridge_processed = None ht_model_processed = None - # Load processed models for Phase 3 - try: - if verbose: - print("Loading TransformerBridge (processed)...") - # Use saved dtype from Phase 1 (HF model has been cleaned up) - bridge_dtype = saved_bridge_dtype - if verbose: - print(f"Using dtype={bridge_dtype} from Phase 1") - bridge_processed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype) # type: ignore[attr-defined] - bridge_processed.enable_compatibility_mode(disable_warnings=True) - if verbose: - print("✓ TransformerBridge compatibility mode enabled (processed)\n") - except Exception as e: - import traceback + # Reuse the Phase 1 bridge instance for Phase 3 instead of loading a fresh one. + # This avoids non-deterministic loading issues (some architectures like OpenELM + # produce different outputs across separate from_pretrained calls despite + # identical parameters and buffers). Processing weights in-place on the same + # instance ensures Phase 3 purely measures weight processing equivalence. + if bridge_unprocessed is not None: + try: + if verbose: + print("Processing weights on existing bridge (reusing Phase 1 instance)...") + bridge_processed = bridge_unprocessed + bridge_unprocessed = None # Transfer ownership + bridge_processed.enable_compatibility_mode(disable_warnings=True) + if verbose: + print("✓ TransformerBridge compatibility mode enabled (processed)\n") + except Exception as e: + import traceback - error_trace = traceback.format_exc() - add_result( - BenchmarkResult( - name="load_bridge_processed", - severity=BenchmarkSeverity.ERROR, - message=f"Failed to load processed TransformerBridge: {str(e)}", - passed=False, - details={"error": str(e), "traceback": error_trace}, + error_trace = traceback.format_exc() + add_result( + BenchmarkResult( + name="process_bridge_weights", + severity=BenchmarkSeverity.ERROR, + message=f"Failed to process bridge weights: {str(e)}", + passed=False, + details={"error": str(e), "traceback": error_trace}, + ) ) - ) - if verbose: - print(f"✗ Failed to load processed TransformerBridge: {str(e)}") - print(f"\nStack trace:\n{error_trace}") + if verbose: + print(f"✗ Failed to process bridge weights: {str(e)}") + print(f"\nStack trace:\n{error_trace}") + else: + # Fallback: load a fresh bridge if Phase 1 bridge was not available + try: + if verbose: + print("Loading TransformerBridge (processed)...") + bridge_dtype = saved_bridge_dtype + if verbose: + print(f"Using dtype={bridge_dtype} from Phase 1") + bridge_processed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] + bridge_processed.enable_compatibility_mode(disable_warnings=True) + if verbose: + print("✓ TransformerBridge compatibility mode enabled (processed)\n") + except Exception as e: + import traceback - # Add failure results for all Phase 3 tests that would have been run + error_trace = traceback.format_exc() + add_result( + BenchmarkResult( + name="load_bridge_processed", + severity=BenchmarkSeverity.ERROR, + message=f"Failed to load processed TransformerBridge: {str(e)}", + passed=False, + details={"error": str(e), "traceback": error_trace}, + ) + ) + if verbose: + print(f"✗ Failed to load processed TransformerBridge: {str(e)}") + print(f"\nStack trace:\n{error_trace}") + + if bridge_processed is None: + # Add failure results for all Phase 3 tests phase3_tests = [ "no_nan_inf", "weight_magnitudes", @@ -1161,9 +1385,9 @@ def cleanup_model(model, model_name_str: str): BenchmarkResult( name=test_name, severity=BenchmarkSeverity.ERROR, - message=f"Skipped due to model load failure", + message=f"Skipped due to weight processing failure", passed=False, - details={"reason": "load_bridge_processed_failed"}, + details={"reason": "bridge_processing_failed"}, ) ) @@ -1178,6 +1402,7 @@ def cleanup_model(model, model_name_str: str): ht_model_processed = HookedTransformer.from_pretrained( model_name, device=device, + dtype=bridge_dtype, fold_ln=True, center_writing_weights=True, center_unembed=True, @@ -1225,6 +1450,7 @@ def cleanup_model(model, model_name_str: str): is_processed=True, # Processed mode - include weight processing tests verbose=verbose, gpt2_reference=gpt2_reference, # Use GPT-2 cross-model ref if no same-arch HT + phase1_reference=phase1_reference, # Saved HF logits/loss for equivalence testing ) # Tag all phase 3 results with phase number for result in phase3_results: @@ -1407,6 +1633,11 @@ def main(): action="store_true", help="Suppress verbose output", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code for custom architectures (e.g., OpenELM)", + ) args = parser.parse_args() @@ -1417,6 +1648,7 @@ def main(): use_ht_reference=not args.no_ht_reference, enable_compatibility_mode=not args.no_compat, verbose=not args.quiet, + trust_remote_code=args.trust_remote_code, ) diff --git a/transformer_lens/benchmarks/utils.py b/transformer_lens/benchmarks/utils.py index c11c16066..50c1d0454 100644 --- a/transformer_lens/benchmarks/utils.py +++ b/transformer_lens/benchmarks/utils.py @@ -59,6 +59,19 @@ def print_immediate(self) -> None: print(str(self)) +@dataclass +class PhaseReferenceData: + """Reference data saved from Phase 1 for reuse in Phase 3. + + When a model has no HookedTransformer support, Phase 1 HF logits serve as + ground truth for verifying that weight processing doesn't alter model output. + """ + + hf_logits: Optional[torch.Tensor] = None # [batch, seq, vocab] from HF model + hf_loss: Optional[float] = None # scalar loss from bridge (unprocessed) + test_text: Optional[str] = None # text used (for verification) + + def compare_tensors( tensor1: torch.Tensor, tensor2: torch.Tensor, diff --git a/transformer_lens/benchmarks/weight_processing.py b/transformer_lens/benchmarks/weight_processing.py index 418610741..84b6875e6 100644 --- a/transformer_lens/benchmarks/weight_processing.py +++ b/transformer_lens/benchmarks/weight_processing.py @@ -296,10 +296,10 @@ def benchmark_weight_modification( bridge.blocks[0].attn.W_V.copy_(original_w_v) # Some models (e.g., models with complex attention mechanisms) may have - # forward pass issues after weight modification. Report as a known limitation. + # forward pass issues after weight modification. Report as skipped. return BenchmarkResult( name="weight_modification", - severity=BenchmarkSeverity.INFO, + severity=BenchmarkSeverity.SKIPPED, message=f"Weight modification not testable for this architecture: {str(forward_error)}", details={"error": str(forward_error), "architecture_limitation": True}, ) @@ -327,15 +327,18 @@ def benchmark_weight_modification( ) except Exception as e: - # Some architectures (e.g., Gemma 3 with complex attention) may have forward pass - # issues after weight modification. Report as INFO (passed) for these known limitations. - if "cannot be multiplied" in str(e) or "shape" in str(e).lower(): + # Some architectures (e.g., Gemma 3 with complex attention, OpenELM with + # combined QKV) don't expose W_V. Report as skipped, not passed. + if ( + "cannot be multiplied" in str(e) + or "shape" in str(e).lower() + or "has no attribute" in str(e) + ): return BenchmarkResult( name="weight_modification", - severity=BenchmarkSeverity.INFO, - message=f"Weight modification not testable for this architecture (shape incompatibility)", + severity=BenchmarkSeverity.SKIPPED, + message=f"Weight modification not testable for this architecture: {str(e)}", details={"error": str(e), "architecture_limitation": True}, - passed=True, ) return BenchmarkResult( name="weight_modification", @@ -364,46 +367,68 @@ def benchmark_layer_norm_folding( # Get state dict from bridge (should return TransformerLens format keys) state_dict = bridge.state_dict() - # Check first layer normalization weights in TransformerLens format - ln_key = "blocks.0.ln1.weight" + # Check both ln1 (attention LN) and ln2 (MLP LN) in TransformerLens format. + # Models with combined QKV projections (e.g., OpenELM's qkv_proj) cannot + # fold ln1 into attention weights, but ln2 should always be foldable. + tolerance = 0.01 + expected_val = 1.0 + folded = [] + not_folded = [] - # Fallback: if TL format key doesn't exist, try common HF format patterns - if ln_key not in state_dict: - # Try GPT-2 HF format - if "transformer.h.0.ln_1.weight" in state_dict: - ln_key = "transformer.h.0.ln_1.weight" - # Try Gemma HF format - elif "model.layers.0.input_layernorm.weight" in state_dict: - ln_key = "model.layers.0.input_layernorm.weight" + for ln_name in ["ln1", "ln2"]: + ln_key = f"blocks.0.{ln_name}.weight" + if ln_key not in state_dict: + continue + ln_weight = state_dict[ln_key] + mean_val = torch.mean(ln_weight).item() + if abs(mean_val - expected_val) < tolerance: + folded.append((ln_name, ln_key, mean_val)) else: - return BenchmarkResult( - name="layer_norm_folding", - severity=BenchmarkSeverity.WARNING, - message="Could not find layer norm weights in state dict", - passed=False, - ) - - # Get the layer norm weight tensor - ln_weight = state_dict[ln_key] + not_folded.append((ln_name, ln_key, mean_val)) - # Check if weights are close to identity (all ones for LayerNorm/RMSNorm) - mean_val = torch.mean(ln_weight).item() - expected_val = 1.0 - tolerance = 0.1 + if not folded and not not_folded: + return BenchmarkResult( + name="layer_norm_folding", + severity=BenchmarkSeverity.WARNING, + message="Could not find layer norm weights in state dict", + passed=False, + ) - if abs(mean_val - expected_val) < tolerance: + if folded and not not_folded: + # All LN weights are folded + names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in folded) return BenchmarkResult( name="layer_norm_folding", severity=BenchmarkSeverity.INFO, - message=f"Layer norm folding verified (mean={mean_val:.6f}, expected={expected_val})", - details={"mean": mean_val, "expected": expected_val, "key": ln_key}, + message=f"Layer norm folding verified: {names}", + details={"folded": [n for n, _, _ in folded]}, + ) + elif folded and not_folded: + # Partial folding — some LN weights folded, some not. + # This is expected for models with combined QKV (ln1 can't fold). + folded_names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in folded) + unfolded_names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in not_folded) + return BenchmarkResult( + name="layer_norm_folding", + severity=BenchmarkSeverity.WARNING, + message=( + f"Partial LN folding: {folded_names} folded; " + f"{unfolded_names} preserved (expected for combined QKV models)" + ), + details={ + "folded": [n for n, _, _ in folded], + "not_folded": [n for n, _, _ in not_folded], + }, + passed=True, ) else: + # No LN weights folded + names = ", ".join(f"{n} (mean={m:.6f})" for n, _, m in not_folded) return BenchmarkResult( name="layer_norm_folding", severity=BenchmarkSeverity.WARNING, - message=f"Layer norm weights not identity after folding (mean={mean_val:.6f}, expected={expected_val})", - details={"mean": mean_val, "expected": expected_val, "key": ln_key}, + message=f"Layer norm weights not identity after folding: {names}", + details={"not_folded": [n for n, _, _ in not_folded]}, passed=False, ) @@ -582,7 +607,7 @@ def benchmark_unembed_centering( # Compute mean along vocabulary dimension (dim 0) mean_abs = torch.mean(torch.abs(torch.mean(w_u, dim=0))).item() - tolerance = 0.1 # 10% tolerance (unembed centering is less strict) + tolerance = 0.01 # 1% tolerance (consistent with attn/mlp centering) if mean_abs < tolerance: return BenchmarkResult( diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index f3fba33f4..67f00a0b2 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -27,6 +27,7 @@ Olmo3ArchitectureAdapter, OlmoArchitectureAdapter, OlmoeArchitectureAdapter, + OpenElmArchitectureAdapter, OptArchitectureAdapter, Phi3ArchitectureAdapter, PhiArchitectureAdapter, @@ -59,6 +60,7 @@ "Olmo2ForCausalLM": Olmo2ArchitectureAdapter, "Olmo3ForCausalLM": Olmo3ArchitectureAdapter, "OlmoeForCausalLM": OlmoeArchitectureAdapter, + "OpenELMForCausalLM": OpenElmArchitectureAdapter, "OPTForCausalLM": OptArchitectureAdapter, "PhiForCausalLM": PhiArchitectureAdapter, "Phi3ForCausalLM": Phi3ArchitectureAdapter, diff --git a/transformer_lens/model_bridge/architecture_adapter.py b/transformer_lens/model_bridge/architecture_adapter.py index 650985a8a..1928c8b76 100644 --- a/transformer_lens/model_bridge/architecture_adapter.py +++ b/transformer_lens/model_bridge/architecture_adapter.py @@ -645,6 +645,30 @@ def convert_hf_key_to_tl_key(self, hf_key: str) -> str: return f"blocks.{layer_idx}.{tl_subname}.{tl_nested_name}.{param}" return hf_key + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: + """Called before HuggingFace model loading to apply architecture-specific patches. + + Override this to patch HF model classes before from_pretrained() is called. + For example, patching custom model code that is incompatible with transformers v5 + meta device initialization. + + Args: + model_name: The HuggingFace model name/path + model_kwargs: The kwargs dict that will be passed to from_pretrained() + """ + pass + + def prepare_model(self, hf_model: Any) -> None: + """Called after HuggingFace model loading but before bridge creation. + + Override this to fix up the loaded model (e.g., create synthetic modules, + re-initialize deferred computations, apply post-load patches). + + Args: + hf_model: The loaded HuggingFace model instance + """ + pass + def setup_component_testing(self, hf_model: RemoteModel, bridge_model: Any = None) -> None: """Set up model-specific references needed for component testing. diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 3214cbe55..dd78fa53c 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -145,6 +145,7 @@ def boot_transformers( dtype: torch.dtype = torch.float32, tokenizer: Optional[Any] = None, load_weights: bool = True, + trust_remote_code: bool = False, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -155,6 +156,7 @@ def boot_transformers( dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. + trust_remote_code: Whether to trust remote code for custom model architectures. Returns: The bridge to the loaded model. @@ -168,6 +170,7 @@ def boot_transformers( dtype=dtype, tokenizer=tokenizer, load_weights=load_weights, + trust_remote_code=trust_remote_code, ) @property @@ -1686,6 +1689,7 @@ def generate( top_p: Optional[float] = None, temperature: float = 1.0, freq_penalty: float = 0.0, + repetition_penalty: float = 1.0, use_past_kv_cache: bool = True, prepend_bos: Optional[bool] = None, padding_side: Optional[str] = None, @@ -1710,6 +1714,9 @@ def generate( top_p: Probability mass to sample from. If 1.0, sample from all tokens temperature: Temperature for sampling. Higher values will make the model more random freq_penalty: Frequency penalty for sampling - how much to penalise previous tokens + repetition_penalty: HuggingFace-style repetition penalty. Values > 1.0 discourage + repetition by dividing positive logits and multiplying negative logits for + previously seen tokens. Default 1.0 (no penalty). use_past_kv_cache: Not used in Bridge (kept for API compatibility) prepend_bos: Not used in Bridge (kept for API compatibility) padding_side: Not used in Bridge (kept for API compatibility) @@ -1790,10 +1797,16 @@ def generate( top_p=top_p, temperature=temperature, freq_penalty=freq_penalty, + repetition_penalty=repetition_penalty, tokens=current_tokens, ).to(self.cfg.device) else: - sampled_tokens = final_logits.argmax(-1).to(self.cfg.device) + sampled_tokens = utils.sample_logits( + final_logits, + temperature=0.0, + repetition_penalty=repetition_penalty, + tokens=current_tokens, + ).to(self.cfg.device) sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index d297bd1f8..41f2a6581 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -50,6 +50,8 @@ def map_default_transformer_lens_config(hf_config): tl_config.d_model = hf_config.n_embd elif hasattr(hf_config, "hidden_size"): tl_config.d_model = hf_config.hidden_size + elif hasattr(hf_config, "model_dim"): + tl_config.d_model = hf_config.model_dim elif hasattr(hf_config, "d_model"): tl_config.d_model = hf_config.d_model if hasattr(hf_config, "n_head"): @@ -58,9 +60,30 @@ def map_default_transformer_lens_config(hf_config): tl_config.n_heads = hf_config.num_attention_heads elif hasattr(hf_config, "num_heads"): tl_config.n_heads = hf_config.num_heads + elif hasattr(hf_config, "num_query_heads") and isinstance(hf_config.num_query_heads, list): + tl_config.n_heads = max(hf_config.num_query_heads) if hasattr(hf_config, "num_key_value_heads") and hf_config.num_key_value_heads is not None: try: num_kv_heads = hf_config.num_key_value_heads + # Handle per-layer lists (e.g., OpenELM) by taking the max + if isinstance(num_kv_heads, list): + num_kv_heads = max(num_kv_heads) + if hasattr(num_kv_heads, "item"): + num_kv_heads = num_kv_heads.item() + num_kv_heads = int(num_kv_heads) + num_heads = tl_config.n_heads + if hasattr(num_heads, "item"): + num_heads = num_heads.item() + num_heads = int(num_heads) + if num_kv_heads != num_heads: + tl_config.n_key_value_heads = num_kv_heads + except (TypeError, ValueError, AttributeError): + pass + elif hasattr(hf_config, "num_kv_heads") and hf_config.num_kv_heads is not None: + try: + num_kv_heads = hf_config.num_kv_heads + if isinstance(num_kv_heads, list): + num_kv_heads = max(num_kv_heads) if hasattr(num_kv_heads, "item"): num_kv_heads = num_kv_heads.item() num_kv_heads = int(num_kv_heads) @@ -76,6 +99,8 @@ def map_default_transformer_lens_config(hf_config): tl_config.n_layers = hf_config.n_layer elif hasattr(hf_config, "num_hidden_layers"): tl_config.n_layers = hf_config.num_hidden_layers + elif hasattr(hf_config, "num_transformer_layers"): + tl_config.n_layers = hf_config.num_transformer_layers elif hasattr(hf_config, "num_layers"): tl_config.n_layers = hf_config.num_layers if hasattr(hf_config, "vocab_size"): @@ -84,6 +109,8 @@ def map_default_transformer_lens_config(hf_config): tl_config.n_ctx = hf_config.n_positions elif hasattr(hf_config, "max_position_embeddings"): tl_config.n_ctx = hf_config.max_position_embeddings + elif hasattr(hf_config, "max_context_length"): + tl_config.n_ctx = hf_config.max_context_length elif hasattr(hf_config, "max_length"): tl_config.n_ctx = hf_config.max_length elif hasattr(hf_config, "seq_length"): @@ -154,6 +181,7 @@ def determine_architecture_from_hf_config(hf_config): "qwen": "QwenForCausalLM", "qwen2": "Qwen2ForCausalLM", "qwen3": "Qwen3ForCausalLM", + "openelm": "OpenELMForCausalLM", "stablelm": "StableLmForCausalLM", "t5": "T5ForConditionalGeneration", } @@ -211,6 +239,7 @@ def boot( dtype: torch.dtype = torch.float32, tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, + trust_remote_code: bool = False, ) -> TransformerBridge: """Boot a model from HuggingFace. @@ -232,7 +261,9 @@ def boot( ) model_name = official_name break - hf_config = AutoConfig.from_pretrained(model_name, output_attentions=True) + hf_config = AutoConfig.from_pretrained( + model_name, output_attentions=True, trust_remote_code=trust_remote_code + ) if hf_config_overrides: hf_config.__dict__.update(hf_config_overrides) tl_config = map_default_transformer_lens_config(hf_config) @@ -252,15 +283,22 @@ def boot( if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) model_kwargs = {"config": hf_config, "torch_dtype": dtype} + if trust_remote_code: + model_kwargs["trust_remote_code"] = True if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation + adapter.prepare_loading(model_name, model_kwargs) if not load_weights: + from_config_kwargs = {} + if trust_remote_code: + from_config_kwargs["trust_remote_code"] = True with contextlib.redirect_stdout(None): - hf_model = model_class.from_config(hf_config) + hf_model = model_class.from_config(hf_config, **from_config_kwargs) else: hf_model = model_class.from_pretrained(model_name, **model_kwargs) if device is not None: hf_model = hf_model.to(device) + adapter.prepare_model(hf_model) tokenizer = tokenizer default_padding_side = getattr(adapter.cfg, "default_padding_side", None) use_fast = getattr(adapter.cfg, "use_fast", True) @@ -269,21 +307,28 @@ def boot( else: huggingface_token = os.environ.get("HF_TOKEN", "") token_arg = huggingface_token if len(huggingface_token) > 0 else None + # Determine tokenizer source: use adapter's tokenizer_name if the model + # doesn't ship its own tokenizer (e.g., OpenELM uses LLaMA tokenizer) + tokenizer_source = model_name + if hasattr(adapter.cfg, "tokenizer_name") and adapter.cfg.tokenizer_name is not None: + tokenizer_source = adapter.cfg.tokenizer_name # Try to load tokenizer with add_bos_token=True first # (encoder-decoder models like T5 don't have BOS tokens and will raise ValueError) try: base_tokenizer = AutoTokenizer.from_pretrained( - model_name, + tokenizer_source, add_bos_token=True, use_fast=use_fast, token=token_arg, + trust_remote_code=trust_remote_code, ) except ValueError: # Model doesn't have a BOS token, load without add_bos_token base_tokenizer = AutoTokenizer.from_pretrained( - model_name, + tokenizer_source, use_fast=use_fast, token=token_arg, + trust_remote_code=trust_remote_code, ) tokenizer = setup_tokenizer( base_tokenizer, diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 0afe524a0..768e71bef 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -64,6 +64,9 @@ from transformer_lens.model_bridge.supported_architectures.neox import ( NeoxArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.openelm import ( + OpenElmArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.opt import ( OptArchitectureAdapter, ) @@ -109,6 +112,7 @@ "NeelSoluOldArchitectureAdapter", "NeoArchitectureAdapter", "NeoxArchitectureAdapter", + "OpenElmArchitectureAdapter", "OlmoArchitectureAdapter", "Olmo2ArchitectureAdapter", "Olmo3ArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/openelm.py b/transformer_lens/model_bridge/supported_architectures/openelm.py new file mode 100644 index 000000000..99b024c94 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/openelm.py @@ -0,0 +1,274 @@ +"""OpenELM architecture adapter.""" + +import sys +from typing import Any + +import torch + +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + LinearBridge, + MLPBridge, + RMSNormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.attention import ( + AttentionBridge, +) + + +class OpenElmArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for Apple OpenELM models. + + OpenELM uses a unique architecture with per-layer varying head counts and FFN + dimensions. Key characteristics: + + - Combined QKV projection (qkv_proj) with per-layer varying Q/KV head counts + - Gated MLP with combined gate+up projection (proj_1) and per-layer FFN sizes + - RMSNorm normalization + - Full rotary embeddings (per-layer, not shared) + - Optional Q/K RMSNorm (normalize_qk_projections=True) + - Weight tying (share_input_output_layers=True typically) + - Model root is 'transformer' (not 'model') + - Requires trust_remote_code=True (custom HF code) + + The native HF attention handles all per-layer dimension variations, RoPE, + GQA group repeat, and Q/K normalization internally. The bridge delegates + to the native forward for correct computation. + + Note: Individual Q/K/V hooks are not available since the model uses a combined + QKV projection. Attention-level hooks (hook_attn_in, hook_attn_out) are provided. + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the OpenELM architecture adapter.""" + super().__init__(cfg) + + # Set config variables for weight processing + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + + self.default_config = { + "d_model": cfg.d_model, + "d_head": getattr(cfg, "head_dim", cfg.d_model // cfg.n_heads), + "n_heads": cfg.n_heads, + "n_layers": cfg.n_layers, + "d_vocab": cfg.d_vocab, + } + + # GQA support + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.default_config["n_key_value_heads"] = cfg.n_key_value_heads + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + # OpenELM doesn't ship its own tokenizer — uses LLaMA tokenizer. + # Use NousResearch mirror (ungated) to avoid access restrictions. + self.cfg.tokenizer_name = "NousResearch/Llama-2-7b-hf" + + # No weight processing conversions needed - native attention handles all + # per-layer dimension variations internally + self.weight_processing_conversions = {} + + # Store reference for RoPE patching + self._original_rope_compute = None + self._rope_class = None + + self.component_mapping = { + "embed": EmbeddingBridge(name="transformer.token_embeddings"), + "blocks": BlockBridge( + name="transformer.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="attn_norm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="ffn_norm", config=self.cfg), + "attn": AttentionBridge( + name="attn", + config=self.cfg, + submodules={ + "qkv": LinearBridge(name="qkv_proj"), + "o": LinearBridge(name="out_proj"), + }, + maintain_native_attention=True, + requires_attention_mask=True, + ), + "mlp": MLPBridge( + name="ffn", + config=self.cfg, + submodules={ + "in": LinearBridge(name="proj_1"), + "out": LinearBridge(name="proj_2"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="transformer.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: + """Patch OpenELM for compatibility with transformers v5. + + Two patches are needed: + 1. RotaryEmbedding: Custom _compute_sin_cos_embeddings fails on meta device + because it calls .cos() on meta tensors. We wrap it to catch NotImplementedError. + 2. Weight re-initialization: OpenELM's _init_weights re-randomizes ALL weights + after they've been loaded from safetensors because transformers v5's + _finalize_load_state_dict calls initialize_weights() on modules lacking the + _is_hf_initialized flag. We patch _init_weights to skip real (non-meta) tensors. + + Args: + model_name: The HuggingFace model name/path + model_kwargs: The kwargs dict for from_pretrained() + """ + # Force-import the modeling module so we can patch it + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + get_class_from_dynamic_module( + "modeling_openelm.OpenELMForCausalLM", + model_name, + ) + except Exception: + return + + # Find ALL imported OpenELM modules and apply patches. + # Each model variant (e.g., OpenELM-1_1B vs OpenELM-1_1B-Instruct) gets its own + # module in sys.modules with a different cache path, so we patch all of them. + for key in list(sys.modules.keys()): + if "openelm" in key.lower() and "modeling" in key.lower(): + module = sys.modules[key] + if hasattr(module, "OpenELMRotaryEmbedding"): + rope_class = module.OpenELMRotaryEmbedding + # Skip if already patched (avoid wrapping safe_compute in safe_compute) + if getattr(rope_class, "_tl_patched", False): + continue + # Patch 1: RoPE meta device fix + original_compute = rope_class._compute_sin_cos_embeddings + + def safe_compute( + self, + key_len, + key_device="cpu", + key_dtype=torch.float32, + _original=original_compute, + ): + try: + _original(self, key_len, key_device, key_dtype) + except NotImplementedError: + pass # Deferred: re-initialized in prepare_model() + + rope_class._compute_sin_cos_embeddings = safe_compute + rope_class._tl_patched = True + self._original_rope_compute = original_compute + self._rope_class = rope_class + + if hasattr(module, "OpenELMPreTrainedModel"): + pretrained_class = module.OpenELMPreTrainedModel + if getattr(pretrained_class, "_tl_patched", False): + continue + # Patch 2: Prevent _init_weights from re-randomizing loaded weights. + # transformers v5 calls _init_weights on all modules after weight + # materialization. For modules with real (non-meta) tensors, we must + # skip re-initialization to preserve the loaded checkpoint values. + original_init_weights = pretrained_class._init_weights + + def safe_init_weights( + self, + mod, + _original=original_init_weights, + ): + # Only initialize modules still on meta device (pre-loading) + first_param = next(mod.parameters(), None) + if first_param is not None and first_param.device.type != "meta": + return # Already loaded from checkpoint — don't re-randomize + _original(self, mod) + + pretrained_class._init_weights = safe_init_weights + pretrained_class._tl_patched = True + + def prepare_model(self, hf_model: Any) -> None: + """Post-load fixes for non-persistent buffers zeroed during meta materialization. + + Transformers v5 creates models on meta device then materializes weights from + checkpoint. Non-persistent buffers (registered with persistent=False) are NOT + in the checkpoint, so they materialize as zeros. OpenELM has two critical + non-persistent buffers that must be recomputed: + + 1. RoPE inv_freq — zeroed inv_freq produces cos=1, sin=0 for all positions, + destroying positional information entirely. + 2. causal_mask — zeroed mask means no causal masking, allowing all positions + to attend to future tokens. Single forward passes appear correct (no future + tokens to leak) but autoregressive generation degenerates immediately. + + We also create a synthetic lm_head for weight-tied models. + + Note: We intentionally do NOT restore the original _compute_sin_cos_embeddings. + The safe_compute wrapper is functionally equivalent for real (non-meta) tensors, + and keeping it avoids issues when multiple models are loaded in the same process + (e.g., benchmark suite loading both HF reference and bridge models). + + Args: + hf_model: The loaded HuggingFace OpenELM model + """ + # Ensure use_cache is set on config (transformers v5 raises AttributeError + # for missing config attributes, and OpenELM's custom config omits use_cache) + if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__: + hf_model.config.use_cache = False + + # Fix 1: Always recompute causal_mask (non-persistent buffer). + # After meta→real materialization, the buffer may contain garbage values + # (not all zeros) depending on the materializer's memory state. The old + # check `not cm.any()` only recomputed when all zeros, missing cases where + # garbage values are non-zero. Always recompute to guarantee correctness. + if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "causal_mask"): + cm = hf_model.transformer.causal_mask + if cm is not None: + seq_len = cm.shape[-1] + correct_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device), + diagonal=1, + ) + hf_model.transformer.causal_mask = correct_mask + + # Fix 2: Recompute RoPE inv_freq on all layers (non-persistent buffer zeroed + # during materialization), then force-recompute sin/cos embeddings. + if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): + rope_max = getattr(hf_model.config, "rope_max_length", 4096) + for layer in hf_model.transformer.layers: + if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"): + rope = layer.attn.pos_embedding + # Always recompute inv_freq (non-persistent buffer). + # Like causal_mask, inv_freq may contain garbage after meta + # materialization rather than clean zeros. + correct_inv_freq = 1.0 / ( + rope.freq_constant + ** ( + torch.arange(0, rope.model_dim, 2, dtype=torch.float32) / rope.model_dim + ) + ) + rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device) + # Force-recompute sin/cos (may have been computed with zero inv_freq) + rope._cached_cos = None + rope._cached_sin = None + rope._compute_sin_cos_embeddings(rope_max) + + # Create synthetic lm_head when embeddings are shared + if getattr(hf_model, "lm_head", None) is None and hasattr(hf_model, "transformer"): + embed = hf_model.transformer.token_embeddings + lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False) + lm_head.weight = embed.weight + hf_model.lm_head = lm_head + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up references for OpenELM component testing. + + Args: + hf_model: The HuggingFace OpenELM model instance + bridge_model: The TransformerBridge model (if available) + """ + pass diff --git a/transformer_lens/supported_models.py b/transformer_lens/supported_models.py index 8f1af48ff..6a3f4b767 100644 --- a/transformer_lens/supported_models.py +++ b/transformer_lens/supported_models.py @@ -25,12 +25,6 @@ "codellama/CodeLlama-7b-hf", "codellama/CodeLlama-7b-Instruct-hf", "codellama/CodeLlama-7b-Python-hf", - "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", - "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", - "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", - "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "distilgpt2", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-1.3B", @@ -232,19 +226,10 @@ "roneneldan/TinyStories-Instruct-3M", "roneneldan/TinyStories-Instruct-8M", "roneneldan/TinyStories-Instuct-1Layer-21M", - "stabilityai/stable-code-3b", - "stabilityai/stable-code-instruct-3b", - "stabilityai/stablelm-2-12b", - "stabilityai/stablelm-2-12b-chat", - "stabilityai/stablelm-2-1_6b", - "stabilityai/stablelm-2-1_6b-chat", - "stabilityai/stablelm-2-zephyr-1_6b", - "stabilityai/stablelm-3b-4e1t", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b", "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", - "stabilityai/stablelm-zephyr-3b", "stanford-crfm/alias-gpt2-small-x21", "stanford-crfm/arwen-gpt2-medium-x21", "stanford-crfm/battlestar-gpt2-small-x49", @@ -289,30 +274,6 @@ "codellama/CodeLlama-7b-Instruct-hf", ], "codellama/CodeLlama-7b-Python-hf": ["CodeLlama-7b-python", "codellama/CodeLlama-7b-Python-hf"], - "deepseek-ai/DeepSeek-R1-Distill-Llama-70B": [ - "deepseek-r1-distill-llama-70b", - "deepseek-r1-distill-llama-70b-chat", - ], - "deepseek-ai/DeepSeek-R1-Distill-Llama-8B": [ - "deepseek-r1-distill-llama-8b", - "deepseek-r1-distill-llama-8b-chat", - ], - "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B": [ - "deepseek-r1-distill-qwen-1.5b", - "deepseek-r1-distill-qwen-1.5b-chat", - ], - "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": [ - "deepseek-r1-distill-qwen-14b", - "deepseek-r1-distill-qwen-14b-chat", - ], - "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B": [ - "deepseek-r1-distill-qwen-32b", - "deepseek-r1-distill-qwen-32b-chat", - ], - "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": [ - "deepseek-r1-distill-qwen-7b", - "deepseek-r1-distill-qwen-7b-chat", - ], "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], @@ -605,19 +566,10 @@ "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], - "stabilityai/stable-code-3b": ["stable-code-3b"], - "stabilityai/stable-code-instruct-3b": ["stable-code-instruct-3b"], - "stabilityai/stablelm-2-12b": ["stablelm-2-12b"], - "stabilityai/stablelm-2-12b-chat": ["stablelm-2-12b-chat"], - "stabilityai/stablelm-2-1_6b": ["stablelm-2-1.6b"], - "stabilityai/stablelm-2-1_6b-chat": ["stablelm-2-1.6b-chat"], - "stabilityai/stablelm-2-zephyr-1_6b": ["stablelm-2-zephyr-1.6b"], - "stabilityai/stablelm-3b-4e1t": ["stablelm-3b-4e1t", "stablelm-3b"], "stabilityai/stablelm-base-alpha-3b": ["stablelm-base-alpha-3b", "stablelm-base-3b"], "stabilityai/stablelm-base-alpha-7b": ["stablelm-base-alpha-7b", "stablelm-base-7b"], "stabilityai/stablelm-tuned-alpha-3b": ["stablelm-tuned-alpha-3b", "stablelm-tuned-3b"], "stabilityai/stablelm-tuned-alpha-7b": ["stablelm-tuned-alpha-7b", "stablelm-tuned-7b"], - "stabilityai/stablelm-zephyr-3b": ["stablelm-zephyr-3b"], "stanford-crfm/alias-gpt2-small-x21": [ "stanford-gpt2-small-a", "alias-gpt2-small-x21", diff --git a/transformer_lens/utilities/logits_utils.py b/transformer_lens/utilities/logits_utils.py index 083196fd0..34fbce4e7 100644 --- a/transformer_lens/utilities/logits_utils.py +++ b/transformer_lens/utilities/logits_utils.py @@ -11,12 +11,41 @@ from jaxtyping import Float, Int +def _apply_repetition_penalty( + logits: Float[torch.Tensor, "batch d_vocab"], + tokens: Int[torch.Tensor, "batch pos"], + penalty: float, +) -> Float[torch.Tensor, "batch d_vocab"]: + """Apply HuggingFace-style repetition penalty to logits. + + For each token that has appeared in the sequence, positive logits are divided + by the penalty and negative logits are multiplied by it. + + Args: + logits: Logits tensor of shape [batch, d_vocab] + tokens: Token IDs of shape [batch, pos] + penalty: Repetition penalty value (1.0 = no penalty) + + Returns: + Modified logits tensor + """ + logits = logits.clone() + for batch_idx in range(logits.shape[0]): + # Get unique tokens that have appeared in this sequence + unique_tokens = tokens[batch_idx].unique() + score = logits[batch_idx, unique_tokens] + # Divide positive logits, multiply negative logits + logits[batch_idx, unique_tokens] = torch.where(score > 0, score / penalty, score * penalty) + return logits + + def sample_logits( final_logits: Float[torch.Tensor, "batch d_vocab"], top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: float = 1.0, freq_penalty: float = 0.0, + repetition_penalty: float = 1.0, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, ) -> Int[torch.Tensor, "batch"]: """ @@ -28,17 +57,25 @@ def sample_logits( Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens + Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling. + #! TODO: Finish testing all the edge cases here. Useful testing code: logits = torch.randn(4) print(logits) np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) """ if temperature == 0.0: - # Greedy sampling + # Greedy sampling - still apply repetition penalty before argmax + if repetition_penalty != 1.0 and tokens is not None: + final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty) return final_logits.argmax(dim=-1) else: # Sample from the distribution + # Apply repetition penalty before temperature scaling + if repetition_penalty != 1.0 and tokens is not None: + final_logits = _apply_repetition_penalty(final_logits, tokens, repetition_penalty) + final_logits = final_logits / temperature if freq_penalty > 0: assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index c06351e84..8abdd0a8a 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -192,31 +192,32 @@ def fold_layer_norm_biases( bk_tensor: Optional[torch.Tensor], bv_tensor: Optional[torch.Tensor], ln_bias: torch.Tensor, - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Fold LayerNorm bias into attention biases. + When QKV biases don't exist (e.g., GPT-Neo), creates zero-initialized biases + to absorb the LN bias contribution, similar to how MLP folding handles missing biases. + Args: wq_tensor, wk_tensor, wv_tensor: Weight tensors [n_heads, d_model, d_head] bq_tensor, bk_tensor, bv_tensor: Bias tensors [n_heads, d_head] or None if no bias ln_bias: LayerNorm bias [d_model] Returns: - Tuple of (new_bq, new_bk, new_bv) with folded biases (None if input bias was None) + Tuple of (new_bq, new_bk, new_bv) with folded biases (always non-None) """ - new_bq = ( - ProcessWeights.fold_layer_norm_bias_single(wq_tensor, bq_tensor, ln_bias) - if bq_tensor is not None - else None + + def _zero_bias(w: torch.Tensor) -> torch.Tensor: + return torch.zeros(w.shape[0], w.shape[2], dtype=w.dtype, device=w.device) + + new_bq = ProcessWeights.fold_layer_norm_bias_single( + wq_tensor, bq_tensor if bq_tensor is not None else _zero_bias(wq_tensor), ln_bias ) - new_bk = ( - ProcessWeights.fold_layer_norm_bias_single(wk_tensor, bk_tensor, ln_bias) - if bk_tensor is not None - else None + new_bk = ProcessWeights.fold_layer_norm_bias_single( + wk_tensor, bk_tensor if bk_tensor is not None else _zero_bias(wk_tensor), ln_bias ) - new_bv = ( - ProcessWeights.fold_layer_norm_bias_single(wv_tensor, bv_tensor, ln_bias) - if bv_tensor is not None - else None + new_bv = ProcessWeights.fold_layer_norm_bias_single( + wv_tensor, bv_tensor if bv_tensor is not None else _zero_bias(wv_tensor), ln_bias ) return (new_bq, new_bk, new_bv) @@ -381,89 +382,89 @@ def _fold_layer( ln1_b = tensors["ln1_b"] ln1_w = tensors["ln1_w"] keys = tensors["keys"] - if wq_tensor is None: - return state_dict - assert isinstance(wq_tensor, torch.Tensor) - assert isinstance(keys, dict) - if wk_tensor is not None: - assert isinstance(wk_tensor, torch.Tensor) - if wv_tensor is not None: - assert isinstance(wv_tensor, torch.Tensor) - if bq_tensor is not None: - assert isinstance(bq_tensor, torch.Tensor) - if bk_tensor is not None: - assert isinstance(bk_tensor, torch.Tensor) - if bv_tensor is not None: - assert isinstance(bv_tensor, torch.Tensor) - # CRITICAL FIX: For RMS norm (Gemma), ln1_b is None. We must still fold ln1_w! - # Only require ln1_w to be non-None for folding - if ln1_w is not None: - assert isinstance(ln1_w, torch.Tensor) - # Only fold biases if they exist (LayerNorm). RMS norm has no biases. - if fold_biases and ln1_b is not None: - assert isinstance(ln1_b, torch.Tensor) - if all( - ( - t is not None - for t in [wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor] - ) - ): - # Type narrowing for mypy + + # Fold attention LN into QKV weights (only if separate Q/K/V weights exist). + # Models with combined QKV (e.g., OpenELM's qkv_proj) won't have separate + # Q/K/V weights — skip attention folding but still proceed to MLP folding. + if wq_tensor is not None: + assert isinstance(wq_tensor, torch.Tensor) + assert isinstance(keys, dict) + if wk_tensor is not None: + assert isinstance(wk_tensor, torch.Tensor) + if wv_tensor is not None: + assert isinstance(wv_tensor, torch.Tensor) + if bq_tensor is not None: + assert isinstance(bq_tensor, torch.Tensor) + if bk_tensor is not None: + assert isinstance(bk_tensor, torch.Tensor) + if bv_tensor is not None: + assert isinstance(bv_tensor, torch.Tensor) + # CRITICAL FIX: For RMS norm (Gemma), ln1_b is None. We must still fold ln1_w! + # Only require ln1_w to be non-None for folding + if ln1_w is not None: + assert isinstance(ln1_w, torch.Tensor) + # Only fold biases if they exist (LayerNorm). RMS norm has no biases. + if fold_biases and ln1_b is not None: + assert isinstance(ln1_b, torch.Tensor) + # fold_layer_norm_biases handles missing QKV biases by creating + # zero-initialized ones, so we always fold (no all(...) guard needed). assert wq_tensor is not None assert wk_tensor is not None assert wv_tensor is not None bq_tensor, bk_tensor, bv_tensor = ProcessWeights.fold_layer_norm_biases( wq_tensor, wk_tensor, wv_tensor, bq_tensor, bk_tensor, bv_tensor, ln1_b ) - if keys["ln1_b"] in state_dict: - state_dict[keys["ln1_b"]] = torch.zeros_like(ln1_b) - alternate_b_key = ( - keys["ln1_b"].replace("ln_1", "ln1") - if "ln_1" in keys["ln1_b"] - else keys["ln1_b"].replace("ln1", "ln_1") + if keys["ln1_b"] in state_dict: + state_dict[keys["ln1_b"]] = torch.zeros_like(ln1_b) + alternate_b_key = ( + keys["ln1_b"].replace("ln_1", "ln1") + if "ln_1" in keys["ln1_b"] + else keys["ln1_b"].replace("ln1", "ln_1") + ) + if alternate_b_key != keys["ln1_b"] and alternate_b_key in state_dict: + state_dict[alternate_b_key] = torch.zeros_like(ln1_b) + # Fold ln1_w into QKV weights (works for both LayerNorm and RMS norm) + if wk_tensor is not None and wv_tensor is not None: + wq_tensor, wk_tensor, wv_tensor = ProcessWeights.fold_layer_norm_weights( + wq_tensor, wk_tensor, wv_tensor, ln1_w + ) + # After folding, set ln1.w to identity (all 1.0). + # For HookedTransformer with Pre normalization (LNPre/RMSNormPre), load_state_dict + # will ignore these weights since those layers have no weight parameters. + # For TransformerBridge and other models, the weights must be 1.0 after folding. + if keys["ln1_w"] in state_dict: + state_dict[keys["ln1_w"]] = torch.ones_like(ln1_w) + alternate_w_key = ( + keys["ln1_w"].replace("ln_1", "ln1") + if "ln_1" in keys["ln1_w"] + else keys["ln1_w"].replace("ln1", "ln_1") ) - if alternate_b_key != keys["ln1_b"] and alternate_b_key in state_dict: - state_dict[alternate_b_key] = torch.zeros_like(ln1_b) - # Fold ln1_w into QKV weights (works for both LayerNorm and RMS norm) - if wk_tensor is not None and wv_tensor is not None: - wq_tensor, wk_tensor, wv_tensor = ProcessWeights.fold_layer_norm_weights( - wq_tensor, wk_tensor, wv_tensor, ln1_w + if alternate_w_key != keys["ln1_w"] and alternate_w_key in state_dict: + state_dict[alternate_w_key] = torch.ones_like(ln1_w) + if center_weights and wk_tensor is not None and (wv_tensor is not None): + wq_tensor, wk_tensor, wv_tensor = ProcessWeights.center_attention_weights( + wq_tensor, wk_tensor, wv_tensor ) - # After folding, set ln1.w to identity (all 1.0). - # For HookedTransformer with Pre normalization (LNPre/RMSNormPre), load_state_dict - # will ignore these weights since those layers have no weight parameters. - # For TransformerBridge and other models, the weights must be 1.0 after folding. - if keys["ln1_w"] in state_dict: - state_dict[keys["ln1_w"]] = torch.ones_like(ln1_w) - alternate_w_key = ( - keys["ln1_w"].replace("ln_1", "ln1") - if "ln_1" in keys["ln1_w"] - else keys["ln1_w"].replace("ln1", "ln_1") + state_dict = ProcessWeights._store_processed_attention_tensors( + state_dict, + keys, + wq_tensor, + wk_tensor, + wv_tensor, + bq_tensor, + bk_tensor, + bv_tensor, + adapter, + cfg, + layer, ) - if alternate_w_key != keys["ln1_w"] and alternate_w_key in state_dict: - state_dict[alternate_w_key] = torch.ones_like(ln1_w) - if center_weights and wk_tensor is not None and (wv_tensor is not None): - wq_tensor, wk_tensor, wv_tensor = ProcessWeights.center_attention_weights( - wq_tensor, wk_tensor, wv_tensor - ) - state_dict = ProcessWeights._store_processed_attention_tensors( - state_dict, - keys, - wq_tensor, - wk_tensor, - wv_tensor, - bq_tensor, - bk_tensor, - bv_tensor, - adapter, - cfg, - layer, - ) # NOTE: For Gemma 2/3 with use_normalization_before_and_after=True, ln1_post.w exists # and should KEEP its original values (not be set to 1.0). It applies normalization # AFTER the attention output, which is independent of the ln1 folding we just did. + # Always fold MLP layer norm, even if attention QKV weights weren't available. + # MLP folding is independent of attention folding. state_dict = ProcessWeights._fold_mlp_layer_norm( state_dict, cfg, layer, fold_biases, center_weights, adapter ) @@ -577,11 +578,14 @@ def _fold_mlp_layer_norm( mlp_W_gate = ProcessWeights.convert_tensor_to_tl_format( mlp_W_gate_key, state_dict, state_dict.get(mlp_W_gate_key), cfg, adapter, layer ) - assert mlp_W_gate is not None, f"MLP W_gate not found at key {mlp_W_gate_key}" - new_mlp_W_gate = mlp_W_gate * ln2_w_broadcast - state_dict[mlp_W_gate_key] = ProcessWeights.convert_tensor_to_hf_format( - mlp_W_gate_key, new_mlp_W_gate, cfg, adapter, layer - ) + # For models with combined gate+up projections (e.g., OpenELM's proj_1), + # the separate gate weight won't exist — LN was already folded into the + # combined "in" weight above. + if mlp_W_gate is not None: + new_mlp_W_gate = mlp_W_gate * ln2_w_broadcast + state_dict[mlp_W_gate_key] = ProcessWeights.convert_tensor_to_hf_format( + mlp_W_gate_key, new_mlp_W_gate, cfg, adapter, layer + ) # After folding, set ln2.w to identity (all 1.0). # For HookedTransformer with Pre normalization, load_state_dict will ignore these. # For TransformerBridge and other models, the weights must be 1.0 after folding. @@ -1063,7 +1067,15 @@ def center_writing_weights( mlp_W_out_key, state_dict, state_dict.get(mlp_W_out_key), cfg, adapter, l ) assert mlp_W_out is not None, f"MLP W_out not found at key {mlp_W_out_key}" - mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) + # Center along d_model dimension. In TL format W_out is [d_mlp, d_model] + # so d_model is dim=-1. But bridge adapters may keep HF format + # [d_model, d_mlp] where d_model is dim=0. Detect via cfg.d_model. + if mlp_W_out.shape[-1] == cfg.d_model: + mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) + elif mlp_W_out.shape[0] == cfg.d_model: + mlp_W_out = mlp_W_out - mlp_W_out.mean(0, keepdim=True) + else: + mlp_W_out = mlp_W_out - mlp_W_out.mean(-1, keepdim=True) state_dict[mlp_W_out_key] = ProcessWeights.convert_tensor_to_hf_format( mlp_W_out_key, mlp_W_out, cfg, adapter, l ) @@ -1085,7 +1097,7 @@ def center_writing_weights( @staticmethod def center_unembed( - state_dict: Dict[str, torch.Tensor], adapter=None + state_dict: Dict[str, torch.Tensor], cfg=None, adapter=None ) -> Dict[str, torch.Tensor]: """Center the unembedding weights W_U. @@ -1097,6 +1109,7 @@ def center_unembed( Args: state_dict (Dict[str, torch.Tensor]): State dict of the model. + cfg: Model configuration (used to determine d_vocab for correct centering dimension). adapter: Optional architecture adapter for parameter key translation. Returns: @@ -1116,7 +1129,20 @@ def center_unembed( unembed_W_U_key, state_dict, state_dict.get(unembed_W_U_key), None, adapter, None ) assert W_U is not None, f"Unembed weight not found at key {unembed_W_U_key}" - W_U = W_U - W_U.mean(-1, keepdim=True) + + # Determine which dimension is d_vocab to center along. + # In TL format W_U is [d_model, d_vocab], so we center along dim=-1. + # But if convert_tensor_to_tl_format was a no-op (empty weight_processing_conversions), + # W_U may still be in HF format [d_vocab, d_model]. Centering along the wrong + # dimension is NOT softmax-invariant and corrupts model output. + vocab_dim = -1 # Default: TL format [d_model, d_vocab] + if cfg is not None: + d_vocab = getattr(cfg, "d_vocab", None) + if d_vocab is not None: + if W_U.shape[0] == d_vocab and W_U.shape[-1] != d_vocab: + # HF format [d_vocab, d_model] — center along dim=0 + vocab_dim = 0 + W_U = W_U - W_U.mean(vocab_dim, keepdim=True) state_dict[unembed_W_U_key] = ProcessWeights.convert_tensor_to_hf_format( unembed_W_U_key, W_U, None, adapter, None ) @@ -1318,22 +1344,19 @@ def process_weights( state_dict = ProcessWeights.fold_layer_norm( state_dict, cfg, fold_biases=False, center_weights=False, adapter=adapter ) - # For RMS normalization, set all layer norm weights to 1.0 after folding - # since RMS folding doesn't result in identity weights like LayerNorm does - for layer_idx in range(cfg.n_layers): - for ln_name in ["ln1", "ln2"]: - ln_w_key = ProcessWeights._get_param_key( - f"blocks.{layer_idx}.{ln_name}.w", adapter - ) - if ln_w_key in state_dict: - state_dict[ln_w_key] = torch.ones_like(state_dict[ln_w_key]) + # Note: Each folding function (_fold_layer for attention, _fold_mlp_layer_norm + # for MLP) sets its own LN weights to 1.0 after successful folding. + # We must NOT unconditionally set all LN weights to 1.0 here, because + # models with combined QKV projections (e.g., OpenELM's qkv_proj) may + # not be able to fold attention LN — setting ln1.w=1.0 without folding + # destroys the RMS scaling. if center_writing_weights: if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"] and ( not getattr(cfg, "final_rms", False) ): state_dict = ProcessWeights.center_writing_weights(state_dict, cfg, adapter=adapter) if center_unembed: - state_dict = ProcessWeights.center_unembed(state_dict, adapter=adapter) + state_dict = ProcessWeights.center_unembed(state_dict, cfg=cfg, adapter=adapter) if fold_value_biases: state_dict = ProcessWeights.fold_value_biases(state_dict, cfg, adapter=adapter) if center_writing_weights and getattr(cfg, "normalization_type", "LN") in [ @@ -1587,7 +1610,23 @@ def convert_tensor_to_tl_format( # Skip conversion for optional parameters that don't exist (e.g. biases) if tensor is None and param_name not in model_state_dict: return None - # Let ParamProcessingConversion handle the fetching and conversion + # Try ParamProcessingConversion.convert() first (uses source_key + # to fetch from state dict — needed for split conversions like + # GPT-2's QKV). If source_key resolves to a missing key and we + # already have the tensor, fall back to applying the tensor + # conversion directly (needed for adapters like GPT-Neo whose + # source_key references HF keys not in the bridge state dict). + if ( + hasattr(param_conversion, "source_key") + and param_conversion.source_key is not None + ): + resolved_key = param_conversion._resolve_key( + param_name, param_conversion.source_key + ) + if resolved_key not in model_state_dict and tensor is not None: + return param_conversion.tensor_conversion.convert( + tensor, model_state_dict + ) return param_conversion.convert(model_state_dict, param_name) else: # No conversion defined, return tensor as-is (may be None for optional params)