Refactor: extract _prepare_eager_model() from CoreML export main()#18343
Refactor: extract _prepare_eager_model() from CoreML export main()#18343
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18343
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 3 Pending, 4 Unrelated FailuresAs of commit 06a59b2 with merge base dd7464a ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
Refactors the CoreML static Llama export script by extracting model-preparation steps into a reusable helper, enabling consistent application of dtype conversion, splitting, quantization, and graph-break insertion (useful for multi-method exports).
Changes:
- Added
_prepare_model(model, args, float_dtype)to encapsulate dtype + transformation steps. - Replaced inlined preparation logic in
main()with a single call to_prepare_model(...).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
|
|
||
| def _prepare_eager_model(model, args, float_dtype): | ||
| """Apply splitting, quantization, and graph breaks to a model.""" |
There was a problem hiding this comment.
The _prepare_model docstring says it applies splitting/quantization/graph breaks, but the function also changes dtype and sets eval mode. Consider updating the docstring to reflect the full set of side effects (and that it mutates the module in-place).
| """Apply splitting, quantization, and graph breaks to a model.""" | |
| """ | |
| Prepare an eager model for export by applying dtype conversion, eval mode, | |
| splitting, quantization, and optional graph breaks. | |
| This function mutates the given ``model`` in-place: | |
| * moves it to ``float_dtype``, | |
| * sets it to evaluation mode, | |
| * optionally splits linear layers, | |
| * optionally quantizes embeddings and linear layers, and | |
| * optionally wraps the first/last transformer blocks with graph breaks. | |
| The same (mutated) ``model`` instance is returned. | |
| """ |
| bitwidth, group_size = args.embedding_quantize.split(",") | ||
| bitwidth = int(bitwidth) | ||
| group_size = int(group_size) | ||
| assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" | ||
|
|
There was a problem hiding this comment.
args.embedding_quantize is parsed via .split(",") and validated with assert. Since this is user-controlled CLI input, consider doing explicit validation and raising a ValueError/argparse.ArgumentTypeError with a clear message (and avoid assert, which can be stripped with python -O).
| bitwidth, group_size = args.embedding_quantize.split(",") | |
| bitwidth = int(bitwidth) | |
| group_size = int(group_size) | |
| assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" | |
| try: | |
| bitwidth_str, group_size_str = args.embedding_quantize.split(",") | |
| except ValueError as e: | |
| raise ValueError( | |
| f"Invalid value for --embedding_quantize: {args.embedding_quantize!r}. " | |
| "Expected format is 'BITWIDTH,GROUP_SIZE', e.g. '4,32'." | |
| ) from e | |
| try: | |
| bitwidth = int(bitwidth_str) | |
| group_size = int(group_size_str) | |
| except ValueError as e: | |
| raise ValueError( | |
| f"Invalid value for --embedding_quantize: {args.embedding_quantize!r}. " | |
| "BITWIDTH and GROUP_SIZE must be integers, e.g. '4,32'." | |
| ) from e | |
| if bitwidth not in (4, 8): | |
| raise ValueError( | |
| f"Unsupported BITWIDTH {bitwidth} in --embedding_quantize. " | |
| "CoreML only supports 4-bit and 8-bit quantization." | |
| ) |
This PR needs a
|
Extract eager model preparation logic (dtype conversion, linear splitting,
quantization, graph breaks) into a reusable _prepare_model() helper.
No functional change — pure refactor.
Generated with Claude.
Used for multimethod; extract out eager model instantiation and transforms, so that we can apply them to each method.