@@ -324,38 +324,40 @@ def multinomial_sample(x, vocab_size, temperature):
324
324
return tf .to_int32 (reshaped_samples )
325
325
326
326
327
- def ae_latent_sample (t_c , inputs , ed , embed , iters , hparams ):
327
+ def ae_latent_sample (latents_dense , inputs , ed , embed , iters , hparams ):
328
328
"""Sample from the latent space in the autoencoder."""
329
- t_pred = decode_transformer (inputs , ed , t_c , hparams , "extra" )
330
- t_pred = tf .layers .dense (t_pred , 2 ** 16 , name = "extra_logits" )
331
- t_bit = multinomial_sample (t_pred , 2 ** 16 , hparams .sampling_temp )
329
+ latents_pred = decode_transformer (inputs , ed , latents_dense , hparams , "extra" )
330
+ latents_pred = tf .layers .dense (latents_pred , 2 ** 16 , name = "extra_logits" )
331
+ latents_discrete = multinomial_sample (
332
+ latents_pred , 2 ** 16 , hparams .sampling_temp )
332
333
333
- def next_bit (t_bit , i ):
334
- t_bit_prev = t_bit
334
+ def next_bit (latents_discrete , i ):
335
+ latents_discrete_prev = latents_discrete
335
336
with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
336
- t_c = embed (t_bit )
337
- t_pred = decode_transformer (inputs , ed , t_c , hparams , "extra" )
338
- t_pred = tf .layers .dense (t_pred , 2 ** 16 , name = "extra_logits" )
339
- t_bit = multinomial_sample (t_pred , 2 ** 16 , hparams .sampling_temp )
340
- return tf .concat ([t_bit_prev [:, :(i + 1 ), :],
341
- t_bit [:, (i + 1 ):, :]], axis = 1 )
337
+ latents_dense = embed (latents_discrete )
338
+ latents_pred = decode_transformer (
339
+ inputs , ed , latents_dense , hparams , "extra" )
340
+ latents_pred = tf .layers .dense (latents_pred , 2 ** 16 , name = "extra_logits" )
341
+ latents_discrete = multinomial_sample (
342
+ latents_pred , 2 ** 16 , hparams .sampling_temp )
343
+ return tf .concat ([latents_discrete_prev [:, :(i + 1 ), :],
344
+ latents_discrete [:, (i + 1 ):, :]], axis = 1 )
342
345
343
346
for i in xrange (iters ):
344
- t_bit = next_bit (t_bit , i )
345
- return t_bit
347
+ latents_discrete = next_bit (latents_discrete , i )
348
+ return latents_discrete
346
349
347
350
348
351
def ae_transformer_internal (inputs , targets , target_space , hparams ,
349
- beam_size , cache = None , predict_mask = 1.0 ):
352
+ cache = None , predict_mask = 1.0 ):
350
353
"""AE Transformer, main step used for training."""
351
354
# Summaries break with the do_refine cond, turn them off in that case.
352
355
global _DO_SUMMARIES
353
356
if hparams .do_refine :
354
357
_DO_SUMMARIES = False
355
358
356
359
# Prepare.
357
- orig_targets = targets
358
- batch_size = common_layers .shape_list (orig_targets )[0 ]
360
+ batch_size = common_layers .shape_list (inputs )[0 ]
359
361
targets = tf .reshape (targets , [batch_size , - 1 , 1 , hparams .hidden_size ])
360
362
361
363
# Encoder.
@@ -375,22 +377,24 @@ def ae_transformer_internal(inputs, targets, target_space, hparams,
375
377
targets_c = compress (targets , False , hparams , "compress" )
376
378
if hparams .mode != tf .estimator .ModeKeys .PREDICT :
377
379
# Compress and bottleneck.
378
- t_c , t_bit , vc_loss , _ = bottleneck (targets_c , hparams , 2 * 2048 , "vc" )
380
+ latents_dense , latents_discrete , extra_loss , _ = bottleneck (
381
+ targets_c , hparams , 2 * 2048 , "vc" )
379
382
if _DO_SUMMARIES :
380
- tf .summary .histogram ("bit0 " , tf .reshape (t_bit [:, 0 , :], [- 1 ]))
383
+ tf .summary .histogram ("b0 " , tf .reshape (latents_discrete [:, 0 , :], [- 1 ]))
381
384
pc = common_layers .inverse_exp_decay (hparams .startup_steps ) * 0.95
382
385
pc = pc if hparams .mode == tf .estimator .ModeKeys .TRAIN else 1.0
383
386
cond = tf .less (tf .random_uniform ([batch_size ]), pc )
384
- t_c = tf .where (cond , t_c , targets_c )
387
+ latents_dense = tf .where (cond , latents_dense , targets_c )
385
388
# TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
386
- losses ["extra" ] = vc_loss * tf .reduce_mean (tf .to_float (cond ))
389
+ losses ["extra" ] = extra_loss * tf .reduce_mean (tf .to_float (cond ))
387
390
# Extra loss predicting latent code from input. Discrete only.
388
391
if hparams .bottleneck_kind not in ["dense" , "vae" ]:
389
- t_pred = decode_transformer (
390
- inputs , ed , tf .stop_gradient (t_c ), hparams , "extra" )
391
- t_pred = tf .layers .dense (t_pred , 2 ** 16 , name = "extra_logits" )
392
+ latents_pred = decode_transformer (
393
+ tf .stop_gradient (inputs ), tf .stop_gradient (ed ),
394
+ tf .stop_gradient (latents_dense ), hparams , "extra" )
395
+ latents_pred = tf .layers .dense (latents_pred , 2 ** 16 , name = "extra_logits" )
392
396
losses ["latent_pred" ] = tf .nn .sparse_softmax_cross_entropy_with_logits (
393
- labels = t_bit , logits = t_pred )
397
+ labels = latents_discrete , logits = latents_pred )
394
398
losses ["latent_pred" ] = tf .reduce_mean (
395
399
losses ["latent_pred" ] * 0.5 * tf .to_float (cond ))
396
400
else :
@@ -405,27 +409,25 @@ def bn_inputs():
405
409
bn_inputs , lambda : inputs_c )
406
410
ptc = 1.0 - common_layers .inverse_lin_decay (200000 ) * 0.5
407
411
ptc = ptc if hparams .mode == tf .estimator .ModeKeys .TRAIN else 1.0
408
- t_c = tf .where (tf .less (tf .random_uniform ([batch_size ]), ptc ),
409
- t_c , inputs_c )
412
+ latents_dense = tf .where (tf .less (tf .random_uniform ([batch_size ]), ptc ),
413
+ latents_dense , inputs_c )
410
414
else :
411
415
if hparams .bottleneck_kind in ["dense" , "vae" ]:
412
416
inputs_c = decode_transformer (inputs , ed , targets_c , hparams , "dec_c" )
413
- t_c , _ , _ , _ = bottleneck (inputs_c , hparams , 2 * 2048 , "vc" )
417
+ latents_dense , _ , _ , _ = bottleneck (inputs_c , hparams , 2 * 2048 , "vc" )
414
418
else :
415
419
latent_len = common_layers .shape_list (targets_c )[1 ]
416
420
_ , _ , _ , embed = bottleneck (targets_c , hparams , 2 * 2048 , "vc" )
417
- t_c = tf .zeros_like (targets_c [:, :latent_len , :, :])
421
+ latents_dense = tf .zeros_like (targets_c [:, :latent_len , :, :])
418
422
if cache is None :
419
- cache = ae_latent_sample (t_c , inputs , ed , embed , 8 , hparams )
420
- cache = cache [0 , :, :]
421
- cache = tf .reshape (cache , [1 , latent_len , 1 ])
422
- cache = tf .tile (cache , [beam_size , 1 , 1 ])
423
- t_c = embed (cache )
423
+ cache = ae_latent_sample (latents_dense , inputs , ed , embed , 8 , hparams )
424
+ latents_dense = embed (cache )
424
425
# Postprocess.
425
- d = t_c
426
+ d = latents_dense
426
427
pos = tf .get_variable ("pos" , [1 , 1000 , 1 , hparams .hidden_size ])
427
- pos = pos [:, :common_layers .shape_list (t_c )[1 ] + 1 , :, :]
428
- t_c = tf .pad (t_c , [[0 , 0 ], [1 , 0 ], [0 , 0 ], [0 , 0 ]]) + pos
428
+ pos = pos [:, :common_layers .shape_list (latents_dense )[1 ] + 1 , :, :]
429
+ latents_dense = tf .pad (latents_dense ,
430
+ [[0 , 0 ], [1 , 0 ], [0 , 0 ], [0 , 0 ]]) + pos
429
431
430
432
# Masking.
431
433
if hparams .do_mask :
@@ -444,23 +446,26 @@ def bn_inputs():
444
446
d = residual_conv (d , 1 , (3 , 1 ), hparams , "decompress_rc_%d" % j )
445
447
d = decompress_step (d , hparams , i > 0 , False , "decompress_%d" % j )
446
448
targets = mask * targets + (1.0 - mask ) * d
447
- targets = tf .concat ([tf .reverse (t_c , [1 ]), targets ], axis = 1 )
449
+ targets = tf .concat ([tf .reverse (latents_dense , [1 ]), targets ], axis = 1 )
448
450
449
451
res = decode_transformer (inputs , ed , targets , hparams , "decoder" )
450
452
if hparams .do_ae :
451
- res = res [:, common_layers .shape_list (t_c )[1 ]:, :, :]
453
+ res = res [:, common_layers .shape_list (latents_dense )[1 ]:, :, :]
452
454
if hparams .do_mask and hparams .do_refine :
453
455
def refine_res ():
454
456
return residual_conv (res , 1 , (5 , 1 ), hparams , "refine" )
455
457
masked_batches = tf .reduce_sum (mask , axis = [1 , 2 , 3 ])
456
458
all_masked = tf .less (masked_batches , 0.1 )
457
459
res = tf .where (all_masked , refine_res (), res )
458
- latent_time = tf .less (200000 , tf .to_int32 (tf .train .get_global_step ()))
460
+ # We'll start training only the extra model of latents after 400K steps.
461
+ # Before we train only this, we decrease lr for other weights.
462
+ latent_time = tf .less (300000 , tf .to_int32 (tf .train .get_global_step ()))
463
+ decreased_lr = common_layers .inverse_lin_decay (400000 )
459
464
losses ["latent_pred" ] *= tf .to_float (latent_time )
460
465
losses ["extra" ] *= 1.0 - tf .to_float (latent_time )
461
- res = tf .cond ( latent_time ,
462
- lambda : tf . stop_gradient ( 0.7 * res ) + 0.3 * res ,
463
- lambda : res )
466
+ decreased_lr_res = tf .stop_gradient ( decreased_lr * res )
467
+ decreased_lr_res += ( 1.0 - decreased_lr ) * res
468
+ res = tf . cond ( latent_time , lambda : decreased_lr_res , lambda : res )
464
469
return res , losses , cache
465
470
466
471
@@ -481,27 +486,26 @@ def body(self, features):
481
486
if self ._hparams .drop_inputs :
482
487
inputs = None
483
488
reuse = "cache_raw" in features
484
- beam_size = self ._decode_hparams .beam_size
485
489
with tf .variable_scope (tf .get_variable_scope (), reuse = reuse ):
486
490
res , loss , _ = ae_transformer_internal (
487
491
inputs , features ["targets" ], features ["target_space_id" ],
488
- self ._hparams , beam_size , features .get ("cache_raw" , None ),
492
+ self ._hparams , features .get ("cache_raw" , None ),
489
493
predict_mask = self .predict_mask )
490
494
return res , loss
491
495
492
496
def prepare_features_for_infer (self , features ):
493
497
if not self ._hparams .do_ae :
494
498
return features
495
- beam_size = self ._decode_hparams .beam_size
496
- inputs = tf .zeros ([beam_size , 1 , 1 , self ._hparams .hidden_size ])
499
+ beam_batch_size = self ._decode_hparams .beam_size
500
+ beam_batch_size *= self ._decode_hparams .batch_size
501
+ inputs = tf .zeros ([beam_batch_size , 1 , 1 , self ._hparams .hidden_size ])
497
502
inputs = inputs if "inputs" in features else None
498
503
if self ._hparams .drop_inputs or not self .has_input :
499
504
inputs = None
500
- targets = tf .zeros ([beam_size , 1 , 1 , self ._hparams .hidden_size ])
505
+ targets = tf .zeros ([beam_batch_size , 1 , 1 , self ._hparams .hidden_size ])
501
506
with tf .variable_scope ("body" ):
502
507
_ , _ , cache = ae_transformer_internal (
503
- inputs , targets , features ["target_space_id" ],
504
- self ._hparams , beam_size )
508
+ inputs , targets , features ["target_space_id" ], self ._hparams )
505
509
features ["cache_raw" ] = cache
506
510
507
511
def infer (self , features = None , decode_length = 50 , beam_size = 1 , top_beams = 1 ,
@@ -531,6 +535,16 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
531
535
logits , _ = self (features ) # pylint: disable=not-callable
532
536
samples = tf .argmax (logits , axis = - 1 )
533
537
538
+ # More steps.
539
+ self .predict_mask = 0.0 # Use the provided targets this time.
540
+ how_many_more_steps = 0 # Set to 1 or more for Gibbs-like sampling.
541
+ for _ in xrange (how_many_more_steps ):
542
+ with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
543
+ features ["targets" ] = samples
544
+ logits , _ = self (features ) # pylint: disable=not-callable
545
+ samples = tf .argmax (logits , axis = - 1 )
546
+
547
+ self .predict_mask = 1.0
534
548
if inputs_old is not None : # Restore to not confuse Estimator.
535
549
features ["inputs" ] = inputs_old
536
550
return samples
0 commit comments