Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b80bf4e

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Clean up autoencoder code, correct modality bug, add autoregressive baselines and tests.
PiperOrigin-RevId: 193435665
1 parent acee384 commit b80bf4e

File tree

6 files changed

+238
-26
lines changed

6 files changed

+238
-26
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def dropout_with_broadcast_dims(x, keep_prob, broadcast_dims=None, **kwargs):
7272
if broadcast_dims:
7373
shape = tf.shape(x)
7474
ndims = len(x.get_shape())
75+
# Allow dimensions like "-1" as well.
76+
broadcast_dims = [dim + ndims if dim < 0 else dim for dim in broadcast_dims]
7577
kwargs["noise_shape"] = [
7678
1 if i in broadcast_dims else shape[i] for i in xrange(ndims)]
7779
return tf.nn.dropout(x, keep_prob, **kwargs)
@@ -441,7 +443,7 @@ def conv2d_kernel(kernel_size_arg, name_suffix):
441443
return conv2d_kernel(kernel_size, "single")
442444

443445

444-
def conv(inputs, filters, kernel_size, dilation_rate=1, **kwargs):
446+
def conv(inputs, filters, kernel_size, dilation_rate=(1, 1), **kwargs):
445447
return conv_internal(
446448
tf.layers.conv2d,
447449
inputs,

tensor2tensor/layers/modalities.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,7 @@ def loss(self, logits, targets):
258258
logits,
259259
targets,
260260
self._model_hparams.label_smoothing,
261-
weights_fn=self.targets_weights_fn,
262-
gaussian=True)
261+
weights_fn=self.targets_weights_fn)
263262

264263

265264
@registry.register_image_modality("image_channel_compress")
@@ -535,8 +534,7 @@ def loss(self, logits, targets):
535534
logits,
536535
targets,
537536
self._model_hparams.label_smoothing,
538-
weights_fn=self.targets_weights_fn,
539-
gaussian=True)
537+
weights_fn=self.targets_weights_fn)
540538

541539

542540
@registry.register_class_label_modality("default")

tensor2tensor/models/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def basic_autoencoder():
206206
hparams.learning_rate_constant = 0.0002
207207
hparams.learning_rate_warmup_steps = 500
208208
hparams.learning_rate_schedule = "constant * linear_warmup"
209-
hparams.label_smoothing = 0.05
209+
hparams.label_smoothing = 0.0
210210
hparams.batch_size = 128
211211
hparams.hidden_size = 64
212212
hparams.num_hidden_layers = 5

tensor2tensor/models/basic_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Basic nets tests."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
import numpy as np
25+
26+
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
27+
from tensor2tensor.models import basic
28+
from tensor2tensor.utils import trainer_lib
29+
30+
import tensorflow as tf
31+
32+
33+
class BasicTest(tf.test.TestCase):
34+
35+
def testBasicFcRelu(self):
36+
x = np.random.random_integers(0, high=255, size=(1, 28, 28, 1))
37+
y = np.random.random_integers(0, high=9, size=(1, 1))
38+
hparams = trainer_lib.create_hparams(
39+
"basic_fc_small", problem_name="image_mnist", data_dir=".")
40+
with self.test_session() as session:
41+
features = {
42+
"inputs": tf.constant(x, dtype=tf.int32),
43+
"targets": tf.constant(y, dtype=tf.int32),
44+
}
45+
model = basic.BasicFcRelu(hparams, tf.estimator.ModeKeys.TRAIN)
46+
logits, _ = model(features)
47+
session.run(tf.global_variables_initializer())
48+
res = session.run(logits)
49+
self.assertEqual(res.shape, (1, 1, 1, 1, 10))
50+
51+
def testBasicAutoencoder(self):
52+
x = np.random.random_integers(0, high=255, size=(1, 28, 28, 1))
53+
y = np.random.random_integers(0, high=9, size=(1, 1))
54+
hparams = trainer_lib.create_hparams(
55+
"basic_autoencoder", problem_name="image_mnist_rev", data_dir=".")
56+
with self.test_session() as session:
57+
features = {
58+
"targets": tf.constant(x, dtype=tf.int32),
59+
"inputs": tf.constant(y, dtype=tf.int32),
60+
}
61+
tf.train.create_global_step()
62+
model = basic.BasicAutoencoder(hparams, tf.estimator.ModeKeys.TRAIN)
63+
logits, _ = model(features)
64+
session.run(tf.global_variables_initializer())
65+
res = session.run(logits)
66+
self.assertEqual(res.shape, (1, 28, 28, 1, 256))
67+
68+
69+
if __name__ == "__main__":
70+
tf.test.main()

tensor2tensor/models/research/autoencoders.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,53 @@
3030

3131

