Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,39 @@ def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
cm.save(save_path)


def save_layers_to(q_m, max_seq_len, save_layers_to_folder):
q_m.eval()
q_m.cpu()

if not hasattr(q_m, "wrapped"):
print("Saving layers currently is supported only for PTQ quantized model")
return

layers = q_m.wrapped.model.wrapped.layers
config = q_m.wrapped.config
for i, qlayer in enumerate(layers):
save_path = pathlib.Path(save_layers_to_folder, f"decoder_layer_{i}.q.circle")
B, S, D = 1, max_seq_len, config.hidden_size
example_hidden = torch.randn(B, S, D)

attention_mask = qlayer.wrapped._slice_causal(S, "cpu").squeeze(0)
dtype = example_hidden.dtype
pos_embeds = qlayer.wrapped._slice_rope(S, "cpu", dtype)

print(f"Saving model layer_{i} to {save_path.resolve()}")
with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(
qlayer,
(example_hidden,),
kwargs={
"attention_mask": attention_mask,
"position_embeddings": pos_embeds,
},
)
cm.save(save_path)


def quantize_using_PTQ(q_m, calib_inputs, args):
print("Wrapping layers with PTQWrapper …")

Expand Down Expand Up @@ -235,6 +268,12 @@ def main():
default=None,
help="Save embedding/lm_head/all_layers/model.model/the_whole_model to the folder specified",
)
parser.add_argument(
"--save_layers_to_folder",
type=str,
default=None,
help="Save all layers to the folder specified",
)
parser.add_argument(
"--cache_dir",
type=str,
Expand Down Expand Up @@ -413,6 +452,9 @@ def main():
# after PTQ quantizer only fixed-length input sequences are valid
evaluate(q_m, tokenizer, dataset_test, args)

if args.save_layers_to_folder is not None:
save_layers_to(q_m, args.max_seq_len, args.save_layers_to_folder)

if args.save_circle_to_folder is not None:
calib_inputs = list(torch.stack(calib_inputs).reshape(-1, 1, args.max_seq_len))
save_circles_to(q_m, calib_inputs, args.save_circle_to_folder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor:
assert isinstance(self.causal_mask_template, torch.Tensor)
return self.causal_mask_template[..., :seq_len, :seq_len].to(device)

def _slice_rope(
self, seq_len: int, device: torch.device, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
assert isinstance(self.rope_cos_template, torch.Tensor)
assert isinstance(self.rope_sin_template, torch.Tensor)
cos = self.rope_cos_template[:, :seq_len, :].to(device=device, dtype=dtype)
sin = self.rope_sin_template[:, :seq_len, :].to(device=device, dtype=dtype)
return cos, sin

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -198,13 +207,8 @@ def forward(
) # let it be quantized immediately

if position_embeddings is None:
position_embeddings = (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
position_embeddings = self._slice_rope(
hidden_states.size(1), hidden_states.device, hidden_states.dtype
)
cos, sin = position_embeddings
position_embeddings = (
Expand Down
Loading