Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit cd222d3

Browse files
authored
Merge pull request #188 from rsepassi/push
v1.1.3
2 parents a55c4cf + 7c072d7 commit cd222d3

17 files changed

+323
-322
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.2',
8+
version='1.1.3',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/bin/t2t-datagen

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ from tensor2tensor.data_generators import audio
4343
from tensor2tensor.data_generators import generator_utils
4444
from tensor2tensor.data_generators import image
4545
from tensor2tensor.data_generators import lm1b
46-
from tensor2tensor.data_generators import ptb
4746
from tensor2tensor.data_generators import snli
4847
from tensor2tensor.data_generators import wiki
4948
from tensor2tensor.data_generators import wmt
@@ -176,12 +175,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
176175
lambda: audio.timit_generator(
177176
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
178177
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
179-
"lmptb_10k": (
180-
lambda: ptb.train_generator(
181-
FLAGS.tmp_dir,
182-
FLAGS.data_dir,
183-
False),
184-
ptb.valid_generator),
185178
}
186179

187180
# pylint: enable=g-long-lambda

tensor2tensor/data_generators/gene_expression.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
110110
# Collect created shard processes to start and join
111111
processes = []
112112

113-
datasets = [(self.training_filepaths, self.num_shards, "train",
114-
num_train_examples), (self.dev_filepaths, 1, "valid",
115-
num_dev_examples),
116-
(self.test_filepaths, 1, "test", num_test_examples)]
113+
datasets = [
114+
(self.training_filepaths, self.num_shards, "train", num_train_examples),
115+
(self.dev_filepaths, 10, "valid", num_dev_examples),
116+
(self.test_filepaths, 10, "test", num_test_examples)]
117117
for fname_fn, nshards, key_prefix, num_examples in datasets:
118118
outfiles = fname_fn(data_dir, nshards, shuffled=False)
119119
all_filepaths.extend(outfiles)
@@ -125,8 +125,8 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
125125
start_idx, end_idx))
126126
processes.append(p)
127127

128-
# 1 per training shard + dev + test
129-
assert len(processes) == self.num_shards + 2
128+
# 1 per training shard + 10 for dev + 10 for test
129+
assert len(processes) == self.num_shards + 20
130130

131131
# Start and wait for processes in batches
132132
num_batches = int(
@@ -168,8 +168,8 @@ def preprocess_examples(self, examples, mode):
168168

169169
# Reshape targets
170170
examples["targets"] = tf.reshape(examples["targets"],
171-
[-1, 1, self.num_output_predictions])
172-
examples["targets_mask"] = tf.reshape(examples["targets_mask"], [-1, 1, 1])
171+
[-1, self.num_output_predictions])
172+
examples["targets_mask"] = tf.reshape(examples["targets_mask"], [-1, 1])
173173

174174
# Set masked targets to 0 (i.e. pad) so that loss and metrics ignore them.
175175
# Add epsilon because some unmasked labels are actually 0.

tensor2tensor/data_generators/generator_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def generate():
305305

306306
# Use Tokenizer to count the word occurrences.
307307
with tf.gfile.GFile(filepath, mode="r") as source_file:
308-
file_byte_budget = 3.5e5 if "en" in filepath else 7e5
308+
file_byte_budget = 3.5e5 if filepath.endswith("en") else 7e5
309309
for line in source_file:
310310
if file_byte_budget <= 0:
311311
break

tensor2tensor/data_generators/problem.py

Lines changed: 107 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import os
22+
2123
# Dependency imports
2224

23-
from tensor2tensor.data_generators import generator_utils as utils
25+
from tensor2tensor.data_generators import generator_utils
2426
from tensor2tensor.data_generators import text_encoder
2527
from tensor2tensor.utils import metrics
28+
from tensor2tensor.utils import registry
2629

2730
import tensorflow as tf
2831

@@ -176,20 +179,23 @@ def eval_metrics(self):
176179
def training_filepaths(self, data_dir, num_shards, shuffled):
177180
file_basename = self.dataset_filename()
178181
if not shuffled:
179-
file_basename += utils.UNSHUFFLED_SUFFIX
180-
return utils.train_data_filenames(file_basename, data_dir, num_shards)
182+
file_basename += generator_utils.UNSHUFFLED_SUFFIX
183+
return generator_utils.train_data_filenames(
184+
file_basename, data_dir, num_shards)
181185

182186
def dev_filepaths(self, data_dir, num_shards, shuffled):
183187
file_basename = self.dataset_filename()
184188
if not shuffled:
185-
file_basename += utils.UNSHUFFLED_SUFFIX
186-
return utils.dev_data_filenames(file_basename, data_dir, num_shards)
189+
file_basename += generator_utils.UNSHUFFLED_SUFFIX
190+
return generator_utils.dev_data_filenames(
191+
file_basename, data_dir, num_shards)
187192

