diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 99809a063bc..c36f5a70358 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -812,38 +812,46 @@ def __init__( [StaticVCache(layer_id, i) for i in range(self.n_kv_heads)] ) else: - self.wqs = nn.ModuleList( - [ - nn.Linear( - self.dim, - self.head_dim * self.n_heads, - bias=self.attention_qkv_bias, - ) - ] - ) - self.wks = nn.ModuleList( - [ - nn.Linear( - self.dim, - self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, - ) - ] - ) - self.wvs = nn.ModuleList( - [ - nn.Linear( - self.dim, - self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, + has_lora = config.target_modules is not None + _PROJ_TARGET = { + "wqs": ("q_proj", self.dim, self.head_dim * self.n_heads), + "wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads), + "wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads), + } + for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items(): + if has_lora and target in config.target_modules: + proj = LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=config.r, + alpha=config.lora_alpha, + use_bias=self.attention_qkv_bias, ) - ] - ) + else: + proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias) + setattr(self, attr, nn.ModuleList([proj])) self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + wo_use_lora = ( + not self.split_mha + and config.target_modules is not None + and ( + "output_proj" in config.target_modules + or "o_proj" in config.target_modules + ) + ) + if wo_use_lora: + self.wo = LoRALinear( + in_dim=self.n_heads * self.head_dim, + out_dim=self.dim, + rank=config.r, + alpha=config.lora_alpha, + use_bias=False, + ) + else: + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.rope = _Rope(rope.params) self.layer_id = layer_id