-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
53 lines (41 loc) · 1.77 KB
/
utils.py
File metadata and controls
53 lines (41 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from typing import List, Tuple, Any, Iterable
def preprocess_text_batch(raw_texts, tokenizer, device, target_len=256):
"""Tokenizes and pads text to a fixed target length."""
encoded_list = [tokenizer.encode(t) for t in raw_texts]
pad_id = tokenizer.word2idx[tokenizer.pad_token]
text_tokens = torch.nn.utils.rnn.pad_sequence(
encoded_list,
batch_first=True,
padding_value=pad_id
).to(device)
current_len = text_tokens.size(1)
if current_len < target_len:
pad_amt = target_len - current_len
text_tokens = torch.nn.functional.pad(
text_tokens, (0, pad_amt),
value=pad_id
)
else:
text_tokens = text_tokens[:, :target_len]
return text_tokens
def decode_sequences(
logits: torch.Tensor,
tokenizer: Any,
special_tokens_to_filter: Iterable[str],
label_tokens: torch.Tensor
) -> Tuple[List[str], List[str]]:
"""Decodes model logits and ground truth tokens into clean text strings."""
def _tokens_to_text(token_ids: Iterable[Iterable[int]]) -> List[str]:
decoded_batch = []
for sequence in token_ids:
units = [tokenizer.idx2word.get(int(idx), tokenizer.unk_token) for idx in sequence]
# Filter out special tokens ([PAD], [CLS], etc.)
clean_units = [u for u in units if u not in special_tokens_to_filter]
# Join logic: usually "" for char-level, " " for word-level
decoded_batch.append(" ".join(clean_units))
return decoded_batch
predictions = torch.argmax(logits, dim=-1).detach().cpu().numpy()
decoded_labels = _tokens_to_text(label_tokens)
decoded_preds = _tokens_to_text(predictions)
return decoded_labels, decoded_preds