Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c35c7a3

Browse files
authored
Merge pull request #201 from rsepassi/push
v1.1.4
2 parents 0df0f50 + 41bca68 commit c35c7a3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1012
-657
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ _pycache__/
1010
# PyPI distribution artifacts.
1111
build/
1212
dist/
13-
data/
1413

1514
# Sublime project files
1615
*.sublime-project

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.3',
8+
version='1.1.4',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/bin/t2t-datagen

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -118,40 +118,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
118118
lambda: wiki.generator(FLAGS.tmp_dir, True),
119119
1000
120120
),
121-
"image_mnist_tune": (
122-
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
123-
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
124-
"image_mnist_test": (
125-
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 60000),
126-
lambda: image.mnist_generator(FLAGS.tmp_dir, False, 10000)),
127-
"image_cifar10_tune": (
128-
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 48000),
129-
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 2000, 48000)),
130-
"image_cifar10_test": (
131-
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
132-
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
133-
"image_mscoco_characters_test": (
134-
lambda: image.mscoco_generator(
135-
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000),
136-
lambda: image.mscoco_generator(
137-
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000)),
138121
"image_celeba_tune": (
139122
lambda: image.celeba_generator(FLAGS.tmp_dir, 162770),
140123
lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)),
141-
"image_mscoco_tokens_8k_test": (
142-
lambda: image.mscoco_generator(
143-
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
144-
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
145-
lambda: image.mscoco_generator(
146-
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
147-
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)),
148-
"image_mscoco_tokens_32k_test": (
149-
lambda: image.mscoco_generator(
150-
FLAGS.data_dir, FLAGS.tmp_dir, True, 80000,
151-
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
152-
lambda: image.mscoco_generator(
153-
FLAGS.data_dir, FLAGS.tmp_dir, False, 40000,
154-
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
155124
"snli_32k": (
156125
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
157126
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),

tensor2tensor/data_generators/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ for an example.
2828

2929
`Problem`s support data generation, training, and decoding.
3030

31-
Data generation is handles by `Problem.generate_data` which should produce 2
31+
Data generation is handled by `Problem.generate_data` which should produce 2
3232
datasets, training and dev, which should be named according to
3333
`Problem.training_filepaths` and `Problem.dev_filepaths`.
3434
`Problem.generate_data` should also produce any other files that may be required

tensor2tensor/data_generators/gene_expression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ def example_reading_spec(self):
163163
data_items_to_decoders = None
164164
return (data_fields, data_items_to_decoders)
165165

166-
def preprocess_examples(self, examples, mode):
166+
def preprocess_examples(self, examples, mode, hparams):
167167
del mode
168+
del hparams
168169

169170
# Reshape targets
170171
examples["targets"] = tf.reshape(examples["targets"],

tensor2tensor/data_generators/generator_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
# Dependency imports
3030

31+
import requests
3132
import six
3233
from six.moves import xrange # pylint: disable=redefined-builtin
3334
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3
@@ -196,6 +197,56 @@ def maybe_download(directory, filename, url):
196197
return filepath
197198

198199

200+
def maybe_download_from_drive(directory, filename, url):
201+
"""Download filename from google drive unless it's already in directory.
202+
203+
Args:
204+
directory: path to the directory that will be used.
205+
filename: name of the file to download to (do nothing if it already exists).
206+
url: URL to download from.
207+
208+
Returns:
209+
The path to the downloaded file.
210+
"""
211+
if not tf.gfile.Exists(directory):
212+
tf.logging.info("Creating directory %s" % directory)
213+
os.mkdir(directory)
214+
filepath = os.path.join(directory, filename)
215+
confirm_token = None
216+
if tf.gfile.Exists(filepath):
217+
tf.logging.info("Not downloading, file already found: %s" % filepath)
218+
return filepath
219+
220+
# Since the file is big, drive will scan it for virus and take it to a
221+
# warning page. We find the confirm token on this page and append it to the
222+
# URL to start the download process.
223+
confirm_token = None
224+
session = requests.Session()
225+
response = session.get(url, stream=True)
226+
for k, v in response.cookies.items():
227+
if k.startswith("download_warning"):
228+
confirm_token = v
229+
230+
if confirm_token:
231+
url = url + "&confirm=" + confirm_token
232+
tf.logging.info("Downloading %s to %s" % (url, filepath))
233+
234+
response = session.get(url, stream=True)
235+
# Now begin the download.
236+
chunk_size = 16 * 1024
237+
with open(filepath, "wb") as f:
238+
for chunk in response.iter_content(chunk_size):
239+
if chunk:
240+
f.write(chunk)
241+
242+
# Print newline to clear the carriage return from the download progress
243+
print()
244+
statinfo = os.stat(filepath)
245+
tf.logging.info("Succesfully downloaded %s, %s bytes." % (filename,
246+
statinfo.st_size))
247+
return filepath
248+
249+
199250
def gunzip_file(gz_path, new_path):
200251
"""Unzips from gz_path into new_path.
201252

tensor2tensor/data_generators/generator_utils_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ def testMaybeDownload(self):
6464
os.remove(tmp_file_path + ".http")
6565
os.remove(tmp_file_path)
6666

67+
def testMaybeDownloadFromDrive(self):
68+
tmp_dir = self.get_temp_dir()
69+
(_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
70+
tmp_file_name = os.path.basename(tmp_file_path)
71+
72+
# Download Google index to the temporary file.http.
73+
res_path = generator_utils.maybe_download_from_drive(
74+
tmp_dir, tmp_file_name + ".http", "http://drive.google.com")
75+
self.assertEqual(res_path, tmp_file_path + ".http")
76+
77+
# Clean up.
78+
os.remove(tmp_file_path + ".http")
79+
os.remove(tmp_file_path)
80+
6781
def testGunzipFile(self):
6882
tmp_dir = self.get_temp_dir()
6983
(_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)

0 commit comments

Comments
 (0)