Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 45a787e

Browse files
authored
Merge pull request #222 from rsepassi/push
v1.1.8
2 parents b669110 + 8abc5d2 commit 45a787e

36 files changed

+944
-1656
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.7',
8+
version='1.1.8',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',
@@ -19,6 +19,7 @@
1919
'tensor2tensor/bin/t2t-make-tf-configs',
2020
],
2121
install_requires=[
22+
'bz2file',
2223
'numpy',
2324
'requests',
2425
'sympy',

tensor2tensor/bin/t2t-datagen

100755100644
Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ from tensor2tensor.data_generators import generator_utils
4545
from tensor2tensor.data_generators import image
4646
from tensor2tensor.data_generators import lm1b
4747
from tensor2tensor.data_generators import snli
48-
from tensor2tensor.data_generators import wiki
4948
from tensor2tensor.data_generators import wmt
5049
from tensor2tensor.data_generators import wsj_parsing
5150
from tensor2tensor.utils import registry
@@ -105,10 +104,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
105104
lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True),
106105
lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True)
107106
),
108-
"wiki_32k": (
109-
lambda: wiki.generator(FLAGS.tmp_dir, True),
110-
1000
111-
),
112107
"image_celeba_tune": (
113108
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
114109
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
@@ -170,17 +165,14 @@ def main(_):
170165
# Remove parsing if paths are not given.
171166
if not FLAGS.parsing_path:
172167
problems = [p for p in problems if "parsing" not in p]
173-
# Remove en-de BPE if paths are not given.
174-
if not FLAGS.ende_bpe_path:
175-
problems = [p for p in problems if "ende_bpe" not in p]
176168

177169
if not problems:
178170
problems_str = "\n * ".join(
179171
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
180172
error_msg = ("You must specify one of the supported problems to "
181173
"generate data for:\n * " + problems_str + "\n")
182-
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
183-
"--timit_paths, --ende_bpe_path and --parsing_path.")
174+
error_msg += ("TIMIT and parsing need data_sets specified with "
175+
"--timit_paths and --parsing_path.")
184176
raise ValueError(error_msg)
185177

186178
if not FLAGS.data_dir:
@@ -203,34 +195,17 @@ def generate_data_for_problem(problem):
203195
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
204196
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
205197

206-
if isinstance(dev_gen, int):
207-
# The dev set and test sets are generated as extra shards using the
208-
# training generator. The integer specifies the number of training
209-
# shards. FLAGS.num_shards is ignored.
210-
num_training_shards = dev_gen
211-
tf.logging.info("Generating data for %s.", problem)
212-
all_output_files = generator_utils.combined_data_filenames(
213-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
214-
num_training_shards)
215-
generator_utils.generate_files(training_gen(), all_output_files,
216-
FLAGS.max_cases)
217-
else:
218-
# usual case - train data and dev data are generated using separate
219-
# generators.
220-
num_shards = FLAGS.num_shards or 10
221-
tf.logging.info("Generating training data for %s.", problem)
222-
train_output_files = generator_utils.train_data_filenames(
223-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
224-
generator_utils.generate_files(training_gen(), train_output_files,
225-
FLAGS.max_cases)
226-
tf.logging.info("Generating development data for %s.", problem)
227-
dev_shards = 10 if "coco" in problem else 1
228-
dev_output_files = generator_utils.dev_data_filenames(
229-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
230-
generator_utils.generate_files(dev_gen(), dev_output_files)
231-
all_output_files = train_output_files + dev_output_files
232-
233-
tf.logging.info("Shuffling data...")
198+
num_shards = FLAGS.num_shards or 10
199+
tf.logging.info("Generating training data for %s.", problem)
200+
train_output_files = generator_utils.train_data_filenames(
201+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
202+
generator_utils.generate_files(training_gen(), train_output_files,
203+
FLAGS.max_cases)
204+
tf.logging.info("Generating development data for %s.", problem)
205+
dev_output_files = generator_utils.dev_data_filenames(
206+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1)
207+
generator_utils.generate_files(dev_gen(), dev_output_files)
208+
all_output_files = train_output_files + dev_output_files
234209
generator_utils.shuffle_dataset(all_output_files)
235210

236211

tensor2tensor/bin/t2t-trainer

100755100644
File mode changed.

0 commit comments

Comments
 (0)