-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
153 lines (124 loc) · 4.79 KB
/
example.py
File metadata and controls
153 lines (124 loc) · 4.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
"""
Example usage of rouge-torch package.
This script demonstrates the basic functionality of the rouge-torch package
including computing ROUGE scores and using the loss function.
"""
import torch
from rouge_torch import ROUGEScoreTorch, create_vocab_and_tokenizer, text_to_logits
def main():
"""Run example usage of rouge-torch."""
print("=" * 60)
print("ROUGE-Torch Example Usage")
print("=" * 60)
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Create simple tokenizer (in practice, you'd use your own)
word_to_id, id_to_word, tokenize, detokenize = create_vocab_and_tokenizer()
vocab_size = len(word_to_id)
print(f"Vocabulary size: {vocab_size}")
# Initialize ROUGE scorer
rouge_scorer = ROUGEScoreTorch(vocab_size, device)
# Example texts
candidate_text = "the cat sat on the mat"
reference_texts = [
"the cat sat on the mat", # Perfect match
"a cat was sitting on the mat", # Similar meaning
]
print(f"\nCandidate: '{candidate_text}'")
for i, ref_text in enumerate(reference_texts):
print(f"Reference {i+1}: '{ref_text}'")
# Convert texts to logits
def text_to_logits_helper(text, max_len=20):
return text_to_logits(text, tokenize, vocab_size, device, max_len)
cand_logits = text_to_logits_helper(candidate_text)
ref_logits = [text_to_logits_helper(ref) for ref in reference_texts]
print(f"\nTensor shapes:")
print(f" Candidate logits: {cand_logits.shape}")
print(f" Reference logits: {[ref.shape for ref in ref_logits]}")
# Compute ROUGE scores
print(f"\n" + "-" * 40)
print("ROUGE Scores:")
print("-" * 40)
# ROUGE-1
rouge_1 = rouge_scorer.rouge_n_batch(cand_logits, ref_logits, n=1)
print(f"ROUGE-1:")
print(f" Precision: {rouge_1['precision'][0]:.3f}")
print(f" Recall: {rouge_1['recall'][0]:.3f}")
print(f" F1: {rouge_1['f1'][0]:.3f}")
# ROUGE-2
rouge_2 = rouge_scorer.rouge_n_batch(cand_logits, ref_logits, n=2)
print(f"ROUGE-2:")
print(f" Precision: {rouge_2['precision'][0]:.3f}")
print(f" Recall: {rouge_2['recall'][0]:.3f}")
print(f" F1: {rouge_2['f1'][0]:.3f}")
# ROUGE-L
rouge_l = rouge_scorer.rouge_l_batch(cand_logits, ref_logits)
print(f"ROUGE-L:")
print(f" Precision: {rouge_l['precision'][0]:.3f}")
print(f" Recall: {rouge_l['recall'][0]:.3f}")
print(f" F1: {rouge_l['f1'][0]:.3f}")
# Loss function example
print(f"\n" + "-" * 40)
print("Loss Function:")
print("-" * 40)
# Single ROUGE type loss
loss_r1 = rouge_scorer.compute_rouge_loss(
cand_logits, ref_logits, rouge_types=["rouge_1"], reduction="mean"
)
print(f"ROUGE-1 Loss: {loss_r1.item():.6f}")
# Combined loss
loss_combined = rouge_scorer.compute_rouge_loss(
cand_logits, ref_logits, rouge_types=["rouge_1", "rouge_l"], reduction="mean"
)
print(f"Combined Loss: {loss_combined.item():.6f}")
# Different reduction modes
loss_none = rouge_scorer.compute_rouge_loss(
cand_logits, ref_logits, rouge_types=["rouge_1"], reduction="none"
)
print(f"Per-sample Loss: {loss_none[0].item():.6f}")
# Batch example
print(f"\n" + "-" * 40)
print("Batch Processing Example:")
print("-" * 40)
# Multiple candidates
candidates = [
"the cat sat on the mat",
"a dog ran in the park",
"the quick brown fox jumps",
]
references = [
"the cat sat on the mat",
"the dog ran fast in the park",
"quick brown fox jumps over lazy dog",
]
# Create batch tensors
batch_cand_logits = torch.cat(
[text_to_logits_helper(cand, 12) for cand in candidates], dim=0
)
batch_ref_logits = [
torch.cat([text_to_logits_helper(ref, 12) for ref in references], dim=0)
]
print(f"Batch candidate shape: {batch_cand_logits.shape}")
print(f"Batch reference shape: {batch_ref_logits[0].shape}")
# Compute batch ROUGE scores
batch_rouge_1 = rouge_scorer.rouge_n_batch(batch_cand_logits, batch_ref_logits, n=1)
batch_loss = rouge_scorer.compute_rouge_loss(
batch_cand_logits,
batch_ref_logits,
rouge_types=["rouge_1"],
reduction="none", # Get per-sample losses
)
print(f"\nBatch Results:")
for i in range(len(candidates)):
print(f" Example {i+1}:")
print(f" Candidate: '{candidates[i]}'")
print(f" Reference: '{references[i]}'")
print(f" ROUGE-1 F1: {batch_rouge_1['f1'][i]:.3f}")
print(f" Loss: {batch_loss[i]:.3f}")
print(f"\n" + "=" * 60)
print("Example completed successfully!")
print("=" * 60)
if __name__ == "__main__":
main()