Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b31b3ae

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Play more with VAE, small corrections elsewhere.
PiperOrigin-RevId: 165031077
1 parent d1f9bb2 commit b31b3ae

File tree

5 files changed

+98
-53
lines changed

5 files changed

+98
-53
lines changed

tensor2tensor/layers/modalities.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,11 @@ def top(self, body_output, _):
406406
# Assume input is a square with self._body_input_depth channels.
407407
if self._is_2d:
408408
length_float = tf.to_float(tf.shape(x)[1])
409+
length_float *= tf.to_float(tf.shape(x)[2])
409410
spatial_dim_float = tf.sqrt(length_float)
410411
spatial_dim = tf.to_int32(spatial_dim_float)
411-
x = tf.reshape(x,
412-
[-1, spatial_dim, spatial_dim, self._body_input_depth])
412+
x_depth = int(x.get_shape()[3])
413+
x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth])
413414
x = common_layers.conv_block_downsample(x, self._kernel, self._strides,
414415
self._padding)
415416
x = tf.nn.relu(x)

tensor2tensor/models/cycle_gan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def discriminator(x, compress, hparams, name, reuse=None):
3939
with tf.variable_scope(name, reuse=reuse):
4040
x = tf.stop_gradient(2 * x) - x # Reverse gradient.
4141
if compress:
42-
x = transformer_vae.compress(x, hparams, "compress")
42+
x = transformer_vae.compress(x, None, hparams, "compress")
4343
else:
4444
x = transformer_vae.residual_conv(x, 1, hparams, "compress_rc")
4545
y = tf.reduce_mean(x, axis=1)
@@ -144,12 +144,12 @@ def cycle_vae_gan_internal(inputs, targets, _, hparams):
144144

145145
# Input-input part.
146146
inp1_back, kl_loss1, inp1_mu, inp1_log_sigma = transformer_vae.vae_compress(
147-
inputs1, hparams, "inp2hyp", "hyp2inp")
147+
inputs1, None, hparams, "inp2hyp", "hyp2inp")
148148
inp1_hyp = tf.concat([inp1_mu, inp1_log_sigma], axis=3)
149149

150150
# Target-target part.
151151
tgt2_back, kl_loss2, tgt2_mu, tgt2_log_sigma = transformer_vae.vae_compress(
152-
targets2, hparams, "tgt2hyp", "hyp2tgt")
152+
targets2, None, hparams, "tgt2hyp", "hyp2tgt")
153153
tgt2_hyp = tf.concat([tgt2_mu, tgt2_log_sigma], axis=3)
154154

155155
# Reconstruction losses.
@@ -165,16 +165,16 @@ def cycle_vae_gan_internal(inputs, targets, _, hparams):
165165

166166
# Reconstruct targets from inputs.
167167
tgt, _, _, _ = transformer_vae.vae_compress(
168-
inputs, hparams, "inp2hyp", "hyp2tgt", reuse=True)
168+
inputs, None, hparams, "inp2hyp", "hyp2tgt", reuse=True)
169169
tgt = tf.layers.dense(tgt, hparams.vocab_size, name="softmax", reuse=True)
170170
# We use the reconstruction only for tracking progress, no gradients here!
171171
tgt = tf.stop_gradient(tf.expand_dims(tgt, axis=2))
172172

173173
kl_rev_decay = common_layers.inverse_exp_decay(hparams.kl_warmup_steps)
174174
losses = {"input_input": hparams.cycle_loss_multiplier * inp1_loss,
175175
"target_target": hparams.cycle_loss_multiplier * tgt2_loss,
176-
"input_kl": kl_loss1 * kl_rev_decay,
177-
"target_kl": kl_loss2 * kl_rev_decay,
176+
"input_kl": kl_loss1 * kl_rev_decay * 15.0,
177+
"target_kl": kl_loss2 * kl_rev_decay * 15.0,
178178
"discriminator": dloss}
179179
return tgt, losses
180180