3232
@registry.register_model
33-
class ResidualAutoencoder(basic.BasicAutoencoder):
33+
class AutoencoderAutoregressive(basic.BasicAutoencoder):
34+
"""Autoencoder with an autoregressive part."""
35+
36+
def body(self, features):
37+
hparams = self._hparams
38+
shape = common_layers.shape_list(features["targets"])
39+
# Run the basic autoencoder part first.
40+
basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
41+
# Prepare inputs for autoregressive modes.
42+
targets_keep_prob = 1.0 - hparams.autoregressive_dropout
43+
targets_dropout = common_layers.dropout_with_broadcast_dims(
44+
features["targets"], targets_keep_prob, broadcast_dims=[-1])
45+
targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
46+
targets_shifted = common_layers.shift_right_3d(targets1d)
47+
basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
48+
concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
49+
# The forget_base hparam sets purely-autoregressive mode, no autoencoder.
50+
if hparams.autoregressive_forget_base:
51+
concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]])
52+
concat1d = common_layers.shift_right_3d(concat1d)
53+
# The autoregressive part depends on the mode.
54+
if hparams.autoregressive_mode == "none":
55+
assert not hparams.autoregressive_forget_base
56+
return basic_result, losses
57+
if hparams.autoregressive_mode == "conv3":
58+
res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT",
59+
activation=common_layers.belu,
60+
name="autoregressive_conv3")
61+
return tf.reshape(res, shape), losses
62+
if hparams.autoregressive_mode == "conv5":
63+
res = common_layers.conv1d(concat1d, shape[3], 5, padding="LEFT",
64+
activation=common_layers.belu,
65+
name="autoregressive_conv5")
66+
return tf.reshape(res, shape), losses
67+
if hparams.autoregressive_mode == "sru":
68+
res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT",
69+
activation=common_layers.belu,
70+
name="autoregressive_sru_conv3")
71+
res = common_layers.sru(res)
72+
return tf.reshape(res, shape), losses
73+
74+
raise ValueError("Unsupported autoregressive mode: %s"
75+
% hparams.autoregressive_mode)
76+
77+
78+
@registry.register_model
79+
class AutoencoderResidual(AutoencoderAutoregressive):
3480
"""Residual autoencoder."""
3581

3682
def encoder(self, x):
@@ -106,7 +152,7 @@ def decoder(self, x):
106152

107153

108154
@registry.register_model
109-
class BasicDiscreteAutoencoder(basic.BasicAutoencoder):
155+
class AutoencoderBasicDiscrete(AutoencoderAutoregressive):
110156
"""Discrete autoencoder."""
111157

112158
def bottleneck(self, x):
@@ -132,7 +178,7 @@ def sample(self):
132178

133179

134180
@registry.register_model
135-
class ResidualDiscreteAutoencoder(ResidualAutoencoder):
181+
class AutoencoderResidualDiscrete(AutoencoderResidual):
136182
"""Discrete residual autoencoder."""
137183

138184
def bottleneck(self, x, bottleneck_size=None):
@@ -160,13 +206,15 @@ def sample(self):
160206
size = [hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y,
161207
hp.bottleneck_size]
162208
rand = tf.random_uniform(size)
163-
res1 = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
164-
res2 = tf.zeros_like(rand) - 1.0
165-
return tf.concat([res2[:, :, :, :2], res1[:, :, :, 2:]], axis=-1)
209+
res = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
210+
# If you want to set some first bits to a fixed value, do this:
211+
# fixed = tf.zeros_like(rand) - 1.0
212+
# res = tf.concat([fixed[:, :, :, :2], res[:, :, :, 2:]], axis=-1)
213+
return res
166214

167215

168216
@registry.register_model
169-
class OrderedDiscreteAutoencoder(ResidualDiscreteAutoencoder):
217+
class AutoencoderOrderedDiscrete(AutoencoderResidualDiscrete):
170218
"""Ordered discrete autoencoder."""
171219

172220
def bottleneck(self, x):
@@ -195,7 +243,7 @@ def bottleneck(self, x):
195243

196244

197245
@registry.register_model
198-
class StackedAutoencoder(ResidualDiscreteAutoencoder):
246+
class AutoencoderStacked(AutoencoderResidualDiscrete):
199247
"""A stacked autoencoder."""
200248

201249
def stack(self, b, size, bottleneck_size, name):
@@ -290,9 +338,19 @@ def body(self, features):
290338

291339

292340
@registry.register_hparams
293-
def residual_autoencoder():
294-
"""Residual autoencoder model."""
341+
def autoencoder_autoregressive():
342+
"""Autoregressive autoencoder model."""
295343
hparams = basic.basic_autoencoder()
344+
hparams.add_hparam("autoregressive_forget_base", False)
345+
hparams.add_hparam("autoregressive_mode", "conv3")
346+
hparams.add_hparam("autoregressive_dropout", 0.4)
347+
return hparams
348+
349+
350+
@registry.register_hparams
351+
def autoencoder_residual():
352+
"""Residual autoencoder model."""
353+
hparams = autoencoder_autoregressive()
296354
hparams.optimizer = "Adam"
297355
hparams.learning_rate_constant = 0.0001
298356
hparams.learning_rate_warmup_steps = 500
@@ -311,9 +369,9 @@ def residual_autoencoder():
311369

