Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8a90658

Browse files
authored
Merge pull request #625 from johnglover/fix-ptb-data-generator
Fix PTB data generator.
2 parents fb3c08f + fa13c99 commit 8a90658

File tree

1 file changed

+43
-26
lines changed
  • tensor2tensor/data_generators

1 file changed

+43
-26
lines changed

tensor2tensor/data_generators/ptb.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,37 @@ def _get_token_encoder(vocab_dir, vocab_name, filename):
7777
return text_encoder.TokenTextEncoder(vocab_path)
7878

7979

80+
def _maybe_download_corpus(tmp_dir, vocab_type):
81+
"""Download and unpack the corpus.
82+
83+
Args:
84+
tmp_dir: directory containing dataset.
85+
"""
86+
filename = os.path.basename(PTB_URL)
87+
compressed_filepath = generator_utils.maybe_download(
88+
tmp_dir, filename, PTB_URL)
89+
ptb_files = []
90+
ptb_char_files = []
91+
92+
with tarfile.open(compressed_filepath, "r:gz") as tgz:
93+
files = []
94+
# Selecting only relevant files.
95+
for m in tgz.getmembers():
96+
if "ptb" in m.name and ".txt" in m.name:
97+
if "char" in m.name:
98+
ptb_char_files += [m.name]
99+
else:
100+
ptb_files += [m.name]
101+
files += [m]
102+
103+
tgz.extractall(tmp_dir, members=files)
104+
105+
if vocab_type == text_problems.VocabType.CHARACTER:
106+
return ptb_char_files
107+
else:
108+
return ptb_files
109+
110+
80111
@registry.register_problem
81112
class LanguagemodelPtb10k(text_problems.Text2SelfProblem):
82113
"""PTB, 10k vocab."""
@@ -91,6 +122,10 @@ def dataset_splits(self):
91122
"shards": 1,
92123
}]
93124

125+
@property
126+
def is_generate_per_split(self):
127+
return True
128+
94129
@property
95130
def vocab_filename(self):
96131
return "vocab.lmptb.10000"
@@ -100,28 +135,7 @@ def vocab_type(self):
100135
return text_problems.VocabType.TOKEN
101136

102137
def generate_samples(self, data_dir, tmp_dir, dataset_split):
103-
filename = os.path.basename(PTB_URL)
104-
compressed_filepath = generator_utils.maybe_download(
105-
tmp_dir, filename, PTB_URL)
106-
ptb_files = []
107-
ptb_char_files = []
108-
with tarfile.open(compressed_filepath, "r:gz") as tgz:
109-
files = []
110-
# Selecting only relevant files.
111-
for m in tgz.getmembers():
112-
if "ptb" in m.name and ".txt" in m.name:
113-
if "char" in m.name:
114-
ptb_char_files += [m.name]
115-
else:
116-
ptb_files += [m.name]
117-
files += [m]
118-
119-
tgz.extractall(tmp_dir, members=files)
120-
121-
if self.vocab_type == text_problems.VocabType.CHARACTER:
122-
files = ptb_char_files
123-
else:
124-
files = ptb_files
138+
files = _maybe_download_corpus(tmp_dir, self.vocab_type)
125139

126140
train_file, valid_file = None, None
127141
for filename in files:
@@ -138,10 +152,13 @@ def generate_samples(self, data_dir, tmp_dir, dataset_split):
138152
train = dataset_split == problem.DatasetSplit.TRAIN
139153
filepath = train_file if train else valid_file
140154

141-
with tf.gfile.GFile(filepath, "r") as f:
142-
for line in f:
143-
line = " ".join(line.replace("\n", " %s " % EOS).split())
144-
yield {"targets": line}
155+
def _generate_samples():
156+
with tf.gfile.GFile(filepath, "r") as f:
157+
for line in f:
158+
line = " ".join(line.replace("\n", " %s " % EOS).split())
159+
yield {"targets": line}
160+
161+
return _generate_samples()
145162

146163

147164
@registry.register_problem

0 commit comments

Comments
 (0)