-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
77 lines (61 loc) · 3.04 KB
/
evaluate.py
File metadata and controls
77 lines (61 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from data import SNLIDataset, snli_labels, create_collate_fn
from nli import BertNLI
import util
from pathlib import Path
import json
from torch.utils.data import ConcatDataset
from transformers import BertTokenizer, RobertaTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
def get_best_epoch(model_save_dir):
val_results = util.load_jsonl(model_save_dir/'val_results.json')
val_results.sort(key=lambda x: x['acc'])
return val_results[-1]['epoch']
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("model_save_dir", type=Path)
parser.add_argument("data_dir", type=Path)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--device', type=str, default='cpu')
args = parser.parse_args()
print(args.model_save_dir)
preds_file = args.model_save_dir/'preds.jsonl'
items = util.load_jsonl(preds_file) if preds_file.exists() else []
items = [i for i in items if i['pairID'][-5:] != 'ab-ba']
existing_ids = {item['pairID'] for item in items}
train_args = json.load((args.model_save_dir/'args.json').open())
match train_args['architecture']:
case 'BERT':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
roberta_se = False
case 'RoBERTa+SE':
tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
roberta_se = True
label_itos = snli_labels
label_stoi = {l:i for i,l in enumerate(label_itos)}
test_data = ConcatDataset([
SNLIDataset(args.data_dir/'snli_1.0/snli_1.0_test.jsonl', 'gold_label', exclude_ids=existing_ids),
SNLIDataset(args.data_dir/f'generated/llama3.2:3b_test.jsonl', 'model_label', exclude_ids=existing_ids),
SNLIDataset(args.data_dir/f'generated/llama3.3:70b_test.jsonl', 'model_label', exclude_ids=existing_ids),
SNLIDataset(args.data_dir/f'generated/deepseek-r1:70b_test.jsonl', 'model_label', exclude_ids=existing_ids),
SNLIDataset(args.data_dir/f'inferred/deepseek-r1:70b_test.jsonl', None, exclude_ids=existing_ids),
])
best_epoch = get_best_epoch(args.model_save_dir)
model = torch.load(args.model_save_dir/f'epoch-{best_epoch}', map_location=args.device)
model.eval()
collate_fn = create_collate_fn(tokenizer, label_stoi, args.device,
roberta_se=roberta_se,
hypothesis_only=train_args['hypothesis_only'])
test_loader = DataLoader(test_data, args.batch_size, collate_fn=collate_fn, shuffle=False)
with torch.no_grad():
for item_ids, x, y in tqdm(test_loader):
logits = model(x)
batch_pred_probs = F.softmax(logits, dim=1)
for item_id, pred_probs in zip(item_ids, batch_pred_probs):
items.append({'pairID': item_id} | dict(zip(snli_labels, pred_probs.tolist())))
with preds_file.open('w') as f:
for line in items:
util.write_jsonl(line, f)