[quantization] Add save_layers_to_folder option#565
Merged
mhs4670go merged 1 commit intoSamsung:mainfrom Mar 20, 2026
Merged
Conversation
mhs4670go
reviewed
Mar 19, 2026
Comment on lines
+134
to
+137
| pos_embeds = ( | ||
| qlayer.wrapped.rope_cos_template.cpu().to(dtype), | ||
| qlayer.wrapped.rope_sin_template.cpu().to(dtype), | ||
| ) |
Contributor
There was a problem hiding this comment.
save_layers_to() currently passes full static RoPE tables created at prepare-time (bounded by calibrate_seq_len / max_position_embeddings) while the export sample may use a shorter max_seq_len. This can cause shape mismatches during export.
Suggested change
| pos_embeds = ( | |
| qlayer.wrapped.rope_cos_template.cpu().to(dtype), | |
| qlayer.wrapped.rope_sin_template.cpu().to(dtype), | |
| ) | |
| pos_embeds = qlayer.wrapped._slice_rope(S, torch.device("cpu"), example_hidden.dtype) |
And, we should slice rope_cos_template / rope_sin_template to the example sequence length, just like we already do for the causal mask.
# quant_decoder_layer
def _slice_rope(
self, seq_len: int, device: torch.device, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
assert isinstance(self.rope_cos_template, torch.Tensor)
assert isinstance(self.rope_sin_template, torch.Tensor)
cos = self.rope_cos_template[:, :seq_len, :].to(device=device, dtype=dtype)
sin = self.rope_sin_template[:, :seq_len, :].to(device=device, dtype=dtype)
return cos, sin
# ..
if position_embeddings is None:
position_embeddings = self._slice_rope(
hidden_states.size(1), hidden_states.device, hidden_states.dtype
)
Contributor
Author
There was a problem hiding this comment.
Ahhh. Sorry. You're right. Thank you! I'll fix it.
This PR adds `save_layers_to_folder` option to make it possible to save all layers of quantized model to separate files. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
10e59e2 to
e617bf9
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds
save_layers_to_folderoption to make it possible to save all layers of quantized model to separate files.Log of
python tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py --model Maykeye/TinyLLama-v0 --gptq_mse=mse save_layers_to_folder='.':TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com