-
Notifications
You must be signed in to change notification settings - Fork 886
Refactor: extract _transform_eager_model() from CoreML export main() #18343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
da7f489
ac42ef2
bb400fa
367970a
2bb273d
a4598d9
a8a23df
5b130e1
108fa03
06a59b2
d983688
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -309,6 +309,69 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _transform_eager_model(model, args, float_dtype): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Apply splitting, quantization, and graph breaks to a model.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = model.to(float_dtype).eval() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
lucylq marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.target_split_size is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"\nSplitting linear layers with target size {args.target_split_size}...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| replace_linear_with_split_linear( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_target_split_size=args.target_split_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_max_splits=args.max_splits, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in_target_split_size=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| in_max_splits=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.embedding_quantize: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bitwidth, group_size = args.embedding_quantize.split(",") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bitwidth = int(bitwidth) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| group_size = int(group_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+327
to
+331
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bitwidth, group_size = args.embedding_quantize.split(",") | |
| bitwidth = int(bitwidth) | |
| group_size = int(group_size) | |
| assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" | |
| try: | |
| bitwidth_str, group_size_str = args.embedding_quantize.split(",") | |
| except ValueError as e: | |
| raise ValueError( | |
| f"Invalid value for --embedding_quantize: {args.embedding_quantize!r}. " | |
| "Expected format is 'BITWIDTH,GROUP_SIZE', e.g. '4,32'." | |
| ) from e | |
| try: | |
| bitwidth = int(bitwidth_str) | |
| group_size = int(group_size_str) | |
| except ValueError as e: | |
| raise ValueError( | |
| f"Invalid value for --embedding_quantize: {args.embedding_quantize!r}. " | |
| "BITWIDTH and GROUP_SIZE must be integers, e.g. '4,32'." | |
| ) from e | |
| if bitwidth not in (4, 8): | |
| raise ValueError( | |
| f"Unsupported BITWIDTH {bitwidth} in --embedding_quantize. " | |
| "CoreML only supports 4-bit and 8-bit quantization." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_prepare_modeldocstring says it applies splitting/quantization/graph breaks, but the function also changes dtype and sets eval mode. Consider updating the docstring to reflect the full set of side effects (and that it mutates the module in-place).