-
Notifications
You must be signed in to change notification settings - Fork 885
Add LoRA support to StaticAttention for split_mha=False #18345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/lucylq/142/head
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -812,38 +812,46 @@ def __init__( | |
| [StaticVCache(layer_id, i) for i in range(self.n_kv_heads)] | ||
| ) | ||
| else: | ||
| self.wqs = nn.ModuleList( | ||
| [ | ||
| nn.Linear( | ||
| self.dim, | ||
| self.head_dim * self.n_heads, | ||
| bias=self.attention_qkv_bias, | ||
| ) | ||
| ] | ||
| ) | ||
| self.wks = nn.ModuleList( | ||
| [ | ||
| nn.Linear( | ||
| self.dim, | ||
| self.head_dim * self.n_kv_heads, | ||
| bias=self.attention_qkv_bias, | ||
| ) | ||
| ] | ||
| ) | ||
| self.wvs = nn.ModuleList( | ||
| [ | ||
| nn.Linear( | ||
| self.dim, | ||
| self.head_dim * self.n_kv_heads, | ||
| bias=self.attention_qkv_bias, | ||
| has_lora = config.target_modules is not None | ||
| _PROJ_TARGET = { | ||
| "wqs": ("q_proj", self.dim, self.head_dim * self.n_heads), | ||
| "wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads), | ||
| "wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads), | ||
| } | ||
| for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items(): | ||
| if has_lora and target in config.target_modules: | ||
| proj = LoRALinear( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| rank=config.r, | ||
| alpha=config.lora_alpha, | ||
| use_bias=self.attention_qkv_bias, | ||
| ) | ||
| ] | ||
| ) | ||
| else: | ||
| proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias) | ||
| setattr(self, attr, nn.ModuleList([proj])) | ||
|
|
||
|
Comment on lines
+815
to
833
|
||
| self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) | ||
| self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) | ||
|
|
||
| self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) | ||
| wo_use_lora = ( | ||
| not self.split_mha | ||
| and config.target_modules is not None | ||
| and ( | ||
| "output_proj" in config.target_modules | ||
| or "o_proj" in config.target_modules | ||
| ) | ||
| ) | ||
| if wo_use_lora: | ||
| self.wo = LoRALinear( | ||
| in_dim=self.n_heads * self.head_dim, | ||
| out_dim=self.dim, | ||
| rank=config.r, | ||
| alpha=config.lora_alpha, | ||
| use_bias=False, | ||
| ) | ||
| else: | ||
| self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) | ||
| self.rope = _Rope(rope.params) | ||
| self.layer_id = layer_id | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
config.target_modulesis set butconfig.rand/orconfig.lora_alphaare left as None (both are Optional in ModelArgs), this path will attempt to constructLoRALinear(rank=None, alpha=None)and fail with a low-signal TypeError. Consider adding an explicit validation (ValueError with a clear message) before creating anyLoRALinearmodules, similar toLoRAFeedForward.