Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 94eca0c

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Rename train_generator to just generator and port wiki_32k to Problem. Also cleaning and speeding up vocab generation, algorithmic problems, wmt_zhen and BPE download.
PiperOrigin-RevId: 165015579
1 parent 35416da commit 94eca0c

19 files changed

+418
-437
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'tensor2tensor/bin/t2t-make-tf-configs',
2020
],
2121
install_requires=[
22+
'bz2file',
2223
'numpy',
2324
'requests',
2425
'sympy',

tensor2tensor/bin/t2t-datagen

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ from tensor2tensor.data_generators import generator_utils
4545
from tensor2tensor.data_generators import image
4646
from tensor2tensor.data_generators import lm1b
4747
from tensor2tensor.data_generators import snli
48-
from tensor2tensor.data_generators import wiki
4948
from tensor2tensor.data_generators import wmt
5049
from tensor2tensor.data_generators import wsj_parsing
5150
from tensor2tensor.utils import registry
@@ -82,16 +81,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8281
"algorithmic_algebra_inverse": (
8382
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8483
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
85-
"ice_parsing_tokens": (
86-
lambda: wmt.tabbed_parsing_token_generator(
87-
FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8),
88-
lambda: wmt.tabbed_parsing_token_generator(
89-
FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)),
90-
"ice_parsing_characters": (
91-
lambda: wmt.tabbed_parsing_character_generator(
92-
FLAGS.data_dir, FLAGS.tmp_dir, True),
93-
lambda: wmt.tabbed_parsing_character_generator(
94-
FLAGS.data_dir, FLAGS.tmp_dir, False)),
9584
"wmt_parsing_tokens_8k": (
9685
lambda: wmt.parsing_token_generator(
9786
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
@@ -115,10 +104,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
115104
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
116105
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
117106
),
118-
"wiki_32k": (
119-
lambda: wiki.generator(FLAGS.tmp_dir, True),
120-
1000
121-
),
122107
"image_celeba_tune": (
123108
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
124109
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
@@ -180,17 +165,14 @@ def main(_):
180165
# Remove parsing if paths are not given.
181166
if not FLAGS.parsing_path:
182167
problems = [p for p in problems if "parsing" not in p]
183-
# Remove en-de BPE if paths are not given.
184-
if not FLAGS.ende_bpe_path:
185-
problems = [p for p in problems if "ende_bpe" not in p]
186168

187169
if not problems:
188170
problems_str = "\n * ".join(
189171
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
190172
error_msg = ("You must specify one of the supported problems to "
191173
"generate data for:\n * " + problems_str + "\n")
192-
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
193-
"--timit_paths, --ende_bpe_path and --parsing_path.")
174+
error_msg += ("TIMIT and parsing need data_sets specified with "
175+
"--timit_paths and --parsing_path.")
194176
raise ValueError(error_msg)
195177

196178
if not FLAGS.data_dir:
@@ -213,34 +195,17 @@ def generate_data_for_problem(problem):
213195
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
214196
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
215197

216-
if isinstance(dev_gen, int):
217-
# The dev set and test sets are generated as extra shards using the
218-
# training generator. The integer specifies the number of training
219-
# shards. FLAGS.num_shards is ignored.
220-
num_training_shards = dev_gen
221-
tf.logging.info("Generating data for %s.", problem)
222-
all_output_files = generator_utils.combined_data_filenames(
223-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
224-
num_training_shards)
225-
generator_utils.generate_files(training_gen(), all_output_files,
226-
FLAGS.max_cases)
227-
else:
228-
# usual case - train data and dev data are generated using separate
229-
# generators.
230-
num_shards = FLAGS.num_shards or 10
231-
tf.logging.info("Generating training data for %s.", problem)
232-
train_output_files = generator_utils.train_data_filenames(
233-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
234-
generator_utils.generate_files(training_gen(), train_output_files,
235-
FLAGS.max_cases)
236-
tf.logging.info("Generating development data for %s.", problem)
237-
dev_shards = 10 if "coco" in problem else 1
238-
dev_output_files = generator_utils.dev_data_filenames(
239-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
240-
generator_utils.generate_files(dev_gen(), dev_output_files)
241-
all_output_files = train_output_files + dev_output_files
242-
243-
tf.logging.info("Shuffling data...")
198+
num_shards = FLAGS.num_shards or 10
199+
tf.logging.info("Generating training data for %s.", problem)
200+
train_output_files = generator_utils.train_data_filenames(
201+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
202+
generator_utils.generate_files(training_gen(), train_output_files,
203+
FLAGS.max_cases)
204+
tf.logging.info("Generating development data for %s.", problem)
205+
dev_output_files = generator_utils.dev_data_filenames(
206+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1)
207+
generator_utils.generate_files(dev_gen(), dev_output_files)
208+
all_output_files = train_output_files + dev_output_files
244209
generator_utils.shuffle_dataset(all_output_files)
245210

246211

0 commit comments

Comments
 (0)