diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 276ff6d193a..384f2e5cc49 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -111,6 +111,8 @@ def load_model( params_path: str, max_context_len: int, generate_full_logits: bool = True, + adapter_checkpoint: str = None, + adapter_config: str = None, ): """Load the model from checkpoint with static_mha attention type. @@ -121,6 +123,8 @@ def load_model( generate_full_logits: If True, output logits for all tokens (needed for lookahead decoding). If False, only output logits for the last token (more efficient for standard autoregressive generation). + adapter_checkpoint: Path to LoRA adapter weights (.safetensors) + adapter_config: Path to adapter_config.json """ with open(params_path, "r") as f: params = json.loads(f.read()) @@ -133,6 +137,13 @@ def load_model( args.attention_type = "static_mha" args.attention_kwargs = {"decompose_sdpa_in_mha": True} + if adapter_config is not None: + with open(adapter_config, "r") as f: + lora_config = json.loads(f.read()) + args.r = lora_config["r"] + args.lora_alpha = lora_config["lora_alpha"] + args.target_modules = lora_config["target_modules"] + with torch.device("meta"): model = construct_transformer(args) @@ -142,7 +153,9 @@ def load_model( if "model" in checkpoint: checkpoint = checkpoint["model"] - # Rename attention weight keys for static attention + # Rename attention weight keys for static attention: + # wq.weight -> wqs.0.weight, wk.weight -> wks.0.weight, wv.weight -> wvs.0.weight + # LoRALinear._load_from_state_dict remaps weight -> linear.weight automatically. for i in range(len(model.layers)): if f"layers.{i}.attention.wq.weight" in checkpoint: checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( @@ -157,6 +170,23 @@ def load_model( f"layers.{i}.attention.wv.weight" ) + if adapter_checkpoint is not None: + from executorch.examples.models.llama.convert_weights import ( + load_and_convert_unsloth_to_meta, + ) + + adapter_weights = load_and_convert_unsloth_to_meta(adapter_checkpoint) + # Rename adapter keys: wq.lora_*.weight -> wqs.0.lora_*.weight + for i in range(len(model.layers)): + for old_proj, new_proj in [("wq", "wqs.0"), ("wk", "wks.0"), ("wv", "wvs.0")]: + for suffix in ["lora_a.weight", "lora_b.weight"]: + old_key = f"layers.{i}.attention.{old_proj}.{suffix}" + if old_key in adapter_weights: + new_key = f"layers.{i}.attention.{new_proj}.{suffix}" + adapter_weights[new_key] = adapter_weights.pop(old_key) + + checkpoint.update(adapter_weights) + missing, unexpected = model.load_state_dict( checkpoint, strict=False, @@ -310,7 +340,13 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) def _transform_eager_model(model, args, float_dtype): - """Apply splitting, quantization, and graph breaks to a model.""" + """Apply splitting, quantization, and graph breaks to a model. + + This is shared across base and adapter models so the same transformations + are applied consistently. + """ + from executorch.examples.models.llama.lora import LoRALinear + model = model.to(float_dtype).eval() if args.target_split_size is not None: @@ -342,6 +378,20 @@ def _transform_eager_model(model, args, float_dtype): lambda m, fqn: isinstance(m, torch.nn.Embedding), ) + has_lora_modules = any( + isinstance(m, LoRALinear) for m in model.modules() + ) + + def _exclude_lora(m, fqn): + if isinstance(m, LoRALinear): + return False + parts = fqn.split(".") + if "lora_a" in parts or "lora_b" in parts: + return False + return isinstance(m, nn.Linear) + + linear_filter = _exclude_lora if has_lora_modules else None + if args.linear_quantize == "b4w": print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") quantize_( @@ -350,6 +400,7 @@ def _transform_eager_model(model, args, float_dtype): weight_dtype=torch.int4, granularity=PerGroup(32), ), + linear_filter, ) elif args.linear_quantize == "c4w": print("\nQuantizing linear layers: 4-bit channelwise...") @@ -359,6 +410,7 @@ def _transform_eager_model(model, args, float_dtype): weight_dtype=torch.int4, granularity=PerAxis(0), ), + linear_filter, ) if not args.no_graph_breaks: @@ -464,9 +516,19 @@ def main(): "and generate_full_logits=True for lookahead decoding support.", ) + # LoRA adapter options + parser.add_argument( + "--adapter", + nargs=3, + action="append", + metavar=("NAME", "CHECKPOINT", "CONFIG"), + help="LoRA adapter: method name, path to adapter.safetensors, " + "path to adapter_config.json. Can be repeated for multiple adapters.", + ) + args = parser.parse_args() - # Compute cache length + has_adapters = args.adapter is not None print("Export mode:") if args.multifunction: @@ -475,12 +537,15 @@ def main(): ) else: print("\tSingle method: fixed seqlen, generate_full_logits=True (lookahead)") + if has_adapters: + print(f"\tAdapters: {[a[0] for a in args.adapter]}") print("\nQuantization and datatype:") print(f"\tEmbedding quantize: {args.embedding_quantize}") print(f"\tLinear quantize: {args.linear_quantize}") print(f"\tDtype: {args.dtype}") + # Compute cache length cache_len = args.max_context_len - args.input_len print("\nGeneration configuration:") print(f"\tMax context length: {args.max_context_len}") @@ -491,7 +556,7 @@ def main(): print(f"\tTarget split size: {args.target_split_size}") print(f"\tMax splits: {args.max_splits}") - # Load model + # Load base model # For multifunction: generate_full_logits=False (efficient, only last token) # For single method: generate_full_logits=True (needed for lookahead decoding) generate_full_logits = not args.multifunction @@ -505,8 +570,39 @@ def main(): print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] + model = _transform_eager_model(model, args, float_dtype) + # Load adapter models + lora_models = {} + if has_adapters: + for name, adapter_ckpt, adapter_cfg in args.adapter: + print(f"\nLoading adapter '{name}' from {adapter_ckpt}...") + lora_model, _ = load_model( + args.checkpoint, + args.params, + args.max_context_len, + generate_full_logits=generate_full_logits, + adapter_checkpoint=adapter_ckpt, + adapter_config=adapter_cfg, + ) + lora_model = _transform_eager_model(lora_model, args, float_dtype) + lora_models[name] = lora_model + + def _export_model(m, inputs, label="model"): + print(f"\nTesting eager execution ({label})...") + with torch.no_grad(): + m(*inputs) + print(f"Eager execution successful ({label})!") + + print(f"\nExporting {label}...") + ep = torch.export.export(m, inputs) + print(f"Export successful ({label})!") + print(ep) + return ep + + use_multimethod = args.multifunction or has_adapters + if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing # Both methods use the same cache_len (decode's cache size) so they can share @@ -537,31 +633,22 @@ def main(): cache_len=shared_cache_len, ) - # Test eager execution for both - print("\nTesting eager execution (decode, seqlen=1)...") - with torch.no_grad(): - model(*decode_inputs) - print("Decode eager execution successful!") + # Export base model + methods = { + "forward": _export_model(model, decode_inputs, "base decode"), + "prefill": _export_model(model, prefill_inputs, "base prefill"), + } - print(f"\nTesting eager execution (prefill, seqlen={prefill_input_len})...") - with torch.no_grad(): - model(*prefill_inputs) - print("Prefill eager execution successful!") - - # Export both graphs - print("\nExporting decode model (seqlen=1)...") - decode_ep = torch.export.export(model, decode_inputs) - print("Decode export successful!") - print(decode_ep) - - print(f"\nExporting prefill model (seqlen={prefill_input_len})...") - prefill_ep = torch.export.export(model, prefill_inputs) - print("Prefill export successful!") - print(prefill_ep) - - # Generate metadata for C++ runner - # constant_methods are shared across all methods, so we prefix method-specific - # metadata with the method name + # Export adapter models + for name, lora_model in lora_models.items(): + methods[f"{name}_forward"] = _export_model( + lora_model, decode_inputs, f"{name} decode" + ) + methods[f"{name}_prefill"] = _export_model( + lora_model, prefill_inputs, f"{name} prefill" + ) + + # Generate metadata print("\nGenerating metadata for C++ runner...") decode_metadata = _get_metadata( model_args, decode_inputs, decode_input_len, decode_cache_len, float_dtype @@ -574,7 +661,8 @@ def main(): float_dtype, ) - # Combine metadata - shared values go without prefix, method-specific values get prefixed + # Combine metadata - shared values go without prefix, + # method-specific values get prefixed. constant_methods = { # Shared metadata (same for both methods) "vocab_size": decode_metadata["vocab_size"], @@ -595,50 +683,26 @@ def main(): "prefill_mask_specs": prefill_metadata["mask_specs"], "prefill_kv_cache_specs": prefill_metadata["kv_cache_specs"], } - - # Setup CoreML partitioner with multimethod weight sharing - print("\nSetting up CoreML partitioner (multifunction with weight sharing)...") - compile_specs = CoreMLBackend.generate_compile_specs( - minimum_deployment_target=ct.target.iOS18, - compute_precision={ - torch.float16: ct.precision.FLOAT16, - torch.float32: ct.precision.FLOAT32, - }[float_dtype], - compute_unit=ct.ComputeUnit.CPU_AND_NE, - model_type=CoreMLBackend.MODEL_TYPE.MODEL, - ) - compile_specs.append( - CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( - MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL - ) - ) - partitioner = CoreMLPartitioner( - compile_specs=compile_specs, - take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[], - ) - - # Lower to edge with both decode and prefill methods - print("\nLowering to edge (multi-method: decode + prefill)...") - edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) - - # Create multi-method edge manager with decode as "forward" and prefill as "prefill" - edge_manager = to_edge_transform_and_lower( - {"forward": decode_ep, "prefill": prefill_ep}, - partitioner=[partitioner], - constant_methods=constant_methods, - compile_config=edge_compile_config, + if has_adapters: + constant_methods["has_lora"] = True + elif has_adapters: + # Adapter-only mode (no multifunction): base + adapter methods, same seqlen + print(f"\nCreating example inputs (seqlen={args.input_len})...") + example_inputs, example_cache_len = _create_example_inputs( + model_args, args.input_len, args.max_context_len, float_dtype ) - print("\nDelegated program (decode/forward):") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) + methods = { + "forward": _export_model(model, example_inputs, "base"), + } + for name, lora_model in lora_models.items(): + methods[name] = _export_model(lora_model, example_inputs, name) - print("\nDelegated program (prefill):") - print( - format_delegated_graph( - edge_manager.exported_program("prefill").graph_module - ) + print("\nGenerating metadata for C++ runner...") + constant_methods = _get_metadata( + model_args, example_inputs, args.input_len, example_cache_len, float_dtype ) + constant_methods["has_lora"] = True else: # Single method mode: fixed seqlen with generate_full_logits=True for lookahead print(f"\nCreating example inputs (seqlen={args.input_len})...") @@ -646,51 +710,60 @@ def main(): model_args, args.input_len, args.max_context_len, float_dtype ) - # Test eager execution - print("\nTesting eager execution...") - with torch.no_grad(): - model(*example_inputs) - print("Eager execution successful!") - - # Export the model - print("\nExporting model...") - ep = torch.export.export(model, example_inputs) - print("Export successful!") - print(ep) + ep = _export_model(model, example_inputs, "model") - # Generate metadata for C++ runner print("\nGenerating metadata for C++ runner...") constant_methods = _get_metadata( model_args, example_inputs, args.input_len, example_cache_len, float_dtype ) - # Setup CoreML partitioner - print("\nSetting up CoreML partitioner...") - compile_specs = CoreMLBackend.generate_compile_specs( - minimum_deployment_target=ct.target.iOS18, - compute_precision={ - torch.float16: ct.precision.FLOAT16, - torch.float32: ct.precision.FLOAT32, - }[float_dtype], - compute_unit=ct.ComputeUnit.CPU_AND_NE, - model_type=CoreMLBackend.MODEL_TYPE.MODEL, - ) - partitioner = CoreMLPartitioner( - compile_specs=compile_specs, - take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[], + # Setup CoreML partitioner + print("\nSetting up CoreML partitioner...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + if use_multimethod: + compile_specs.append( + CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( + MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL + ) ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) - # Lower to edge with constant methods for C++ runner - print("\nLowering to edge...") - edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + # Lower to edge + print("\nLowering to edge...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + if use_multimethod: + edge_manager = to_edge_transform_and_lower( + methods, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) + for method_name in methods: + print(f"\nDelegated program ({method_name}):") + print( + format_delegated_graph( + edge_manager.exported_program(method_name).graph_module + ) + ) + else: edge_manager = to_edge_transform_and_lower( ep, partitioner=[partitioner], constant_methods=constant_methods, compile_config=edge_compile_config, ) - print("\nDelegated program:") print(format_delegated_graph(edge_manager.exported_program().graph_module)) diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..3ecc4458901 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -91,7 +91,11 @@ def forward(self, x): def replace_linear_with_split_linear( - model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 + model, + out_target_split_size, + out_max_splits, + in_target_split_size, + in_max_splits=1, ): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear):