diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 8c68af45f31..276ff6d193a 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -309,6 +309,69 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) } +def _transform_eager_model(model, args, float_dtype): + """Apply splitting, quantization, and graph breaks to a model.""" + model = model.to(float_dtype).eval() + + if args.target_split_size is not None: + print(f"\nSplitting linear layers with target size {args.target_split_size}...") + replace_linear_with_split_linear( + model, + out_target_split_size=args.target_split_size, + out_max_splits=args.max_splits, + in_target_split_size=1, + in_max_splits=1, + ) + + if args.embedding_quantize: + bitwidth, group_size = args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + group_size = int(group_size) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + + print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + if args.linear_quantize == "b4w": + print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) + elif args.linear_quantize == "c4w": + print("\nQuantizing linear layers: 4-bit channelwise...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) + + if not args.no_graph_breaks: + print("\nAdding graph breaks between before/after the transformer blocks...") + n_layers = len(model.layers) + model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) + model.layers[n_layers - 1] = BlockWithGraphBreak( + model.layers[n_layers - 1], break_before=False + ) + + return model + + def main(): parser = argparse.ArgumentParser( description="Export static attention Llama model to CoreML" @@ -441,70 +504,8 @@ def main(): ) print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") - # Set dtype float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = model.to(float_dtype).eval() - - # Apply linear splitting (before quantization) - if args.target_split_size is not None: - print(f"\nSplitting linear layers with target size {args.target_split_size}...") - replace_linear_with_split_linear( - model, - out_target_split_size=args.target_split_size, - out_max_splits=args.max_splits, - in_target_split_size=1, - in_max_splits=1, - ) - - # Apply embedding quantization - if args.embedding_quantize: - bitwidth, group_size = args.embedding_quantize.split(",") - bitwidth = int(bitwidth) - group_size = int(group_size) - assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" - - print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") - if group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bitwidth}") - - quantize_( - model, - IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - - # Apply linear quantization - if args.linear_quantize == "b4w": - print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerGroup(32), - ), - ) - elif args.linear_quantize == "c4w": - print("\nQuantizing linear layers: 4-bit channelwise...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerAxis(0), - ), - ) - - # Add graph breaks between transformer blocks - # Keeping model pieces smaller helps with ANE performance - if not args.no_graph_breaks: - print("\nAdding graph breaks between before/after the transformer blocks...") - n_layers = len(model.layers) - model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) - model.layers[n_layers - 1] = BlockWithGraphBreak( - model.layers[n_layers - 1], break_before=False - ) + model = _transform_eager_model(model, args, float_dtype) if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing