Skip to content

[quantization] Add save_layers_to_folder option#565

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:save_layers
Mar 20, 2026
Merged

[quantization] Add save_layers_to_folder option#565
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:save_layers

Conversation

@stamalakhov
Copy link
Contributor

@stamalakhov stamalakhov commented Mar 19, 2026

This PR adds save_layers_to_folder option to make it possible to save all layers of quantized model to separate files.

Log of python tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py --model Maykeye/TinyLLama-v0 --gptq_mse=mse save_layers_to_folder='.' :

Namespace(model='Maykeye/TinyLLama-v0', device='cuda', dtype='float32', seed=42, trust_remote_code=False, hf_token=None, no_tqdm=False, no_GPTQ=False, no_PTQ=False, save_circle_to_folder='.', save_layers_to_folder='.', cache_dir='/mnt/storage/transformers_cache', nsamples_for_qcalibration=128, linear_weight_bits=4, gptq_mse='mse', max_seq_len=2048, calibrate_seq_len=2048, embedding_weight_bits=8, lm_head_weight_bits=4, eval_tasks=None, sensitivity_path=None)
=== Config ===
Model            : Maykeye/TinyLLama-v0
Device           : cuda
DType            : float32

Loading FP model …
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.

Calculating original perplexities …
Token indices sequence length is longer than the specified maximum sequence length for this model (324381 > 2048). Running this sequence through the model will result in indexing errors
PPL:  99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 158/159 [00:06<00:00, 22.90it/s]

┌── Wikitext-2 test perplexity ─────────────
│ FP32 :  7584.31
└───────────────────────────────────────────
Applying GPTQ …
Quantizing layers: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.04s/layer]
Wrapping layers with PTQWrapper …                                                                                                                                                                                                                                              
Calibrating PTQ obeservers…
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:39<00:00,  3.20it/s]

Calculating perplexities …
PPL:  99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 158/159 [00:28<00:00,  5.59it/s]

┌── Wikitext-2 test perplexity ─────────────
│ int16 :  7297.10
└───────────────────────────────────────────
Saving model layer_0 to /mnt/storage/slow_repos/TICO/decoder_layer_0.q.circle
Saving model layer_1 to /mnt/storage/slow_repos/TICO/decoder_layer_1.q.circle
Saving model layer_2 to /mnt/storage/slow_repos/TICO/decoder_layer_2.q.circle
Saving model layer_3 to /mnt/storage/slow_repos/TICO/decoder_layer_3.q.circle
Saving model layer_4 to /mnt/storage/slow_repos/TICO/decoder_layer_4.q.circle
Saving model layer_5 to /mnt/storage/slow_repos/TICO/decoder_layer_5.q.circle
Saving model layer_6 to /mnt/storage/slow_repos/TICO/decoder_layer_6.q.circle
Saving model layer_7 to /mnt/storage/slow_repos/TICO/decoder_layer_7.q.circle

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Mar 19, 2026
@stamalakhov stamalakhov requested a review from a team March 19, 2026 07:27
Comment on lines +134 to +137
pos_embeds = (
qlayer.wrapped.rope_cos_template.cpu().to(dtype),
qlayer.wrapped.rope_sin_template.cpu().to(dtype),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_layers_to() currently passes full static RoPE tables created at prepare-time (bounded by calibrate_seq_len / max_position_embeddings) while the export sample may use a shorter max_seq_len. This can cause shape mismatches during export.

Suggested change
pos_embeds = (
qlayer.wrapped.rope_cos_template.cpu().to(dtype),
qlayer.wrapped.rope_sin_template.cpu().to(dtype),
)
pos_embeds = qlayer.wrapped._slice_rope(S, torch.device("cpu"), example_hidden.dtype)

And, we should slice rope_cos_template / rope_sin_template to the example sequence length, just like we already do for the causal mask.

# quant_decoder_layer
def _slice_rope(
    self, seq_len: int, device: torch.device, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
    assert isinstance(self.rope_cos_template, torch.Tensor)
    assert isinstance(self.rope_sin_template, torch.Tensor)
    cos = self.rope_cos_template[:, :seq_len, :].to(device=device, dtype=dtype)
    sin = self.rope_sin_template[:, :seq_len, :].to(device=device, dtype=dtype)
    return cos, sin

# ..
if position_embeddings is None:
    position_embeddings = self._slice_rope(
        hidden_states.size(1), hidden_states.device, hidden_states.dtype
    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh. Sorry. You're right. Thank you! I'll fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Fixed.

This PR adds `save_layers_to_folder` option to make it possible to save all layers of quantized model to separate files.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit 5d838b3 into Samsung:main Mar 20, 2026
7 checks passed
@stamalakhov stamalakhov deleted the save_layers branch March 20, 2026 04:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants