diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..b282be8c4f4 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -91,9 +91,20 @@ def forward(self, x): def replace_linear_with_split_linear( - model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 + model, + out_target_split_size, + out_max_splits, + in_target_split_size, + in_max_splits=1, + skip_names=None, ): + from executorch.examples.models.llama.lora import LoRALinear + for name, module in model.named_children(): + if skip_names and name in skip_names: + continue + if isinstance(module, LoRALinear): + continue if isinstance(module, torch.nn.Linear): assert module.bias is None, "SplitLinearModule does not support bias" new_module = SplitLinearModule( @@ -113,4 +124,5 @@ def replace_linear_with_split_linear( out_max_splits, in_target_split_size, in_max_splits, + skip_names=skip_names, )