Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 73f0be2

Browse files
authored
Merge pull request #217 from rsepassi/push
v1.1.7
2 parents c5e13db + af4f1e0 commit 73f0be2

37 files changed

+1826
-612
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.6',
8+
version='1.1.7',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/bin/t2t-trainer

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ from __future__ import absolute_import
3030
from __future__ import division
3131
from __future__ import print_function
3232

33+
import os
34+
3335
# Dependency imports
3436

3537
from tensor2tensor.utils import registry
@@ -57,22 +59,25 @@ def main(_):
5759
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
5860
trainer_utils.log_registry()
5961
trainer_utils.validate_flags()
60-
tf.gfile.MakeDirs(FLAGS.output_dir)
62+
output_dir = os.path.expanduser(FLAGS.output_dir)
63+
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
64+
data_dir = os.path.expanduser(FLAGS.data_dir)
65+
tf.gfile.MakeDirs(output_dir)
6166

6267
# Generate data if requested.
6368
if FLAGS.generate_data:
64-
tf.gfile.MakeDirs(FLAGS.data_dir)
65-
tf.gfile.MakeDirs(FLAGS.tmp_dir)
69+
tf.gfile.MakeDirs(data_dir)
70+
tf.gfile.MakeDirs(tmp_dir)
6671
for problem_name in FLAGS.problems.split("-"):
6772
tf.logging.info("Generating data for %s" % problem_name)
6873
problem = registry.problem(problem_name)
69-
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
74+
problem.generate_data(data_dir, tmp_dir)
7075

7176
# Run the trainer.
7277
trainer_utils.run(
73-
data_dir=FLAGS.data_dir,
78+
data_dir=data_dir,
7479
model=FLAGS.model,
75-
output_dir=FLAGS.output_dir,
80+
output_dir=output_dir,
7681
train_steps=FLAGS.train_steps,
7782
eval_steps=FLAGS.eval_steps,
7883
schedule=FLAGS.schedule)

tensor2tensor/data_generators/cipher.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def _gen(nbr_symbols, max_length, nbr_cases):
5656

5757
for plain, code in zip(indices, codes):
5858
yield {
59-
"X": plain,
60-
"Y": code,
59+
"inputs": plain,
60+
"targets": code,
6161
}
6262

6363
return _gen
@@ -99,8 +99,8 @@ def _gen(nbr_symbols, max_length, nbr_cases):
9999

100100
for plain, code in zip(indices, codes):
101101
yield {
102-
"X": plain,
103-
"Y": code,
102+
"inputs": plain,
103+
"targets": code,
104104
}
105105

106106
return _gen
@@ -148,7 +148,7 @@ def key(self):
148148
return [1, 3]
149149

150150

151-
class Layer(object):
151+
class ShiftEncryptionLayer(object):
152152
"""A single layer for shift."""
153153

154154
def __init__(self, vocab, shift):
@@ -211,7 +211,7 @@ def encipher_shift(plaintext, plain_vocab, shift):
211211
ciphertext (list of Strings): encrypted plain text.
212212
"""
213213
ciphertext = []
214-
cipher = Layer(plain_vocab, shift)
214+
cipher = ShiftEncryptionLayer(plain_vocab, shift)
215215

216216
for _, sentence in enumerate(plaintext):
217217
cipher_sentence = []
@@ -238,7 +238,7 @@ def encipher_vigenere(plaintext, plain_vocab, key):
238238
# generate Vigenere table
239239
layers = []
240240
for i in range(len(plain_vocab)):
241-
layers.append(Layer(plain_vocab, i))
241+
layers.append(ShiftEncryptionLayer(plain_vocab, i))
242242

243243
for i, sentence in enumerate(plaintext):
244244
cipher_sentence = []

tensor2tensor/data_generators/concatenate_examples.py

Lines changed: 0 additions & 180 deletions
This file was deleted.

0 commit comments

Comments
 (0)