Skip to content

Update lora def#18342

Merged
lucylq merged 4 commits intomainfrom
gh/lucylq/141/head
Mar 20, 2026
Merged

Update lora def#18342
lucylq merged 4 commits intomainfrom
gh/lucylq/141/head

Conversation

@lucylq
Copy link
Contributor

@lucylq lucylq commented Mar 19, 2026

Summary

Update lora def to use nn.linear instead of torch.nn.functional.linear

  1. Remove debug string from FQN. Otherwise we get different quant indices for lora and non-lora (because lora_a is quantized), after this change, and the foundation weights are no longer shareable.
  2. Having nn.Linear as a submodule allows torchao quant to capture it, and we don't need custom logic filtering LoraLinear. as lora_a and lora_b weren't captured before, exclude these to maintain current behavior. Later we can remove this, optionally.
  3. Having nn.Linear as submodule also means we need to remap lora weights, as weight names have an extra 'linear' in them.
  4. Add @Property for weight, bias so that it remains BC and can be treated as a regular linear module. This is used in [load_weights_from_attention.py](
    def load_weights_from_attention_mha(

Test plan

CI

[ghstack-poisoned]
@lucylq lucylq requested a review from digantdesai as a code owner March 19, 2026 21:40
Copilot AI review requested due to automatic review settings March 19, 2026 21:40
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 19, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18342

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 5b130e1 with merge base 02bad9d (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 19, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates the Llama LoRA linear module definition and related export/quantization behavior, aiming to improve module structure/state_dict compatibility and make serialized constant keys more deterministic.

Changes:

  • Refactors LoRALinear to wrap an internal nn.Linear submodule (with backward-compatible state_dict loading).
  • Simplifies the 8da{4,8}w quantization filter to only target nn.Linear modules.
  • Changes XNNPACK constant named_key generation to be SHA256-only (no tensor-name prefix) and updates the LoRA CI expected output text.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
examples/models/llama/source_transformation/quantize.py Simplifies the module filter used during TorchAO 8da quantization.
examples/models/llama/lora.py Refactors LoRALinear to contain self.linear and adds state_dict key remapping for backward compatibility.
backends/xnnpack/operators/node_visitor.py Changes serialized constant named_key to be content-hash only.
.ci/scripts/test_lora.sh Updates expected output string for quantized LoRA test case.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

)
sha256_hash = hashlib.sha256(bytes(array))
named_key = tensor.name + "_" + sha256_hash.hexdigest()
named_key = sha256_hash.hexdigest()
Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing named_key to only the SHA256 digest makes keys content-addressed and no longer namespaced per tensor. When exporting multiple PTD/PTE named data maps that are loaded together (e.g., foundation + LoRA program-data separation), MergedDataMap::load rejects duplicate keys across maps even if the underlying bytes are identical, so identical constants (common for zero/one-initialized tensors) can now cause runtime load failures. Consider incorporating a stable per-parameter identifier into the key (e.g., get_attr_node.target/FQN) or otherwise namespacing by component while keeping determinism across exports.

Suggested change
named_key = sha256_hash.hexdigest()
# Use a per-tensor namespace in the key to avoid collisions across
# different PTD/PTE maps that may contain identical constant bytes.
tensor_name = getattr(tensor, "name", None)
if tensor_name is not None:
named_key = f"{tensor_name}:{sha256_hash.hexdigest()}"
else:
named_key = sha256_hash.hexdigest()

Copilot uses AI. Check for mistakes.
lucylq added 2 commits March 19, 2026 14:50
[ghstack-poisoned]
[ghstack-poisoned]
Copilot AI review requested due to automatic review settings March 19, 2026 21:53
@lucylq lucylq mentioned this pull request Mar 19, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

backends/xnnpack/operators/node_visitor.py:633

  • named_key is now only the SHA256 digest of the constant bytes. If constants are split across multiple external named-data files (via delegate_constant_tag / external_tag), different tensors that happen to have identical bytes (e.g., many zero-filled bias tensors) will produce the same key in different files. Module merges those files into a MergedDataMap, which rejects duplicate keys across maps and will fail to load. Consider keeping per-tensor uniqueness in the key (e.g., include tensor.name/FQN as a suffix/prefix) and rely on NamedDataStore's internal buffer deduplication to avoid duplicated storage, or otherwise ensure keys remain unique across external shards.
        sha256_hash = hashlib.sha256(bytes(array))
        named_key = sha256_hash.hexdigest()

        size = const_val.untyped_storage().nbytes()
        xnn_graph.constant_data.append(
            ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
        )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +43 to +50
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Remap keys to "linear.*"
for attr in ("weight", "bias"):
old_key = prefix + attr
new_key = prefix + "linear." + attr
if old_key in state_dict and new_key not in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRALinear now changes its state_dict surface from weight/bias to linear.weight/linear.bias and adds a compatibility remap in _load_from_state_dict. This is a behavior that’s easy to regress (especially for loading older checkpoints through the full model), but there’s no test covering the backward-compat load path. Please add a unit test that constructs an old-format state_dict (with ...weight/...bias) and verifies it loads into the updated module and produces identical outputs/parameters.

Copilot uses AI. Check for mistakes.
Comment on lines 146 to +154
def filter_fn(m, fqn):
# Check if it's a regular nn.Linear
is_linear = isinstance(m, nn.Linear)

# Check if it's a LoRALinear (which has a base weight parameter to quantize)
is_lora_linear = False
try:
from executorch.examples.models.llama.lora import LoRALinear

is_lora_linear = isinstance(m, LoRALinear)
except ImportError:
pass

# Check if the weight shape is compatible with group size
has_shape_compatible_with_group_size = False
if is_linear or is_lora_linear:
if group_size == 0:
has_shape_compatible_with_group_size = True
else:
has_shape_compatible_with_group_size = (
m.weight.shape[1] % group_size == 0
)
return (
is_linear or is_lora_linear
) and has_shape_compatible_with_group_size
if not isinstance(m, nn.Linear):
return False
parts = fqn.split(".")
if "lora_a" in parts or "lora_b" in parts:
return False
if group_size == 0:
return True
return m.weight.shape[1] % group_size == 0
Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new filter_fn behavior for 8da4w/8da8w quantization relies on module FQNs to skip LoRA adapter layers (lora_a/lora_b) while still quantizing the base projection (*.linear). This is subtle and currently untested; please add a regression test that runs this quantization path on a small model with LoRALinear and asserts that lora_a/lora_b weights remain unquantized while the base linear weight is quantized (and group_size filtering behaves as expected).

Copilot uses AI. Check for mistakes.
[ghstack-poisoned]
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Mar 19, 2026

@lucylq has imported this pull request. If you are a Meta employee, you can view this in D97393585.

@lucylq lucylq merged commit dd7464a into main Mar 20, 2026
143 of 145 checks passed
@lucylq lucylq deleted the gh/lucylq/141/head branch March 20, 2026 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants