-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata.py
97 lines (86 loc) · 3.55 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import urllib.request
from pathlib import Path
import torch
from transformers import squad_convert_examples_to_features
from transformers.data.processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor
def data_directory(using_bind_mount: bool, rank: int, bind_mount_path: Path = None):
base_dir = bind_mount_path if using_bind_mount else Path("/tmp")
return base_dir / f"data-rank{rank}"
def cache_dir(using_bind_mount: bool, rank: int, bind_mount_path: Path = None):
base_dir = bind_mount_path if using_bind_mount else Path("/tmp")
return base_dir / f"cache/{rank}"
def load_and_cache_examples(
data_dir: Path,
tokenizer,
task,
max_seq_length,
doc_stride,
max_query_length,
evaluate=False,
model_name=None,
):
if task == "SQuAD1.1":
train_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
validation_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
train_file = "train-v1.1.json"
validation_file = "dev-v1.1.json"
processor = SquadV1Processor()
elif task == "SQuAD2.0":
train_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json"
validation_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
train_file = "train-v2.0.json"
validation_file = "dev-v2.0.json"
processor = SquadV2Processor()
else:
raise NameError("Incompatible dataset detected")
if not data_dir.exists():
data_dir.mkdir(parents=True)
if evaluate:
# TODO: Cache instead of always downloading
with urllib.request.urlopen(validation_url) as url:
val_path = data_dir / validation_file
with val_path.open("w") as f:
f.write(url.read().decode())
else:
with urllib.request.urlopen(train_url) as url:
train_path = data_dir / train_file
with train_path.open("w") as f:
f.write(url.read().decode())
# Load data features from cache or dataset file
cached_features_file = os.path.join(
str(data_dir.absolute()),
"cache_{}_{}".format(
"dev" if evaluate else "train",
model_name,
),
)
# Init features and dataset from cache if it exists
overwrite_cache = False # Set to True to do a cache wipe (TODO: Make cache wipe configurable)
if os.path.exists(cached_features_file) and not overwrite_cache:
print("Loading features from cached file %s", cached_features_file)
features_and_dataset = torch.load(cached_features_file)
features, dataset, examples = (
features_and_dataset["features"],
features_and_dataset["dataset"],
features_and_dataset["examples"],
)
else:
if evaluate:
examples = processor.get_dev_examples(data_dir, filename=validation_file)
else:
examples = processor.get_train_examples(data_dir, filename=train_file)
features, dataset = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=not evaluate,
return_dataset="pt",
)
print("Saving features into cached file %s", cached_features_file)
torch.save(
{"features": features, "dataset": dataset, "examples": examples}, cached_features_file
)
return dataset, examples, features