Skip to content

Commit ccce13e

Browse files
committed
Fix HookedTransformerConfig rotary_base types
rotary_base is frequently set to floats in the code but was typed as an int. https://github.com/TransformerLensOrg/TransformerLens/blob/9c5a2a81674d5bcefa641c816b66e9827ccdf637/transformer_lens/loading_from_pretrained.py#L1984 HF confgs' always have rope_theta as a float: https://github.com/huggingface/transformers/blob/c38b2fb78eaedd4261a0e446f7976345cd1c7f1b/src/transformers/modeling_rope_utils.py#L645 This updates the type to float, and since beartype doesn't consider int to be a subtype of float, updates all of the places that hard-coded ints to be floats instead. See: beartype/beartype#66
1 parent 9c5a2a8 commit ccce13e

3 files changed

Lines changed: 18 additions & 18 deletions

File tree

transformer_lens/HookedTransformerConfig.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class HookedTransformerConfig:
194194
Defaults to 8.0.
195195
use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before
196196
computing attention scores. Used by Gemma 3 models. Defaults to False.
197-
rotary_base_local (int, *optional*): The base for rotary positional embeddings in local
197+
rotary_base_local (float, *optional*): The base for rotary positional embeddings in local
198198
attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3)
199199
which use different RoPE bases for local (10k) and global (1M) attention. Defaults
200200
to None, which means the standard rotary_base is used for all layers.
@@ -252,9 +252,9 @@ class HookedTransformerConfig:
252252
tokenizer_prepends_bos: Optional[bool] = None
253253
n_key_value_heads: Optional[int] = None
254254
post_embedding_ln: bool = False
255-
rotary_base: int = 10000
255+
rotary_base: float = 10000.0
256256
rotary_base_local: Optional[
257-
int
257+
float
258258
] = None # For models with different RoPE bases per attention type (e.g., Gemma 3)
259259
trust_remote_code: bool = False
260260
rotary_adjacent_pairs: bool = False

transformer_lens/components/abstract_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def calculate_sin_cos_rotary(
532532
self,
533533
rotary_dim: int,
534534
n_ctx: int,
535-
base: int = 10000,
535+
base: float = 10000.0,
536536
dtype: torch.dtype = torch.float32,
537537
) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]:
538538
"""

transformer_lens/loading_from_pretrained.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
903903
"rotary_dim": 4096 // 32,
904904
"final_rms": True,
905905
"gated_mlp": True,
906-
"rotary_base": 1000000,
906+
"rotary_base": 1000000.0,
907907
}
908908
if "python" in official_model_name.lower():
909909
# The vocab size of python version of CodeLlama-7b is 32000
@@ -1474,7 +1474,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
14741474
"initializer_range": hf_config.initializer_range,
14751475
"normalization_type": "RMS",
14761476
"positional_embedding_type": "rotary",
1477-
"rotary_base": int(hf_config.rope_theta),
1477+
"rotary_base": hf_config.rope_theta,
14781478
"rotary_adjacent_pairs": False,
14791479
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
14801480
"tokenizer_prepends_bos": True,
@@ -1508,7 +1508,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
15081508
"initializer_range": hf_config.initializer_range,
15091509
"normalization_type": "RMS",
15101510
"positional_embedding_type": "rotary",
1511-
"rotary_base": int(hf_config.rope_theta),
1511+
"rotary_base": hf_config.rope_theta,
15121512
"rotary_adjacent_pairs": False,
15131513
"rotary_dim": (
15141514
hf_config.head_dim
@@ -1624,8 +1624,8 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
16241624
"act_fn": "gelu_pytorch_tanh",
16251625
"initializer_range": 0.02,
16261626
"normalization_type": "RMS",
1627-
"rotary_base": 1000000, # Global attention layers
1628-
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1627+
"rotary_base": 1000000.0, # Global attention layers
1628+
"rotary_base_local": 10000.0, # Local attention layers (per Gemma 3 paper)
16291629
"positional_embedding_type": "rotary",
16301630
"use_attn_scale": True,
16311631
"n_key_value_heads": 1,
@@ -1670,8 +1670,8 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
16701670
"act_fn": "gelu_pytorch_tanh",
16711671
"initializer_range": 0.02,
16721672
"normalization_type": "RMS",
1673-
"rotary_base": 1000000, # Global attention layers
1674-
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1673+
"rotary_base": 1000000.0, # Global attention layers
1674+
"rotary_base_local": 10000.0, # Local attention layers (per Gemma 3 paper)
16751675
"positional_embedding_type": "rotary",
16761676
"use_attn_scale": True,
16771677
"n_key_value_heads": 1,
@@ -1726,8 +1726,8 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
17261726
"act_fn": "gelu_pytorch_tanh",
17271727
"initializer_range": 0.02,
17281728
"normalization_type": "RMS",
1729-
"rotary_base": 1000000, # Global attention layers
1730-
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1729+
"rotary_base": 1000000.0, # Global attention layers
1730+
"rotary_base_local": 10000.0, # Local attention layers (per Gemma 3 paper)
17311731
"positional_embedding_type": "rotary",
17321732
"use_attn_scale": True,
17331733
"n_key_value_heads": 4,
@@ -1788,8 +1788,8 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
17881788
"act_fn": "gelu_pytorch_tanh",
17891789
"initializer_range": 0.02,
17901790
"normalization_type": "RMS",
1791-
"rotary_base": 1000000, # Global attention layers
1792-
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1791+
"rotary_base": 1000000.0, # Global attention layers
1792+
"rotary_base_local": 10000.0, # Local attention layers (per Gemma 3 paper)
17931793
"positional_embedding_type": "rotary",
17941794
"use_attn_scale": True,
17951795
"n_key_value_heads": 8,
@@ -1869,8 +1869,8 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
18691869
"act_fn": "gelu_pytorch_tanh",
18701870
"initializer_range": 0.02,
18711871
"normalization_type": "RMS",
1872-
"rotary_base": 1000000, # Global attention layers
1873-
"rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1872+
"rotary_base": 1000000.0, # Global attention layers
1873+
"rotary_base_local": 10000.0, # Local attention layers (per Gemma 3 paper)
18741874
"positional_embedding_type": "rotary",
18751875
"use_attn_scale": True,
18761876
"n_key_value_heads": 16,
@@ -1959,7 +1959,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
19591959
"act_fn": "gelu_new",
19601960
"initializer_range": 0.02,
19611961
"normalization_type": "RMS",
1962-
"rotary_base": 10000,
1962+
"rotary_base": 10000.0,
19631963
"rotary_dim": 256,
19641964
"positional_embedding_type": "rotary",
19651965
"use_attn_scale": True,

0 commit comments

Comments
 (0)