312370

313371
@registry.register_hparams
314-
def basic_discrete_autoencoder():
372+
def autoencoder_basic_discrete():
315373
"""Basic autoencoder model."""
316-
hparams = basic.basic_autoencoder()
374+
hparams = autoencoder_autoregressive()
317375
hparams.num_hidden_layers = 5
318376
hparams.hidden_size = 64
319377
hparams.bottleneck_size = 4096
@@ -324,9 +382,9 @@ def basic_discrete_autoencoder():
324382

325383

326384
@registry.register_hparams
327-
def residual_discrete_autoencoder():
385+
def autoencoder_residual_discrete():
328386
"""Residual discrete autoencoder model."""
329-
hparams = residual_autoencoder()
387+
hparams = autoencoder_residual()
330388
hparams.bottleneck_size = 4096
331389
hparams.bottleneck_noise = 0.1
332390
hparams.bottleneck_warmup_steps = 3000
@@ -339,9 +397,9 @@ def residual_discrete_autoencoder():
339397

340398

341399
@registry.register_hparams
342-
def residual_discrete_autoencoder_big():
400+
def autoencoder_residual_discrete_big():
343401
"""Residual discrete autoencoder model, big version."""
344-
hparams = residual_discrete_autoencoder()
402+
hparams = autoencoder_residual_discrete()
345403
hparams.hidden_size = 128
346404
hparams.max_hidden_size = 4096
347405
hparams.bottleneck_noise = 0.1
@@ -351,15 +409,15 @@ def residual_discrete_autoencoder_big():
351409

352410

353411
@registry.register_hparams
354-
def ordered_discrete_autoencoder():
412+
def autoencoder_ordered_discrete():
355413
"""Basic autoencoder model."""
356-
hparams = residual_discrete_autoencoder()
414+
hparams = autoencoder_residual_discrete()
357415
return hparams
358416

359417

360418
@registry.register_hparams
361-
def stacked_autoencoder():
419+
def autoencoder_stacked():
362420
"""Stacked autoencoder model."""
363-
hparams = residual_discrete_autoencoder()
421+
hparams = autoencoder_residual_discrete()
364422
hparams.bottleneck_size = 128
365423
return hparams
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Autoencoders tests."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
import numpy as np
25+
26+
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
27+
from tensor2tensor.models.research import autoencoders # pylint: disable=unused-import
28+
from tensor2tensor.utils import registry
29+
from tensor2tensor.utils import trainer_lib
30+
31+
import tensorflow as tf
32+
33+
34+
class AutoencoderTest(tf.test.TestCase):
35+
36+
def getMnistRandomOutput(self, model_name, hparams_set=None,
37+
mode=tf.estimator.ModeKeys.TRAIN):
38+
hparams_set = hparams_set or model_name
39+
x = np.random.random_integers(0, high=255, size=(1, 28, 28, 1))
40+
y = np.random.random_integers(0, high=9, size=(1, 1))
41+
hparams = trainer_lib.create_hparams(
42+
hparams_set, problem_name="image_mnist_rev", data_dir=".")
43+
with self.test_session() as session:
44+
features = {
45+
"targets": tf.constant(x, dtype=tf.int32),
46+
"inputs": tf.constant(y, dtype=tf.int32),
47+
}
48+
tf.train.create_global_step()
49+
model = registry.model(model_name)(hparams, mode)
50+
logits, _ = model(features)
51+
session.run(tf.global_variables_initializer())
52+
res = session.run(logits)
53+
return res
54+
55+
@property
56+
def mnistOutputShape(self):
57+
return (1, 28, 28, 1, 256)
58+
59+
def testAutoencoderAutoregressive(self):
60+
res = self.getMnistRandomOutput("autoencoder_autoregressive")
61+
self.assertEqual(res.shape, self.mnistOutputShape)
62+
63+
def testAutoencoderResidual(self):
64+
res = self.getMnistRandomOutput("autoencoder_residual")
65+
self.assertEqual(res.shape, self.mnistOutputShape)
66+
67+
def testAutoencoderBasicDiscrete(self):
68+
res = self.getMnistRandomOutput("autoencoder_basic_discrete")
69+
self.assertEqual(res.shape, self.mnistOutputShape)
70+
71+
def testAutoencoderResidualDiscrete(self):
72+
res = self.getMnistRandomOutput("autoencoder_residual_discrete")
73+
self.assertEqual(res.shape, self.mnistOutputShape)
74+
75+
def testAutoencoderOrderedDiscrete(self):
76+
res = self.getMnistRandomOutput("autoencoder_ordered_discrete")
77+
self.assertEqual(res.shape, self.mnistOutputShape)
78+
79+
def testAutoencoderStacked(self):
80+
res = self.getMnistRandomOutput("autoencoder_stacked")
81+
self.assertEqual(res.shape, self.mnistOutputShape)
82+
83+
if __name__ == "__main__":
84+
tf.test.main()

0 commit comments

Comments
 (0)