188193
def test_filepaths(self, data_dir, num_shards, shuffled):
189194
file_basename = self.dataset_filename()
190195
if not shuffled:
191-
file_basename += utils.UNSHUFFLED_SUFFIX
192-
return utils.test_data_filenames(file_basename, data_dir, num_shards)
196+
file_basename += generator_utils.UNSHUFFLED_SUFFIX
197+
return generator_utils.test_data_filenames(
198+
file_basename, data_dir, num_shards)
193199

194200
def __init__(self, was_reversed=False, was_copy=False):
195201
"""Create a Problem.
@@ -323,3 +329,97 @@ def _default_hparams():
323329
# class.
324330
input_space_id=SpaceID.GENERIC,
325331
target_space_id=SpaceID.GENERIC)
332+
333+
334+
class Text2TextProblem(Problem):
335+
"""Base class for text-to-text problems."""
336+
337+
@property
338+
def is_character_level(self):
339+
raise NotImplementedError()
340+
341+
@property
342+
def targeted_vocab_size(self):
343+
raise NotImplementedError() # Not needed if self.is_character_level.
344+
345+
def train_generator(self, data_dir, tmp_dir, is_training):
346+
"""Generator of the training data."""
347+
raise NotImplementedError()
348+
349+
def dev_generator(self, data_dir, tmp_dir):
350+
"""Generator of the development data."""
351+
return self.train_generator(data_dir, tmp_dir, False)
352+
353+
@property
354+
def input_space_id(self):
355+
raise NotImplementedError()
356+
357+
@property
358+
def target_space_id(self):
359+
raise NotImplementedError()
360+
361+
@property
362+
def num_shards(self):
363+
raise NotImplementedError()
364+
365+
@property
366+
def vocab_name(self):
367+
raise NotImplementedError()
368+
369+
@property
370+
def vocab_file(self):
371+
return "%s.%d" % (self.vocab_name, self.targeted_vocab_size)
372+
373+
@property
374+
def use_subword_tokenizer(self):
375+
raise NotImplementedError()
376+
377+
@property
378+
def has_inputs(self):
379+
return True # Set to False for language models.
380+
381+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
382+
generator_utils.generate_dataset_and_shuffle(
383+
self.train_generator(data_dir, tmp_dir, True),
384+
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
385+
self.dev_generator(data_dir, tmp_dir),
386+
self.dev_filepaths(data_dir, 1, shuffled=False))
387+
388+
def feature_encoders(self, data_dir):
389+
vocab_filename = os.path.join(data_dir, self.vocab_file)
390+
if self.is_character_level:
391+
encoder = text_encoder.ByteTextEncoder(),
392+
elif self.use_subword_tokenizer:
393+
encoder = text_encoder.SubwordTextEncoder(vocab_filename)
394+
else:
395+
encoder = text_encoder.TokenTextEncoder(vocab_filename)
396+
if self.has_inputs:
397+
return {"inputs": encoder, "targets": encoder}
398+
return {"targets": encoder}
399+
400+
def hparams(self, defaults, unused_model_hparams):
401+
p = defaults
402+
if self.is_character_level:
403+
source_vocab_size = 256
404+
target_vocab_size = 256
405+
else:
406+
target_vocab_size = self._encoders["targets"].vocab_size
407+
if self.has_inputs:
408+
source_vocab_size = self._encoders["inputs"].vocab_size
409+
410+
if self.has_inputs:
411+
p.input_modality = {"inputs": (registry.Modalities.SYMBOL,
412+
source_vocab_size)}
413+
p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size)
414+
if self.has_inputs:
415+
p.input_space_id = self.input_space_id
416+
p.target_space_id = self.target_space_id
417+
if self.is_character_level:
418+
p.loss_multiplier = 2.0
419+
420+
def eval_metrics(self):
421+
return [
422+
metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5,
423+
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY,
424+
metrics.Metrics.APPROX_BLEU
425+
]

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -368,21 +368,6 @@ def wiki_32k(model_hparams):
368368
return p
369369

370370

371-
def lmptb_10k(model_hparams):
372-
"""Penn Tree Bank language-modeling benchmark, 10k token vocabulary."""
373-
p = default_problem_hparams()
374-
p.input_modality = {}
375-
p.target_modality = (registry.Modalities.SYMBOL, 10000)
376-
vocabulary = text_encoder.TokenTextEncoder(
377-
os.path.join(model_hparams.data_dir, "lmptb_10k.vocab"))
378-
p.vocabulary = {
379-
"targets": vocabulary,
380-
}
381-
p.input_space_id = 3
382-
p.target_space_id = 3
383-
return p
384-
385-
386371
def wmt_ende_bpe32k(model_hparams):
387372
"""English to German translation benchmark."""
388373
p = default_problem_hparams()
@@ -642,7 +627,6 @@ def image_celeba(unused_model_hparams):
642627
"lm1b_characters": lm1b_characters,
643628
"lm1b_32k": lm1b_32k,
644629
"wiki_32k": wiki_32k,
645-
"lmptb_10k": lmptb_10k,
646630
"ice_parsing_characters": wmt_parsing_characters,
647631
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
648632
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),

0 commit comments

Comments
 (0)