From da7f4897d05a9f127f7f2e9a00586072113afa52 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Thu, 19 Mar 2026 14:40:35 -0700 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- .ci/scripts/test_lora.sh | 3 +- backends/xnnpack/operators/node_visitor.py | 2 +- examples/models/llama/lora.py | 29 ++++++++++++------- .../llama/source_transformation/quantize.py | 29 ++++--------------- 4 files changed, 26 insertions(+), 37 deletions(-) diff --git a/.ci/scripts/test_lora.sh b/.ci/scripts/test_lora.sh index 89a7e99460e..6a929518e02 100644 --- a/.ci/scripts/test_lora.sh +++ b/.ci/scripts/test_lora.sh @@ -139,8 +139,7 @@ Okay, so I need to calculate 15% of 80." EXPECTED_QUANT_LORA_PREFIX=" <|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant To calculate 15% of 80, we can multiply 80 by 15/100. -80 * 15/100 = 12. -So, 15% of 80 is 12. +So, 15% of 80 is equal to 80 * 15/100 = 12. #### 12 The answer is: 12<|im_end|>" diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 4643ada9336..a0f03205ed5 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -625,7 +625,7 @@ def get_serialized_buffer_index( f"Serializing constant data node {tensor} but tensor value has no bytes", ) sha256_hash = hashlib.sha256(bytes(array)) - named_key = tensor.name + "_" + sha256_hash.hexdigest() + named_key = sha256_hash.hexdigest() size = const_val.untyped_storage().nbytes() xnn_graph.constant_data.append( diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py index 12c1c4e5d68..99d583f52dd 100644 --- a/examples/models/llama/lora.py +++ b/examples/models/llama/lora.py @@ -26,22 +26,31 @@ def __init__( self.rank = rank self.alpha = alpha self.use_bias = use_bias - self.dropout = dropout - - linear = nn.Linear(in_dim, out_dim, bias=use_bias) - weight = linear.weight - bias = linear.bias if self.use_bias else None - self.register_parameter("weight", nn.Parameter(weight)) - self.register_parameter( - "bias", nn.Parameter(bias) if bias is not None else None - ) + self.linear = nn.Linear(in_dim, out_dim, bias=use_bias) self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + @property + def weight(self): + return self.linear.weight + + @property + def bias(self): + return self.linear.bias + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # Remap keys to "linear.*" + for attr in ("weight", "bias"): + old_key = prefix + attr + new_key = prefix + "linear." + attr + if old_key in state_dict and new_key not in state_dict: + state_dict[new_key] = state_dict.pop(old_key) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + def forward(self, x: torch.Tensor) -> torch.Tensor: - out = torch.nn.functional.linear(x, self.weight, self.bias) + out = self.linear(x) lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index dbd2caad5a0..04a67c800dd 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -144,30 +144,11 @@ def quantize( # noqa C901 from torchao.utils import unwrap_tensor_subclass def filter_fn(m, fqn): - # Check if it's a regular nn.Linear - is_linear = isinstance(m, nn.Linear) - - # Check if it's a LoRALinear (which has a base weight parameter to quantize) - is_lora_linear = False - try: - from executorch.examples.models.llama.lora import LoRALinear - - is_lora_linear = isinstance(m, LoRALinear) - except ImportError: - pass - - # Check if the weight shape is compatible with group size - has_shape_compatible_with_group_size = False - if is_linear or is_lora_linear: - if group_size == 0: - has_shape_compatible_with_group_size = True - else: - has_shape_compatible_with_group_size = ( - m.weight.shape[1] % group_size == 0 - ) - return ( - is_linear or is_lora_linear - ) and has_shape_compatible_with_group_size + if not isinstance(m, nn.Linear): + return False + if group_size == 0: + return True + return m.weight.shape[1] % group_size == 0 weight_dtype = torch.int4 if qmode == "8da4w" else torch.int8 quantize_( From ac42ef26a3233624c9a2afd2314d79edd151b29e Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Thu, 19 Mar 2026 14:40:39 -0700 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- .../coreml/llama/export_static_llm_coreml.py | 127 +++++++++--------- 1 file changed, 64 insertions(+), 63 deletions(-) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 8c68af45f31..2aac3200dfb 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -309,6 +309,69 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) } +def _prepare_model(model, args, float_dtype): + """Apply splitting, quantization, and graph breaks to a model.""" + model = model.to(float_dtype).eval() + + if args.target_split_size is not None: + print(f"\nSplitting linear layers with target size {args.target_split_size}...") + replace_linear_with_split_linear( + model, + out_target_split_size=args.target_split_size, + out_max_splits=args.max_splits, + in_target_split_size=1, + in_max_splits=1, + ) + + if args.embedding_quantize: + bitwidth, group_size = args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + group_size = int(group_size) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + + print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if args.linear_quantize == "b4w": + print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) + elif args.linear_quantize == "c4w": + print("\nQuantizing linear layers: 4-bit channelwise...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) + + if not args.no_graph_breaks: + print("\nAdding graph breaks between before/after the transformer blocks...") + n_layers = len(model.layers) + model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) + model.layers[n_layers - 1] = BlockWithGraphBreak( + model.layers[n_layers - 1], break_before=False + ) + + return model + + def main(): parser = argparse.ArgumentParser( description="Export static attention Llama model to CoreML" @@ -441,70 +504,8 @@ def main(): ) print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") - # Set dtype float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = model.to(float_dtype).eval() - - # Apply linear splitting (before quantization) - if args.target_split_size is not None: - print(f"\nSplitting linear layers with target size {args.target_split_size}...") - replace_linear_with_split_linear( - model, - out_target_split_size=args.target_split_size, - out_max_splits=args.max_splits, - in_target_split_size=1, - in_max_splits=1, - ) - - # Apply embedding quantization - if args.embedding_quantize: - bitwidth, group_size = args.embedding_quantize.split(",") - bitwidth = int(bitwidth) - group_size = int(group_size) - assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" - - print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") - if group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bitwidth}") - - quantize_( - model, - IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - - # Apply linear quantization - if args.linear_quantize == "b4w": - print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerGroup(32), - ), - ) - elif args.linear_quantize == "c4w": - print("\nQuantizing linear layers: 4-bit channelwise...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerAxis(0), - ), - ) - - # Add graph breaks between transformer blocks - # Keeping model pieces smaller helps with ANE performance - if not args.no_graph_breaks: - print("\nAdding graph breaks between before/after the transformer blocks...") - n_layers = len(model.layers) - model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) - model.layers[n_layers - 1] = BlockWithGraphBreak( - model.layers[n_layers - 1], break_before=False - ) + model = _prepare_model(model, args, float_dtype) if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing From 2bb273de8ddaf0edef5d4cdab132aee9acd52909 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Thu, 19 Mar 2026 14:53:16 -0700 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- .ci/scripts/test_lora.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.ci/scripts/test_lora.sh b/.ci/scripts/test_lora.sh index 6a929518e02..89a7e99460e 100644 --- a/.ci/scripts/test_lora.sh +++ b/.ci/scripts/test_lora.sh @@ -139,7 +139,8 @@ Okay, so I need to calculate 15% of 80." EXPECTED_QUANT_LORA_PREFIX=" <|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant To calculate 15% of 80, we can multiply 80 by 15/100. -So, 15% of 80 is equal to 80 * 15/100 = 12. +80 * 15/100 = 12. +So, 15% of 80 is 12. #### 12 The answer is: 12<|im_end|>" From a8a23df11fdb7c6772da4c95e97cef893bd12f85 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Thu, 19 Mar 2026 15:00:44 -0700 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- examples/apple/coreml/llama/export_static_llm_coreml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 2aac3200dfb..66ef26d58c1 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -309,7 +309,7 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) } -def _prepare_model(model, args, float_dtype): +def _prepare_eager_model(model, args, float_dtype): """Apply splitting, quantization, and graph breaks to a model.""" model = model.to(float_dtype).eval() @@ -505,7 +505,7 @@ def main(): print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = _prepare_model(model, args, float_dtype) + model = _prepare_eager_model(model, args, float_dtype) if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing From d9836888734b4df86830332a1ca383964fcc64c4 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Fri, 20 Mar 2026 15:52:43 -0700 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- examples/apple/coreml/llama/export_static_llm_coreml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 66ef26d58c1..276ff6d193a 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -309,7 +309,7 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) } -def _prepare_eager_model(model, args, float_dtype): +def _transform_eager_model(model, args, float_dtype): """Apply splitting, quantization, and graph breaks to a model.""" model = model.to(float_dtype).eval() @@ -505,7 +505,7 @@ def main(): print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = _prepare_eager_model(model, args, float_dtype) + model = _transform_eager_model(model, args, float_dtype) if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing