Skip to content

Add LoRA support to StaticAttention for split_mha=False#18345

Open
lucylq wants to merge 6 commits intogh/lucylq/142/headfrom
gh/lucylq/144/head
Open

Add LoRA support to StaticAttention for split_mha=False#18345
lucylq wants to merge 6 commits intogh/lucylq/142/headfrom
gh/lucylq/144/head

Conversation

@lucylq
Copy link
Contributor

@lucylq lucylq commented Mar 19, 2026

When ModelArgs.target_modules is set, create LoRALinear instead of
nn.Linear for targeted q/k/v/o projections. Only applies to
split_mha=False path. Existing behavior unchanged when target_modules
is None.

Authored with Claude.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 19, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18345

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Cancelled Job, 7 Pending, 2 Unrelated Failures

As of commit 10f6022 with merge base dd7464a (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

lucylq added 2 commits March 19, 2026 14:50
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds LoRA-aware projection construction to the StaticAttention implementation when using the non-split MHA path (split_mha=False), so that q/k/v/o projections can become LoRALinear based on ModelArgs.target_modules, while keeping existing behavior unchanged when target_modules is None.

Changes:

  • For split_mha=False, conditionally instantiate LoRALinear for q/k/v projections when their corresponding target names are present in config.target_modules.
  • For split_mha=False, conditionally instantiate LoRALinear for the output projection (wo) when output_proj or o_proj is targeted.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +815 to 829
has_lora = config.target_modules is not None
_PROJ_TARGET = {
"wqs": ("q_proj", self.dim, self.head_dim * self.n_heads),
"wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads),
"wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads),
}
for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items():
if has_lora and target in config.target_modules:
proj = LoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=config.r,
alpha=config.lora_alpha,
use_bias=self.attention_qkv_bias,
)
Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When config.target_modules is set but config.r and/or config.lora_alpha are left as None (both are Optional in ModelArgs), this path will attempt to construct LoRALinear(rank=None, alpha=None) and fail with a low-signal TypeError. Consider adding an explicit validation (ValueError with a clear message) before creating any LoRALinear modules, similar to LoRAFeedForward.

Copilot uses AI. Check for mistakes.
Comment on lines +815 to 833
has_lora = config.target_modules is not None
_PROJ_TARGET = {
"wqs": ("q_proj", self.dim, self.head_dim * self.n_heads),
"wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads),
"wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads),
}
for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items():
if has_lora and target in config.target_modules:
proj = LoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=config.r,
alpha=config.lora_alpha,
use_bias=self.attention_qkv_bias,
)
]
)
else:
proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias)
setattr(self, attr, nn.ModuleList([proj]))

Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior is introduced here (direct StaticAttention(..., split_mha=False) now conditionally builds LoRALinear based on config.target_modules), but existing tests in test_static_attention.py only exercise LoRA via from_attention_mha. Please add a unit test that directly constructs StaticAttention with split_mha=False and target_modules set, and asserts the expected projection types and a forward equivalence check.

Copilot uses AI. Check for mistakes.
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants