13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- """VAE Transformer."""
16
+ """AE Transformer."""
17
17
18
18
from __future__ import absolute_import
19
19
from __future__ import division
32
32
import tensorflow as tf
33
33
34
34
35
- def residual_conv (x , repeat , hparams , name , reuse = None ):
35
+ def residual_conv (x , repeat , k , hparams , name , reuse = None ):
36
36
"""A stack of convolution blocks with residual connections."""
37
37
with tf .variable_scope (name , reuse = reuse ):
38
- k = (3 , 1 )
39
38
dilations_and_kernels = [((1 , 1 ), k ) for _ in xrange (3 )]
40
39
for i in xrange (repeat ):
41
40
with tf .variable_scope ("repeat_%d" % i ):
@@ -72,15 +71,19 @@ def interleave(x, y, axis=1):
72
71
return tf .concat ([x , y ], axis = axis + 1 )
73
72
74
73
75
- def decompress_step (source , c , hparams , first_relu , name ):
74
+ def decompress_step (source , c , hparams , first_relu , is_2d , name ):
76
75
"""Decompression function."""
77
76
with tf .variable_scope (name ):
78
77
shape = tf .shape (source )
79
78
if c is not None :
80
79
source = attend (source , c , hparams , "decompress_attend" )
80
+ multiplier = 4 if is_2d else 2
81
+ kernel = (1 , 1 ) if is_2d else (1 , 1 )
81
82
thicker = common_layers .conv_block (
82
- source , hparams .hidden_size * 2 , [((1 , 1 ), ( 1 , 1 ) )],
83
+ source , hparams .hidden_size * multiplier , [((1 , 1 ), kernel )],
83
84
first_relu = first_relu , name = "decompress_conv" )
85
+ if is_2d :
86
+ return tf .depth_to_space (thicker , 2 )
84
87
return tf .reshape (thicker , [shape [0 ], shape [1 ] * 2 , 1 , hparams .hidden_size ])
85
88
86
89
@@ -90,7 +93,7 @@ def gumbel_sample(shape):
90
93
return - tf .log (- tf .log (uniform_samples ))
91
94
92
95
93
- def dvae (x , hparams , name ):
96
+ def dae (x , hparams , name ):
94
97
with tf .variable_scope (name ):
95
98
m = tf .layers .dense (x , hparams .v_size , name = "mask" )
96
99
logsm = tf .nn .log_softmax (m )
@@ -128,7 +131,7 @@ def nearest(x, means, hparams):
128
131
_ , nearest_idx = tf .nn .top_k (- dist , k = 1 )
129
132
nearest_hot = tf .one_hot (tf .squeeze (nearest_idx , axis = 1 ), hparams .v_size )
130
133
nearest_hot = tf .reshape (nearest_hot , [tf .shape (x )[0 ], tf .shape (x )[1 ],
131
- 1 , hparams .v_size ])
134
+ tf . shape ( x )[ 2 ] , hparams .v_size ])
132
135
return tf .stop_gradient (nearest_hot )
133
136
134
137
@@ -137,21 +140,23 @@ def kmeans(x, means, hparams, name):
137
140
x_means_hot = nearest (x , means , hparams )
138
141
x_means = tf .gather (means , tf .argmax (x_means_hot , axis = - 1 ))
139
142
kl = tf .reduce_sum (tf .square (x - x_means ), axis = - 1 )
140
- return x_means_hot , x_means_hot , tf .reduce_mean (kl ) * 10.0
143
+ return x_means_hot , tf .reduce_mean (kl ) * 10.0
141
144
142
145
143
- def compress (x , c , hparams , name ):
146
+ def compress (x , c , is_2d , hparams , name ):
144
147
"""Compress."""
145
148
with tf .variable_scope (name ):
146
149
# Run compression by strided convs.
147
150
cur = x
151
+ k1 = (3 , 3 ) if is_2d else (3 , 1 )
152
+ k2 = (2 , 2 ) if is_2d else (2 , 1 )
148
153
for i in xrange (hparams .num_compress_steps ):
149
154
if c is not None :
150
155
cur = attend (cur , c , hparams , "compress_attend_%d" % i )
151
- cur = residual_conv (cur , 1 , hparams , "compress_rc_%d" % i )
156
+ cur = residual_conv (cur , 1 , k1 , hparams , "compress_rc_%d" % i )
152
157
cur = common_layers .conv_block (
153
- cur , hparams .hidden_size , [((1 , 1 ), ( 2 , 1 ) )],
154
- strides = ( 2 , 1 ) , name = "compress_%d" % i )
158
+ cur , hparams .hidden_size , [((1 , 1 ), k2 )],
159
+ strides = k2 , name = "compress_%d" % i )
155
160
return cur
156
161
157
162
@@ -188,7 +193,7 @@ def decode(cond_vec, cond_add, gold, c, ed, hparams):
188
193
decoder_input = tf .squeeze (decoder_input , axis = 2 )
189
194
decoder_input = common_attention .add_timing_signal_1d (decoder_input )
190
195
bias = common_attention .attention_bias_lower_triangle (tf .shape (gold )[1 ])
191
- if c is not None :
196
+ if c is not None and len ( c . get_shape ()) > 3 :
192
197
c = tf .squeeze (c , axis = 2 )
193
198
return transformer .transformer_decoder (decoder_input , c , bias , ed , hparams )
194
199
@@ -205,69 +210,62 @@ def expand_batch(x, mul):
205
210
return tf .reshape (cx , res_shape )
206
211
207
212
208
- def vae_compress (x , c , ed , hparams , compress_name , decompress_name , reuse = None ):
209
- """Compress, then VAE ."""
210
- with tf .variable_scope (compress_name , reuse = reuse ):
211
- cur = compress (x , None , hparams , "compress" )
213
+ def ae_compress (x , is_2d , hparams , name , reuse = None ):
214
+ """Compress, then AE ."""
215
+ with tf .variable_scope (name , reuse = reuse ):
216
+ cur = compress (x , None , is_2d , hparams , "compress" )
212
217
# Convolve and ReLu to get state.
213
218
cur = common_layers .conv_block (
214
219
cur , hparams .hidden_size , [((1 , 1 ), (1 , 1 ))], name = "mid_conv" )
215
220
cur = tf .nn .l2_normalize (cur , dim = 3 )
216
221
cur_n = hparams .kmeans_lr_factor * cur
217
222
cur_n += (1.0 - hparams .kmeans_lr_factor ) * tf .stop_gradient (cur )
218
223
means = tf .get_variable ("z_to_dense" , [hparams .v_size , hparams .hidden_size ])
219
- # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
220
- # z_true, z_sample, kl_loss = dvae(cur, hparams, name="dvae")
221
- z_true , z_sample , kl_loss = kmeans (cur_n , means , hparams , name = "kmeans" )
222
-
223
- # Compress context.
224
- with tf .variable_scope (compress_name , reuse = reuse ):
225
- compress_c = compress (c , None , hparams , "compress_context" )
226
- dec_c = decode (None , compress_c , cur , None , None , hparams )
227
- c_z = tf .layers .dense (dec_c , hparams .v_size , name = "mask_context" )
228
- reconstruct_loss = tf .nn .softmax_cross_entropy_with_logits (
229
- labels = z_true , logits = c_z )
224
+ hot , loss = kmeans (cur_n , means , hparams , name = "kmeans" )
225
+ # We need a linear layer to undo the l2-normalization.
226
+ cur = tf .layers .dense (cur , hparams .hidden_size , name = "unnormalize" )
227
+ return cur , hot , loss
230
228
231
- # If not training, use the predicted z instead of the autoregressive one.
232
- if hparams .mode == tf .contrib .learn .ModeKeys .INFER :
233
- z = tf .one_hot (tf .argmax (c_z , axis = - 1 ), hparams .v_size )
234
229
235
- with tf .variable_scope (decompress_name , reuse = reuse ):
236
- # Decompress.
237
- z_sample_flat = tf .reshape (z_sample , [- 1 , hparams .v_size ])
238
- z = tf .matmul (z_sample_flat , means )
239
- z = tf .reshape (z , [tf .shape (z_sample )[0 ], tf .shape (z_sample )[1 ],
240
- 1 , hparams .hidden_size ])
230
+ def ae_embed (hot , hparams , name , reuse = None ):
231
+ with tf .variable_scope (name , reuse = reuse ):
232
+ means = tf .get_variable ("z_to_dense" , [hparams .v_size , hparams .hidden_size ])
233
+ hot_flat = tf .reshape (hot , [- 1 , hparams .v_size ])
234
+ emb = tf .matmul (hot_flat , means )
235
+ emb = tf .reshape (emb , [tf .shape (hot )[0 ], tf .shape (hot )[1 ],
236
+ tf .shape (hot )[2 ], hparams .hidden_size ])
237
+ return tf .layers .dense (emb , hparams .hidden_size ,
238
+ name = "unnormalize" , reuse = reuse )
239
+
241
240
241
+ def ae_decompress (z , ae , x , is_2d , hparams , name , reuse = None ):
242
+ """Decompress from z, leaking from ae."""
243
+ with tf .variable_scope (name + "_decompress" , reuse = reuse ):
242
244
# Leak at the beginning to help train.
243
- z = mix (z , cur , hparams .startup_steps )
245
+ z = mix (z , ae , hparams .startup_steps )
244
246
prob_z = common_layers .inverse_exp_decay (hparams .startup_steps ) * 0.8
245
- prob_z = prob_z if hparams .mode == tf .contrib .learn .ModeKeys .TRAIN else 0 .0
247
+ prob_z = prob_z if hparams .mode == tf .contrib .learn .ModeKeys .TRAIN else 1 .0
246
248
z = tf .cond (tf .less (tf .random_uniform ([]), prob_z ),
247
- lambda : z , lambda : cur )
248
- z = tf .layers .dense (z , hparams .hidden_size , name = "unnormalize" )
249
+ lambda : z , lambda : ae )
249
250
250
251
# Dropout for better autoencoding.
251
- z = tf .nn .dropout (z , keep_prob = 0.9 )
252
+ z = tf .nn .dropout (z , keep_prob = 1.0 - hparams . z_dropout )
252
253
253
254
# Decompress.
254
255
d = z
255
256
for i in xrange (hparams .num_compress_steps ):
256
257
j = hparams .num_compress_steps - i - 1
257
- d = residual_conv (d , 1 , hparams , "decompress_rc_%d" % j )
258
- d = decompress_step (d , c , hparams , i > 0 , "decompress_step_ %d" % j )
258
+ d = residual_conv (d , 1 , ( 3 , 1 ), hparams , "decompress_rc_%d" % j )
259
+ d = decompress_step (d , None , hparams , i > 0 , is_2d , "decompress_ %d" % j )
259
260
260
261
k = 2 ** hparams .num_compress_steps
261
262
z_batch = tf .reshape (z , [- 1 , 1 , 1 , hparams .hidden_size ])
262
263
x_batch = tf .reshape (x , [- 1 , k , 1 , hparams .hidden_size ])
263
264
d_batch = tf .reshape (d , [- 1 , k , 1 , hparams .hidden_size ])
264
- # dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams)
265
- c = expand_batch (c , tf .shape (x_batch )[0 ] / tf .shape (x )[0 ])
266
- ed = expand_batch (ed , tf .shape (x_batch )[0 ] / tf .shape (x )[0 ])
267
- dec_batch = decode (z_batch , d_batch , x_batch , c , ed , hparams )
265
+ dec_batch = decode (z_batch , d_batch , x_batch , None , None , hparams )
268
266
z = tf .reshape (dec_batch , [- 1 , tf .shape (x )[1 ], 1 , hparams .hidden_size ])
269
267
270
- return z , kl_loss , reconstruct_loss
268
+ return z
271
269
272
270
273
271
def ffn (x , hparams , name ):
@@ -277,35 +275,42 @@ def ffn(x, hparams, name):
277
275
return common_layers .layer_postprocess (x , y , hparams )
278
276
279
277
280
- def vae_transformer_internal (inputs , targets , target_space , hparams ):
281
- """VAE Transformer, main step used for training."""
282
- with tf .variable_scope ("vae_transformer" ):
283
- # Prepare inputs, targets, and k.
284
- inputs = common_layers .flatten4d3d (inputs )
285
- input_len = tf .shape (inputs )[1 ] # Double input size to cover targets.
286
- inputs = tf .pad (inputs , [[0 , 0 ], [0 , input_len ], [0 , 0 ]])
287
- inputs .set_shape ([None , None , hparams .hidden_size ])
288
- targets = common_layers .flatten4d3d (targets )
278
+ def ae_transformer_internal (inputs , targets , target_space , hparams ):
279
+ """AE Transformer, main step used for training."""
280
+ with tf .variable_scope ("ae_transformer" ):
281
+ # Prepare inputs, targets, k.
289
282
k = 2 ** hparams .num_compress_steps
290
- inputs , targets = common_layers .pad_to_same_length (
291
- inputs , targets , final_length_divisible_by = k )
292
- inputs , ed_bias = encode (inputs , target_space , hparams , "input_enc" )
293
-
294
- # Compress and vae.
295
- z , kl , r = vae_compress (tf .expand_dims (targets , axis = 2 ),
296
- tf .expand_dims (inputs , axis = 2 ),
297
- ed_bias , hparams , "vae_compress" , "vae_decompress" )
283
+ _ , targets = common_layers .pad_to_same_length (
284
+ targets , targets , final_length_divisible_by = k )
285
+ inputs = common_layers .flatten4d3d (inputs )
286
+ inputs , ed = encode (inputs , target_space , hparams , "input_enc" )
287
+
288
+ # Compress and ae.
289
+ ae , hot , kl = ae_compress (targets , False , hparams , "ae" )
290
+ emb = ae_embed (hot , hparams , "ae" , reuse = True )
291
+
292
+ # Compress context and run autoregressive decoder on emb-hot.
293
+ dec_c = decode (None , None , emb , inputs , ed , hparams )
294
+ c_z = tf .layers .dense (dec_c , hparams .v_size , name = "mask_context" )
295
+ reconstruct_loss = tf .nn .softmax_cross_entropy_with_logits (
296
+ labels = hot , logits = c_z )
297
+ # If not training, use the predicted z instead of the autoregressive one.
298
+ if hparams .mode == tf .contrib .learn .ModeKeys .INFER :
299
+ hot = tf .one_hot (tf .argmax (c_z , axis = - 1 ), hparams .v_size )
300
+
301
+ # Decompress, pass for ae loss.
302
+ z = ae_decompress (emb , ae , targets , False , hparams , "ae" )
298
303
kl *= common_layers .inverse_exp_decay (int (hparams .startup_steps * 0.5 ))
299
- r *= common_layers .inverse_exp_decay (int ( hparams .startup_steps * 0.5 ) )
300
- losses = {"kl" : kl , "reconstruction" : r }
304
+ reconstruct_loss *= common_layers .inverse_exp_decay (hparams .startup_steps )
305
+ losses = {"kl" : kl , "reconstruction" : reconstruct_loss }
301
306
return z , losses
302
307
303
308
304
309
@registry .register_model
305
- class TransformerVAE (t2t_model .T2TModel ):
310
+ class TransformerAE (t2t_model .T2TModel ):
306
311
307
312
def model_fn_body (self , features ):
308
- return vae_transformer_internal (
313
+ return ae_transformer_internal (
309
314
features ["inputs" ], features ["targets" ], features ["target_space_id" ],
310
315
self ._hparams )
311
316
@@ -348,7 +353,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
348
353
349
354
350
355
@registry .register_hparams
351
- def transformer_vae_small ():
356
+ def transformer_ae_small ():
352
357
"""Set of hyperparameters."""
353
358
hparams = transformer .transformer_small ()
354
359
hparams .batch_size = 2048
@@ -358,19 +363,20 @@ def transformer_vae_small():
358
363
hparams .add_hparam ("num_compress_steps" , 4 )
359
364
hparams .add_hparam ("kl_warmup_steps" , 60000 )
360
365
hparams .add_hparam ("startup_steps" , 30000 )
366
+ hparams .add_hparam ("kmeans_lr_factor" , 0.002 )
367
+ hparams .add_hparam ("z_dropout" , 0.1 )
361
368
return hparams
362
369
363
370
364
371
@registry .register_hparams
365
- def transformer_vae_base ():
372
+ def transformer_ae_base ():
366
373
"""Set of hyperparameters."""
367
- hparams = transformer_vae_small ()
374
+ hparams = transformer_ae_small ()
368
375
hparams .hidden_size = 512
369
376
hparams .filter_size = 2048
370
377
hparams .attention_dropout = 0.0
371
378
hparams .relu_dropout = 0.0
372
379
hparams .dropout = 0.0
373
380
hparams .num_hidden_layers = 4
374
- hparams .kmeans_lr_factor = 0.002
375
381
hparams .z_size = 256
376
382
return hparams
0 commit comments