Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 64 additions & 63 deletions examples/apple/coreml/llama/export_static_llm_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,69 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype)
}


def _transform_eager_model(model, args, float_dtype):
"""Apply splitting, quantization, and graph breaks to a model."""
Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _prepare_model docstring says it applies splitting/quantization/graph breaks, but the function also changes dtype and sets eval mode. Consider updating the docstring to reflect the full set of side effects (and that it mutates the module in-place).

Suggested change
"""Apply splitting, quantization, and graph breaks to a model."""
"""
Prepare an eager model for export by applying dtype conversion, eval mode,
splitting, quantization, and optional graph breaks.
This function mutates the given ``model`` in-place:
* moves it to ``float_dtype``,
* sets it to evaluation mode,
* optionally splits linear layers,
* optionally quantizes embeddings and linear layers, and
* optionally wraps the first/last transformer blocks with graph breaks.
The same (mutated) ``model`` instance is returned.
"""

Copilot uses AI. Check for mistakes.
model = model.to(float_dtype).eval()

if args.target_split_size is not None:
print(f"\nSplitting linear layers with target size {args.target_split_size}...")
replace_linear_with_split_linear(
model,
out_target_split_size=args.target_split_size,
out_max_splits=args.max_splits,
in_target_split_size=1,
in_max_splits=1,
)

if args.embedding_quantize:
bitwidth, group_size = args.embedding_quantize.split(",")
bitwidth = int(bitwidth)
group_size = int(group_size)
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"

Comment on lines +327 to +331
Copy link

Copilot AI Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.embedding_quantize is parsed via .split(",") and validated with assert. Since this is user-controlled CLI input, consider doing explicit validation and raising a ValueError/argparse.ArgumentTypeError with a clear message (and avoid assert, which can be stripped with python -O).

Suggested change
bitwidth, group_size = args.embedding_quantize.split(",")
bitwidth = int(bitwidth)
group_size = int(group_size)
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
try:
bitwidth_str, group_size_str = args.embedding_quantize.split(",")
except ValueError as e:
raise ValueError(
f"Invalid value for --embedding_quantize: {args.embedding_quantize!r}. "
"Expected format is 'BITWIDTH,GROUP_SIZE', e.g. '4,32'."
) from e
try:
bitwidth = int(bitwidth_str)
group_size = int(group_size_str)
except ValueError as e:
raise ValueError(
f"Invalid value for --embedding_quantize: {args.embedding_quantize!r}. "
"BITWIDTH and GROUP_SIZE must be integers, e.g. '4,32'."
) from e
if bitwidth not in (4, 8):
raise ValueError(
f"Unsupported BITWIDTH {bitwidth} in --embedding_quantize. "
"CoreML only supports 4-bit and 8-bit quantization."
)

Copilot uses AI. Check for mistakes.
print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...")
if group_size == 0:
granularity = PerAxis(0)
else:
granularity = PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bitwidth}")

quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)

if args.linear_quantize == "b4w":
print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...")
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
),
)
elif args.linear_quantize == "c4w":
print("\nQuantizing linear layers: 4-bit channelwise...")
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
)

if not args.no_graph_breaks:
print("\nAdding graph breaks between before/after the transformer blocks...")
n_layers = len(model.layers)
model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True)
model.layers[n_layers - 1] = BlockWithGraphBreak(
model.layers[n_layers - 1], break_before=False
)

return model


def main():
parser = argparse.ArgumentParser(
description="Export static attention Llama model to CoreML"
Expand Down Expand Up @@ -441,70 +504,8 @@ def main():
)
print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim")

# Set dtype
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
model = model.to(float_dtype).eval()

# Apply linear splitting (before quantization)
if args.target_split_size is not None:
print(f"\nSplitting linear layers with target size {args.target_split_size}...")
replace_linear_with_split_linear(
model,
out_target_split_size=args.target_split_size,
out_max_splits=args.max_splits,
in_target_split_size=1,
in_max_splits=1,
)

# Apply embedding quantization
if args.embedding_quantize:
bitwidth, group_size = args.embedding_quantize.split(",")
bitwidth = int(bitwidth)
group_size = int(group_size)
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"

print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...")
if group_size == 0:
granularity = PerAxis(0)
else:
granularity = PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bitwidth}")

quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)

# Apply linear quantization
if args.linear_quantize == "b4w":
print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...")
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
),
)
elif args.linear_quantize == "c4w":
print("\nQuantizing linear layers: 4-bit channelwise...")
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
)

# Add graph breaks between transformer blocks
# Keeping model pieces smaller helps with ANE performance
if not args.no_graph_breaks:
print("\nAdding graph breaks between before/after the transformer blocks...")
n_layers = len(model.layers)
model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True)
model.layers[n_layers - 1] = BlockWithGraphBreak(
model.layers[n_layers - 1], break_before=False
)
model = _transform_eager_model(model, args, float_dtype)

if args.multifunction:
# Multifunction mode: separate prefill and decode graphs with weight sharing
Expand Down
Loading