-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathprepare_roberta_data.py
More file actions
71 lines (60 loc) · 3.09 KB
/
prepare_roberta_data.py
File metadata and controls
71 lines (60 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import pickle
import argparse
from pytorch_transformers.tokenization_roberta import RobertaTokenizer
from mspan_roberta_gcn.drop_roberta_dataset import DropReader
from tag_mspan_robert_gcn.drop_roberta_mspan_dataset import DropReader as TDropReader
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--passage_length_limit", type=int, default=463)
parser.add_argument("--question_length_limit", type=int, default=46)
parser.add_argument("--tag_mspan", action="store_true")
args = parser.parse_args()
tokenizer = RobertaTokenizer.from_pretrained(args.input_path + "/roberta.large")
if args.tag_mspan:
dev_reader = TDropReader(
tokenizer, args.passage_length_limit, args.question_length_limit
)
train_reader = TDropReader(
tokenizer, args.passage_length_limit, args.question_length_limit,
skip_when_all_empty=["passage_span", "question_span", "addition_subtraction", "counting", "multi_span"]
)
data_format = "drop_dataset_{}.json"
data_mode = ["train"]
for dm in data_mode:
dpath = os.path.join(args.input_path, data_format.format(dm))
data = train_reader._read(dpath)
print("Save data to {}.".format(os.path.join(args.output_dir, "tmspan_cached_roberta_{}.pkl".format(dm))))
with open(os.path.join(args.output_dir, "tmspan_cached_roberta_{}.pkl".format(dm)), "wb") as f:
pickle.dump(data, f)
data_mode = ["dev"]
for dm in data_mode:
dpath = os.path.join(args.input_path, data_format.format(dm))
data = dev_reader._read(dpath) if dm == "dev" else train_reader._read(dpath)
print("Save data to {}.".format(os.path.join(args.output_dir, "tmspan_cached_roberta_{}.pkl".format(dm))))
with open(os.path.join(args.output_dir, "tmspan_cached_roberta_{}.pkl".format(dm)), "wb") as f:
pickle.dump(data, f)
else:
dev_reader = DropReader(
tokenizer, args.passage_length_limit, args.question_length_limit
)
train_reader = DropReader(
tokenizer, args.passage_length_limit, args.question_length_limit,
skip_when_all_empty=["passage_span", "question_span", "addition_subtraction", "counting", ]
)
data_format = "drop_dataset_{}.json"
data_mode = ["train"]
for dm in data_mode:
dpath = os.path.join(args.input_path, data_format.format(dm))
data = train_reader._read(dpath)
print("Save data to {}.".format(os.path.join(args.output_dir, "cached_roberta_{}.pkl".format(dm))))
with open(os.path.join(args.output_dir, "cached_roberta_{}.pkl".format(dm)), "wb") as f:
pickle.dump(data, f)
data_mode = ["dev"]
for dm in data_mode:
dpath = os.path.join(args.input_path, data_format.format(dm))
data = dev_reader._read(dpath) if dm == "dev" else train_reader._read(dpath)
print("Save data to {}.".format(os.path.join(args.output_dir, "cached_roberta_{}.pkl".format(dm))))
with open(os.path.join(args.output_dir, "cached_roberta_{}.pkl".format(dm)), "wb") as f:
pickle.dump(data, f)