Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ac8f6e3

Browse files
Ryan SepassiCopybara-Service
authored andcommitted
internal merge of PR #872
PiperOrigin-RevId: 207290554
1 parent 7967b44 commit ac8f6e3

File tree

6 files changed

+164
-237
lines changed

6 files changed

+164
-237
lines changed

.travis.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ script:
6060
# * visualization_test
6161
# * model_rl_experiment_test
6262
# * allen_brain_test
63-
# * allen_brain_utils_test
6463
# * model_rl_experiment_stochastic_test
6564
# * models/research
6665
# algorithmic_math_test: flaky
@@ -74,14 +73,12 @@ script:
7473
--ignore=tensor2tensor/models/research/universal_transformer_test.py
7574
--ignore=tensor2tensor/rl/model_rl_experiment_test.py
7675
--ignore=tensor2tensor/data_generators/allen_brain_test.py
77-
--ignore=tensor2tensor/data_generators/allen_brain_utils_test.py
7876
--ignore=tensor2tensor/rl/model_rl_experiment_stochastic_test.py
7977
--ignore=tensor2tensor/models/research
8078
- pytest tensor2tensor/utils/registry_test.py
8179
- pytest tensor2tensor/utils/trainer_lib_test.py
8280
- pytest tensor2tensor/visualization/visualization_test.py
8381
- pytest tensor2tensor/data_generators/allen_brain_test.py
84-
- pytest tensor2tensor/data_generators/allen_brain_utils_test.py
8582
- if [[ "$TF_VERSION" == "$TF_LATEST" ]] || [[ "$TF_VERSION" == "tf-nightly" ]];
8683
then
8784
pytest tensor2tensor/models/research;

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
# explicit pip install gym[atari] for the tests.
5959
# 'gym[atari]',
6060
],
61-
'allen': ['Pillow==5.1.0', 'pandas==0.23.0']
61+
'allen': ['Pillow==5.1.0', 'pandas==0.23.0'],
6262
},
6363
classifiers=[
6464
'Development Status :: 4 - Beta',

tensor2tensor/data_generators/allen_brain.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
24
# Licensed under the Apache License, Version 2.0 (the "License");
35
# you may not use this file except in compliance with the License.
46
# You may obtain a copy of the License at
57
#
6-
# http://www.apache.org/licenses/LICENSE-2.0
8+
# http://www.apache.org/licenses/LICENSE-2.0
79
#
810
# Unless required by applicable law or agreed to in writing, software
911
# distributed under the License is distributed on an "AS IS" BASIS,
1012
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1113
# See the License for the specific language governing permissions and
1214
# limitations under the License.
13-
1415
"""Problem definitions for Allen Brain Atlas problems.
1516
1617
Notes:
@@ -28,18 +29,17 @@
2829

2930
from io import BytesIO
3031
import math
31-
import numpy as np
3232
import os
33+
34+
import numpy as np
3335
import requests
3436

3537
from tensor2tensor.data_generators import generator_utils
3638
from tensor2tensor.data_generators import image_utils
3739
from tensor2tensor.data_generators import problem
3840
from tensor2tensor.data_generators import text_encoder
39-
from tensor2tensor.utils import registry
4041
from tensor2tensor.utils import metrics
41-
42-
from tensor2tensor.data_generators.allen_brain_utils import try_importing_pil_image
42+
from tensor2tensor.utils import registry
4343

4444
import tensorflow as tf
4545

@@ -52,23 +52,28 @@
5252
# the steps described here: http://help.brain-map.org/display/api,
5353
# e.g. https://gist.github.com/cwbeitel/5dffe90eb561637e35cdf6aa4ee3e704
5454
_IMAGE_IDS = [
55-
'74887117', '71894997', '69443979', '79853548', '101371232', '77857182',
56-
'70446772', '68994990', '69141561', '70942310', '70942316', '68298378',
57-
'69690156', '74364867', '77874134', '75925043', '73854431', '69206601',
58-
'71771457', '101311379', '74777533', '70960269', '71604493', '102216720',
59-
'74776437', '75488723', '79815814', '77857132', '77857138', '74952778',
60-
'69068486', '648167', '75703410', '74486118', '77857098', '637407',
61-
'67849516', '69785503', '71547630', '69068504', '69184074', '74853078',
62-
'74890694', '74890698', '75488687', '71138602', '71652378', '68079764',
63-
'70619061', '68280153', '73527042', '69764608', '68399025', '244297',
64-
'69902658', '68234159', '71495521', '74488395', '73923026', '68280155',
65-
'75488747', '69589140', '71342189', '75119214', '79455452', '71774294',
66-
'74364957', '68031779', '71389422', '67937572', '69912671', '73854471',
67-
'75008183', '101371376', '75703290', '69533924', '79853544', '77343882',
68-
'74887133', '332587', '69758622', '69618413', '77929999', '244293',
69-
'334792', '75825136', '75008103', '70196678', '71883965', '74486130',
70-
'74693566', '76107119', '76043858', '70252433', '68928364', '74806345',
71-
'67848661', '75900326', '71773690', '75008171']
55+
"74887117", "71894997", "69443979", "79853548", "101371232", "77857182",
56+
"70446772", "68994990", "69141561", "70942310", "70942316", "68298378",
57+
"69690156", "74364867", "77874134", "75925043", "73854431", "69206601",
58+
"71771457", "101311379", "74777533", "70960269", "71604493", "102216720",
59+
"74776437", "75488723", "79815814", "77857132", "77857138", "74952778",
60+
"69068486", "648167", "75703410", "74486118", "77857098", "637407",
61+
"67849516", "69785503", "71547630", "69068504", "69184074", "74853078",
62+
"74890694", "74890698", "75488687", "71138602", "71652378", "68079764",
63+
"70619061", "68280153", "73527042", "69764608", "68399025", "244297",
64+
"69902658", "68234159", "71495521", "74488395", "73923026", "68280155",
65+
"75488747", "69589140", "71342189", "75119214", "79455452", "71774294",
66+
"74364957", "68031779", "71389422", "67937572", "69912671", "73854471",
67+
"75008183", "101371376", "75703290", "69533924", "79853544", "77343882",
68+
"74887133", "332587", "69758622", "69618413", "77929999", "244293",
69+
"334792", "75825136", "75008103", "70196678", "71883965", "74486130",
70+
"74693566", "76107119", "76043858", "70252433", "68928364", "74806345",
71+
"67848661", "75900326", "71773690", "75008171"]
72+
73+
74+
def PIL_Image(): # pylint: disable=invalid-name
75+
from PIL import Image # pylint: disable=g-import-not-at-top
76+
return Image
7277

7378

7479
def _get_case_file_paths(tmp_dir, case, training_fraction=0.95):
@@ -77,14 +82,17 @@ def _get_case_file_paths(tmp_dir, case, training_fraction=0.95):
7782
Args:
7883
tmp_dir: str, the root path to which raw images were written, at the
7984
top level having meta/ and raw/ subdirs.
80-
size: int, the size of sub-images to consider (`size`x`size`).
8185
case: bool, whether obtaining file paths for training (true) or eval
8286
(false).
8387
training_fraction: float, the fraction of the sub-image path list to
8488
consider as the basis for training examples.
8589
8690
Returns:
8791
list: A list of file paths.
92+
93+
Raises:
94+
ValueError: if images not found in tmp_dir, or if training_fraction would
95+
leave no examples for eval.
8896
"""
8997

9098
paths = tf.gfile.Glob("%s/*.jpg" % tmp_dir)
@@ -146,7 +154,7 @@ def maybe_download_image_dataset(image_ids, target_dir):
146154

147155
response.raise_for_status()
148156

149-
with open(tmp_destination, "w") as f:
157+
with tf.gfile.Open(tmp_destination, "w") as f:
150158
for block in response.iter_content(1024):
151159
f.write(block)
152160

@@ -159,7 +167,6 @@ def random_square_mask(shape, fraction):
159167
Args:
160168
shape: tuple, shape of the mask to create.
161169
fraction: float, fraction of the mask area to populate with `mask_scalar`.
162-
mask_scalar: float, the scalar to apply to the otherwise 1-valued mask.
163170
164171
Returns:
165172
numpy.array: A numpy array storing the mask.
@@ -191,6 +198,8 @@ def _generator(tmp_dir, training, size=_BASE_EXAMPLE_IMAGE_SIZE,
191198
alternatively, evaluation), determining whether examples in tmp_dir
192199
prefixed with train or dev will be used.
193200
size: int, the image size to add to the example annotation.
201+
training_fraction: float, the fraction of the sub-image path list to
202+
consider as the basis for training examples.
194203
195204
Yields:
196205
A dictionary representing the images with the following fields:
@@ -207,7 +216,7 @@ def _generator(tmp_dir, training, size=_BASE_EXAMPLE_IMAGE_SIZE,
207216
case=training,
208217
training_fraction=training_fraction)
209218

210-
image_obj = try_importing_pil_image()
219+
image_obj = PIL_Image()
211220

212221
tf.logging.info("Loaded case file paths (n=%s)" % len(image_files))
213222
height = size
@@ -230,8 +239,7 @@ def _generator(tmp_dir, training, size=_BASE_EXAMPLE_IMAGE_SIZE,
230239
v_end = v_offset + size - 1
231240

232241
# Extract a sub-image tile.
233-
# pylint: disable=invalid-sequence-index
234-
subimage = np.uint8(img[h_offset:h_end, v_offset:v_end])
242+
subimage = np.uint8(img[h_offset:h_end, v_offset:v_end]) # pylint: disable=invalid-sequence-index
235243

236244
# Filter images that are likely background (not tissue).
237245
if np.amax(subimage) < 230:

tensor2tensor/data_generators/allen_brain_test.py

Lines changed: 126 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,103 @@
11
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
24
# Licensed under the Apache License, Version 2.0 (the "License");
35
# you may not use this file except in compliance with the License.
46
# You may obtain a copy of the License at
57
#
6-
# http://www.apache.org/licenses/LICENSE-2.0
8+
# http://www.apache.org/licenses/LICENSE-2.0
79
#
810
# Unless required by applicable law or agreed to in writing, software
911
# distributed under the License is distributed on an "AS IS" BASIS,
1012
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1113
# See the License for the specific language governing permissions and
1214
# limitations under the License.
13-
1415
"""Tests of the Allen Brain Atlas problems."""
1516

16-
import tensorflow as tf
17-
from tensorflow.contrib.eager.python import tfe
17+
import os
18+
import shutil
19+
import tempfile
20+
21+
import numpy as np
1822

1923
from tensor2tensor.data_generators import allen_brain
20-
from tensor2tensor.data_generators.allen_brain import _generator
21-
from tensor2tensor.data_generators.allen_brain_utils import mock_raw_data
22-
from tensor2tensor.data_generators.allen_brain_utils import TemporaryDirectory
2324
from tensor2tensor.models import image_transformer_2d
2425

26+
import tensorflow as tf
27+
28+
tfe = tf.contrib.eager
2529
tfe.enable_eager_execution()
26-
Modes = tf.estimator.ModeKeys
30+
Modes = tf.estimator.ModeKeys # pylint: disable=invalid-name
31+
32+
33+
def mock_raw_image(x_dim=1024, y_dim=1024, num_channels=3,
34+
output_path=None, write_image=True):
35+
"""Generate random `x_dim` by `y_dim`, optionally to `output_path`.
36+
37+
Args:
38+
x_dim: int, the x dimension of generated raw image.
39+
y_dim: int, the x dimension of generated raw image.
40+
num_channels: int, number of channels in image.
41+
output_path: str, path to which to write image.
42+
write_image: bool, whether to write the image to output_path.
43+
44+
Returns:
45+
numpy.array: The random `x_dim` by `y_dim` image (i.e. array).
46+
"""
47+
48+
rand_shape = (x_dim, y_dim, num_channels)
49+
50+
if num_channels != 3:
51+
raise NotImplementedError("mock_raw_image for channels != 3 not yet "
52+
"implemented.")
53+
54+
img = np.random.random(rand_shape)
55+
img = np.uint8(img*255)
56+
57+
if write_image:
58+
image_obj = allen_brain.PIL_Image()
59+
pil_img = image_obj.fromarray(img, mode="RGB")
60+
with tf.gfile.Open(output_path, "w") as f:
61+
pil_img.save(f, "jpeg")
62+
63+
return img
64+
65+
66+
def mock_raw_data(tmp_dir, raw_dim=1024, num_channels=3, num_images=1):
67+
"""Mock a raw data download directory with meta and raw subdirs.
68+
69+
Notes:
70+
71+
* This utility is shared by tests in both allen_brain_utils and
72+
allen_brain so kept here instead of in one of *_test.
73+
74+
Args:
75+
tmp_dir: str, temporary dir in which to mock data.
76+
raw_dim: int, the x and y dimension of generated raw imgs.
77+
num_channels: int, number of channels in image.
78+
num_images: int, number of images to mock.
79+
"""
80+
81+
tf.gfile.MakeDirs(tmp_dir)
82+
83+
for image_id in range(num_images):
84+
85+
raw_image_path = os.path.join(tmp_dir, "%s.jpg" % image_id)
86+
87+
mock_raw_image(x_dim=raw_dim, y_dim=raw_dim,
88+
num_channels=num_channels,
89+
output_path=raw_image_path)
90+
91+
92+
class TemporaryDirectory(object):
93+
"""For py2 support of `with tempfile.TemporaryDirectory() as name:`"""
94+
95+
def __enter__(self):
96+
self.name = tempfile.mkdtemp()
97+
return self.name
98+
99+
def __exit__(self, exc_type, exc_value, traceback):
100+
shutil.rmtree(self.name)
27101

28102

29103
class TestAllenBrain(tf.test.TestCase):
@@ -32,10 +106,6 @@ class TestAllenBrain(tf.test.TestCase):
32106
def setUp(self):
33107

34108
self.all_problems = [
35-
#allen_brain.Img2imgAllenBrain,
36-
#allen_brain.Img2imgAllenBrainDim48to64,
37-
#allen_brain.Img2imgAllenBrainDim8to32,
38-
#allen_brain.Img2imgAllenBrainDim16to32,
39109
allen_brain.Img2imgAllenBrainDim16to16Paint1
40110
]
41111

@@ -45,7 +115,7 @@ def test_generator_produces_examples(self):
45115
for is_training in [True, False]:
46116
with TemporaryDirectory() as tmp_dir:
47117
mock_raw_data(tmp_dir, raw_dim=256, num_images=100)
48-
for example in _generator(tmp_dir, is_training):
118+
for example in allen_brain._generator(tmp_dir, is_training):
49119
for key in ["image/encoded", "image/format",
50120
"image/height", "image/width"]:
51121
self.assertTrue(key in example.keys())
@@ -170,5 +240,48 @@ def loss_fn(features):
170240
256))
171241

