Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 160bed3

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Improvements to basic_conv_gen and autoencoder hparams.
PiperOrigin-RevId: 191776372
1 parent 6eea0e2 commit 160bed3

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

tensor2tensor/models/research/autoencoders.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def basic_discrete_autoencoder():
316316
hparams = basic.basic_autoencoder()
317317
hparams.num_hidden_layers = 5
318318
hparams.hidden_size = 64
319-
hparams.bottleneck_size = 2048
320-
hparams.bottleneck_noise = 0.2
319+
hparams.bottleneck_size = 4096
320+
hparams.bottleneck_noise = 0.1
321321
hparams.bottleneck_warmup_steps = 3000
322322
hparams.add_hparam("discretize_warmup_steps", 5000)
323323
return hparams
@@ -327,8 +327,8 @@ def basic_discrete_autoencoder():
327327
def residual_discrete_autoencoder():
328328
"""Residual discrete autoencoder model."""
329329
hparams = residual_autoencoder()
330-
hparams.bottleneck_size = 2048
331-
hparams.bottleneck_noise = 0.2
330+
hparams.bottleneck_size = 4096
331+
hparams.bottleneck_noise = 0.1
332332
hparams.bottleneck_warmup_steps = 3000
333333
hparams.add_hparam("discretize_warmup_steps", 5000)
334334
hparams.add_hparam("bottleneck_kind", "tanh_discrete")
@@ -344,7 +344,6 @@ def residual_discrete_autoencoder_big():
344344
hparams = residual_discrete_autoencoder()
345345
hparams.hidden_size = 128
346346
hparams.max_hidden_size = 4096
347-
hparams.bottleneck_size = 8192
348347
hparams.bottleneck_noise = 0.1
349348
hparams.dropout = 0.1
350349
hparams.residual_dropout = 0.4

tensor2tensor/models/research/basic_conv_gen.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,33 @@ def body(self, features):
4040
# Concat frames and down-stride.
4141
cur_frame = tf.to_float(features["inputs"])
4242
prev_frame = tf.to_float(features["inputs_prev"])
43-
frames = tf.concat([cur_frame, prev_frame], axis=-1)
44-
x = tf.layers.conv2d(frames, filters, kernel2, activation=tf.nn.relu,
45-
strides=(2, 2), padding="SAME")
43+
x = tf.concat([cur_frame, prev_frame], axis=-1)
44+
for _ in xrange(hparams.num_compress_steps):
45+
x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu,
46+
strides=(2, 2), padding="SAME")
47+
x = common_layers.layer_norm(x)
48+
filters *= 2
4649
# Add embedded action.
47-
action = tf.reshape(features["action"], [-1, 1, 1, filters])
48-
x = tf.concat([x, action + tf.zeros_like(x)], axis=-1)
50+
action = tf.reshape(features["action"], [-1, 1, 1, hparams.hidden_size])
51+
zeros = tf.zeros(common_layers.shape_list(x)[:-1] + [hparams.hidden_size])
52+
x = tf.concat([x, action + zeros], axis=-1)
4953

5054
# Run a stack of convolutions.
5155
for i in xrange(hparams.num_hidden_layers):
5256
with tf.variable_scope("layer%d" % i):
53-
y = tf.layers.conv2d(x, 2 * filters, kernel1, activation=tf.nn.relu,
57+
y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu,
5458
strides=(1, 1), padding="SAME")
5559
if i == 0:
5660
x = y
5761
else:
5862
x = common_layers.layer_norm(x + y)
5963
# Up-convolve.
60-
x = tf.layers.conv2d_transpose(
61-
x, filters, kernel2, activation=tf.nn.relu,
62-
strides=(2, 2), padding="SAME")
64+
for _ in xrange(hparams.num_compress_steps):
65+
filters //= 2
66+
x = tf.layers.conv2d_transpose(
67+
x, filters, kernel2, activation=common_layers.belu,
68+
strides=(2, 2), padding="SAME")
69+
x = common_layers.layer_norm(x)
6370

6471
# Reward prediction.
6572
reward_pred_h1 = tf.reduce_mean(x, axis=[1, 2], keep_dims=True)
@@ -78,7 +85,7 @@ def basic_conv():
7885
hparams = common_hparams.basic_params1()
7986
hparams.hidden_size = 64
8087
hparams.batch_size = 8
81-
hparams.num_hidden_layers = 2
88+
hparams.num_hidden_layers = 3
8289
hparams.optimizer = "Adam"
8390
hparams.learning_rate_constant = 0.0002
8491
hparams.learning_rate_warmup_steps = 500
@@ -87,6 +94,7 @@ def basic_conv():
8794
hparams.initializer = "uniform_unit_scaling"
8895
hparams.initializer_gain = 1.0
8996
hparams.weight_decay = 0.0
97+
hparams.add_hparam("num_compress_steps", 2)
9098
return hparams
9199

92100

0 commit comments

Comments
 (0)