@@ -196,7 +196,7 @@ def cycle_gan_small():
196196
hparams.input_modalities = "inputs:symbol:identity"
197197
hparams.target_modality = "symbol:identity"
198198
hparams.weight_decay = 3.0
199-
hparams.learning_rate = 0.005
199+
hparams.learning_rate = 0.05
200200
hparams.kl_warmup_steps = 5000
201201
hparams.learning_rate_warmup_steps = 3000
202202
hparams.add_hparam("vocab_size", 32) # Vocabulary size, need to set here.

tensor2tensor/models/shake_shake.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ class ShakeShake(t2t_model.T2TModel):
100100

101101
def model_fn_body(self, features):
102102
hparams = self._hparams
103-
print(hparams.learning_rate)
104-
105103
inputs = features["inputs"]
106104
assert (hparams.num_hidden_layers - 2) % 6 == 0
107105
blocks_per_stage = (hparams.num_hidden_layers - 2) // 6

tensor2tensor/models/transformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ def transformer_decoder(decoder_input,
244244
hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
245245
x = common_layers.layer_postprocess(x, y, hparams)
246246
if encoder_output is not None:
247-
assert encoder_decoder_attention_bias is not None
248247
with tf.variable_scope("encdec_attention"):
249248
y = common_attention.multihead_attention(
250249
common_layers.layer_preprocess(

tensor2tensor/models/transformer_vae.py

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from six.moves import xrange # pylint: disable=redefined-builtin
2525

26+
from tensor2tensor.layers import common_attention
2627
from tensor2tensor.layers import common_layers
2728
from tensor2tensor.models import transformer
2829
from tensor2tensor.utils import registry
@@ -49,13 +50,43 @@ def residual_conv(x, repeat, hparams, name, reuse=None):
4950
return x
5051

5152

52-
def decompress_step(source, hparams, first_relu, name):
53+
def attend(x, source, hparams, name):
54+
with tf.variable_scope(name):
55+
x = tf.squeeze(x, axis=2)
56+
if len(source.get_shape()) > 3:
57+
source = tf.squeeze(source, axis=2)
58+
source = common_attention.add_timing_signal_1d(source)
59+
y = common_attention.multihead_attention(
60+
common_layers.layer_preprocess(x, hparams), source, None,
61+
hparams.attention_key_channels or hparams.hidden_size,
62+
hparams.attention_value_channels or hparams.hidden_size,
63+
hparams.hidden_size, hparams.num_heads,
64+
hparams.attention_dropout)
65+
res = common_layers.layer_postprocess(x, y, hparams)
66+
return tf.expand_dims(res, axis=2)
67+
68+
69+
def interleave(x, y, axis=1):
70+
x = tf.expand_dims(x, axis=axis+1)
71+
y = tf.expand_dims(y, axis=axis+1)
72+
return tf.concat([x, y], axis=axis+1)
73+
74+
75+
def decompress_step(source, c, hparams, first_relu, name):
5376
"""Decompression function."""
5477
with tf.variable_scope(name):
5578
shape = tf.shape(source)
56-
thicker = common_layers.conv_block(
57-
source, hparams.hidden_size * 2, [((1, 1), (1, 1))],
58-
first_relu=first_relu, name="decompress_conv")
79+
if c is not None:
80+
source = attend(source, c, hparams, "decompress_attend")
81+
first = common_layers.conv_block(
82+
source,
83+
hparams.hidden_size, [((1, 1), (3, 1)), ((1, 1), (3, 1))],
84+
first_relu=first_relu, padding="SAME", name="decompress_conv1")
85+
second = common_layers.conv_block(
86+
tf.concat([source, first], axis=3),
87+
hparams.hidden_size, [((1, 1), (3, 1)), ((1, 1), (3, 1))],
88+
first_relu=first_relu, padding="SAME", name="decompress_conv2")
89+
thicker = interleave(first, second)
5990
return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
6091

6192

@@ -71,23 +102,25 @@ def vae(x, hparams, name):
71102
return z, tf.reduce_mean(kl), mu, log_sigma
72103

73104

74-
def compress(inputs, hparams, name):
105+
def compress(x, c, hparams, name):
75106
"""Compress."""
76107
with tf.variable_scope(name):
77108
# Run compression by strided convs.
78-
cur = inputs
109+
cur = x
79110
for i in xrange(hparams.num_compress_steps):
111+
if c is not None:
112+
cur = attend(cur, c, hparams, "compress_attend_%d" % i)
80113
cur = residual_conv(cur, 1, hparams, "compress_rc_%d" % i)
81114
cur = common_layers.conv_block(
82115
cur, hparams.hidden_size, [((1, 1), (2, 1))],
83116
strides=(2, 1), name="compress_%d" % i)
84117
return cur
85118

86119

87-
def vae_compress(inputs, hparams, compress_name, decompress_name, reuse=None):
120+
def vae_compress(x, c, hparams, compress_name, decompress_name, reuse=None):
88121
"""Compress, then VAE."""
89122
with tf.variable_scope(compress_name, reuse=reuse):
90-
cur = compress(inputs, hparams, "compress")
123+
cur = compress(x, c, hparams, "compress")
91124
# Convolve and ReLu to get state.
92125
cur = common_layers.conv_block(
93126
cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
@@ -100,7 +133,7 @@ def vae_compress(inputs, hparams, compress_name, decompress_name, reuse=None):
100133
for i in xrange(hparams.num_compress_steps):
101134
j = hparams.num_compress_steps - i - 1
102135
z = residual_conv(z, 1, hparams, "decompress_rc_%d" % j)
103-
z = decompress_step(z, hparams, i > 0, "decompress__step_%d" % j)
136+
z = decompress_step(z, c, hparams, i > 0, "decompress__step_%d" % j)
104137
return z, kl_loss, mu, log_sigma
105138

106139

@@ -124,6 +157,13 @@ def dropmask(targets, targets_dropout_max, is_training):
124157
return targets * keep_mask
125158

126159

160+
def ffn(x, hparams, name):
161+
with tf.variable_scope(name):
162+
y = transformer.transformer_ffn_layer(
163+
common_layers.layer_preprocess(x, hparams), hparams)
164+
return common_layers.layer_postprocess(x, y, hparams)
165+
166+
127167
def vae_transformer_internal(inputs, targets, target_space, hparams):
128168
"""VAE Transformer, main step used for training."""
129169
with tf.variable_scope("vae_transformer"):
@@ -140,36 +180,40 @@ def vae_transformer_internal(inputs, targets, target_space, hparams):
140180
inputs = encode(inputs, target_space, hparams, "input_enc")
141181

142182
# Dropout targets or swap for zeros 5% of the time.
183+
targets_nodrop = targets
143184
max_prestep = hparams.kl_warmup_steps
144185
prob_targets = 0.95 if is_training else 1.0
145186
targets_dropout_max = common_layers.inverse_lin_decay(max_prestep) - 0.01
146187
targets = dropmask(targets, targets_dropout_max * 0.7, is_training)
147188
targets = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
148189
lambda: targets, lambda: tf.zeros_like(targets))
149-
150-
# Join targets with inputs, run encoder.
151-
# to_encode = common_layers.conv_block(
152-
# tf.expand_dims(tf.concat([targets, inputs], axis=2), axis=2),
153-
# hparams.hidden_size, [((1, 1), (1, 1))],
154-
# first_relu=False, name="join_targets")
155-
# to_compress = encode(tf.squeeze(to_encode, axis=2),
156-
# target_space, hparams, "enc")
190+
targets = targets_nodrop
157191

158192
# Compress and vae.
159-
z, kl_loss, _, _ = vae_compress(tf.expand_dims(targets, axis=2), hparams,
160-
"vae_compress", "vae_decompress")
193+
z = tf.get_variable("z", [hparams.hidden_size])
194+
z = tf.reshape(z, [1, 1, 1, -1])
195+
z = tf.tile(z, [tf.shape(inputs)[0], 1, 1, 1])
196+
197+
z = attend(z, inputs, hparams, "z_attendsi")
198+
z = ffn(z, hparams, "zff2")
199+
z = attend(z, targets, hparams, "z_attendst2")
200+
z = ffn(z, hparams, "zff3")
201+
z, kl_loss, _, _ = vae(z, hparams, name="vae")
202+
z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")
203+
204+
# z, kl_loss, _, _ = vae_compress(
205+
# tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2),
206+
# hparams, "vae_compress", "vae_decompress")
161207

162-
# Join z with inputs, run decoder.
163-
to_decode = common_layers.conv_block(
164-
tf.concat([z, tf.expand_dims(inputs, axis=2)], axis=3),
165-
hparams.hidden_size, [((1, 1), (1, 1))], name="join_z")
166-
ret = encode(tf.squeeze(to_decode, axis=2), target_space, hparams, "dec")
167-
# to_decode = residual_conv(to_decode, 2, hparams, "dec_conv")
168-
# ret = tf.squeeze(to_decode, axis=2)
208+
decoder_in = tf.squeeze(z, axis=2) + tf.zeros_like(targets)
209+
(decoder_input, decoder_self_attention_bias) = (
210+
transformer.transformer_prepare_decoder(decoder_in, hparams))
211+
ret = transformer.transformer_decoder(
212+
decoder_input, inputs, decoder_self_attention_bias, None, hparams)
169213

170-
# Randomize decoder inputs..
171-
kl_loss *= common_layers.inverse_exp_decay(max_prestep) * 10.0
172-
return tf.expand_dims(ret, axis=2), kl_loss
214+
kl_loss *= common_layers.inverse_exp_decay(int(max_prestep * 1.5)) * 5.0
215+
losses = {"kl": kl_loss}
216+
return tf.expand_dims(ret, axis=2), losses
173217

174218

175219
@registry.register_model
@@ -203,13 +247,15 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
203247
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
204248
samples = tf.concat(sharded_samples, 0)
205249

206-
# 2nd step.
207-
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
208-
features["targets"] = samples
209-
sharded_logits, _ = self.model_fn(
210-
features, False, last_position_only=last_position_only)
211-
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
212-
samples = tf.concat(sharded_samples, 0)
250+
# More steps.
251+
how_many_more_steps = 20
252+
for _ in xrange(how_many_more_steps):
253+
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
254+
features["targets"] = samples
255+
sharded_logits, _ = self.model_fn(
256+
features, False, last_position_only=last_position_only)
257+
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
258+
samples = tf.concat(sharded_samples, 0)
213259

214260
if inputs_old is not None: # Restore to not confuse Estimator.
215261
features["inputs"] = inputs_old
@@ -221,9 +267,10 @@ def transformer_vae_small():
221267
"""Set of hyperparameters."""
222268
hparams = transformer.transformer_small()
223269
hparams.batch_size = 2048
270+
hparams.learning_rate_warmup_steps = 16000
224271
hparams.add_hparam("z_size", 128)
225272
hparams.add_hparam("num_compress_steps", 4)
226-
hparams.add_hparam("kl_warmup_steps", 50000)
273+
hparams.add_hparam("kl_warmup_steps", 60000)
227274
return hparams
228275

229276

@@ -233,9 +280,9 @@ def transformer_vae_base():
233280
hparams = transformer_vae_small()
234281
hparams.hidden_size = 512
235282
hparams.filter_size = 2048
236-
hparams.attention_dropout = 0.1
237-
hparams.relu_dropout = 0.1
238-
hparams.dropout = 0.1
239-
hparams.num_hidden_layers = 4
283+
hparams.attention_dropout = 0.0
284+
hparams.relu_dropout = 0.0
285+
hparams.dropout = 0.0
286+
hparams.num_hidden_layers = 3
240287
hparams.z_size = 256
241288
return hparams

0 commit comments

Comments
 (0)