172242

243+
class TestImageMock(tf.test.TestCase):
244+
"""Tests of image mocking utility."""
245+
246+
def test_image_mock_produces_expected_shape(self):
247+
"""Test that the image mocking utility produces expected shape output."""
248+
249+
with TemporaryDirectory() as tmp_dir:
250+
251+
cases = [
252+
{
253+
"x_dim": 8,
254+
"y_dim": 8,
255+
"num_channels": 3,
256+
"output_path": "/foo",
257+
"write_image": True
258+
}
259+
]
260+
261+
for cid, case in enumerate(cases):
262+
output_path = os.path.join(tmp_dir, "dummy%s.jpg" % cid)
263+
img = mock_raw_image(x_dim=case["x_dim"],
264+
y_dim=case["y_dim"],
265+
num_channels=case["num_channels"],
266+
output_path=output_path,
267+
write_image=case["write_image"])
268+
269+
self.assertEqual(img.shape, (case["x_dim"], case["y_dim"],
270+
case["num_channels"]))
271+
if case["write_image"]:
272+
self.assertTrue(tf.gfile.Exists(output_path))
273+
274+
275+
class TestMockRawData(tf.test.TestCase):
276+
"""Tests of raw data mocking utility."""
277+
278+
def test_runs(self):
279+
"""Test that data mocking utility runs for cases expected to succeed."""
280+
281+
with TemporaryDirectory() as tmp_dir:
282+
283+
mock_raw_data(tmp_dir, raw_dim=256, num_channels=3, num_images=40)
284+
285+
173286
if __name__ == "__main__":
174287
tf.test.main()

0 commit comments

Comments
 (0)