-
Notifications
You must be signed in to change notification settings - Fork 886
Update lora def #18342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update lora def #18342
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,22 +26,31 @@ def __init__( | |
| self.rank = rank | ||
| self.alpha = alpha | ||
| self.use_bias = use_bias | ||
| self.dropout = dropout | ||
|
|
||
| linear = nn.Linear(in_dim, out_dim, bias=use_bias) | ||
| weight = linear.weight | ||
| bias = linear.bias if self.use_bias else None | ||
| self.register_parameter("weight", nn.Parameter(weight)) | ||
| self.register_parameter( | ||
| "bias", nn.Parameter(bias) if bias is not None else None | ||
| ) | ||
|
|
||
| self.linear = nn.Linear(in_dim, out_dim, bias=use_bias) | ||
| self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() | ||
| self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) | ||
| self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) | ||
|
|
||
| @property | ||
| def weight(self): | ||
| return self.linear.weight | ||
|
|
||
| @property | ||
| def bias(self): | ||
| return self.linear.bias | ||
|
|
||
| def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | ||
| # Remap keys to "linear.*" | ||
| for attr in ("weight", "bias"): | ||
| old_key = prefix + attr | ||
| new_key = prefix + "linear." + attr | ||
| if old_key in state_dict and new_key not in state_dict: | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | ||
|
Comment on lines
+43
to
+50
|
||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| out = torch.nn.functional.linear(x, self.weight, self.bias) | ||
| out = self.linear(x) | ||
| lora_out = self.lora_a(self.dropout(x)) | ||
| lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,30 +144,14 @@ def quantize( # noqa C901 | |
| from torchao.utils import unwrap_tensor_subclass | ||
|
|
||
| def filter_fn(m, fqn): | ||
| # Check if it's a regular nn.Linear | ||
| is_linear = isinstance(m, nn.Linear) | ||
|
|
||
| # Check if it's a LoRALinear (which has a base weight parameter to quantize) | ||
| is_lora_linear = False | ||
| try: | ||
| from executorch.examples.models.llama.lora import LoRALinear | ||
|
|
||
| is_lora_linear = isinstance(m, LoRALinear) | ||
| except ImportError: | ||
| pass | ||
|
|
||
| # Check if the weight shape is compatible with group size | ||
| has_shape_compatible_with_group_size = False | ||
| if is_linear or is_lora_linear: | ||
| if group_size == 0: | ||
| has_shape_compatible_with_group_size = True | ||
| else: | ||
| has_shape_compatible_with_group_size = ( | ||
| m.weight.shape[1] % group_size == 0 | ||
| ) | ||
| return ( | ||
| is_linear or is_lora_linear | ||
| ) and has_shape_compatible_with_group_size | ||
| if not isinstance(m, nn.Linear): | ||
| return False | ||
| parts = fqn.split(".") | ||
| if "lora_a" in parts or "lora_b" in parts: | ||
| return False | ||
| if group_size == 0: | ||
| return True | ||
| return m.weight.shape[1] % group_size == 0 | ||
|
Comment on lines
146
to
+154
|
||
|
|
||
| weight_dtype = torch.int4 if qmode == "8da4w" else torch.int8 | ||
| quantize_( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing
named_keyto only the SHA256 digest makes keys content-addressed and no longer namespaced per tensor. When exporting multiple PTD/PTE named data maps that are loaded together (e.g., foundation + LoRA program-data separation),MergedDataMap::loadrejects duplicate keys across maps even if the underlying bytes are identical, so identical constants (common for zero/one-initialized tensors) can now cause runtime load failures. Consider incorporating a stable per-parameter identifier into the key (e.g.,get_attr_node.target/FQN) or otherwise namespacing by component while keeping determinism across exports.