Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2be0cbb

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Work on VAE Transformer
PiperOrigin-RevId: 179508117
1 parent 474545a commit 2be0cbb

File tree

1 file changed

+64
-50
lines changed

1 file changed

+64
-50
lines changed

tensor2tensor/models/transformer_vae.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -324,38 +324,40 @@ def multinomial_sample(x, vocab_size, temperature):
324324
return tf.to_int32(reshaped_samples)
325325

326326

327-
def ae_latent_sample(t_c, inputs, ed, embed, iters, hparams):
327+
def ae_latent_sample(latents_dense, inputs, ed, embed, iters, hparams):
328328
"""Sample from the latent space in the autoencoder."""
329-
t_pred = decode_transformer(inputs, ed, t_c, hparams, "extra")
330-
t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
331-
t_bit = multinomial_sample(t_pred, 2**16, hparams.sampling_temp)
329+
latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra")
330+
latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
331+
latents_discrete = multinomial_sample(
332+
latents_pred, 2**16, hparams.sampling_temp)
332333

333-
def next_bit(t_bit, i):
334-
t_bit_prev = t_bit
334+
def next_bit(latents_discrete, i):
335+
latents_discrete_prev = latents_discrete
335336
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
336-
t_c = embed(t_bit)
337-
t_pred = decode_transformer(inputs, ed, t_c, hparams, "extra")
338-
t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
339-
t_bit = multinomial_sample(t_pred, 2**16, hparams.sampling_temp)
340-
return tf.concat([t_bit_prev[:, :(i+1), :],
341-
t_bit[:, (i+1):, :]], axis=1)
337+
latents_dense = embed(latents_discrete)
338+
latents_pred = decode_transformer(
339+
inputs, ed, latents_dense, hparams, "extra")
340+
latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
341+
latents_discrete = multinomial_sample(
342+
latents_pred, 2**16, hparams.sampling_temp)
343+
return tf.concat([latents_discrete_prev[:, :(i+1), :],
344+
latents_discrete[:, (i+1):, :]], axis=1)
342345

343346
for i in xrange(iters):
344-
t_bit = next_bit(t_bit, i)
345-
return t_bit
347+
latents_discrete = next_bit(latents_discrete, i)
348+
return latents_discrete
346349

347350

348351
def ae_transformer_internal(inputs, targets, target_space, hparams,
349-
beam_size, cache=None, predict_mask=1.0):
352+
cache=None, predict_mask=1.0):
350353
"""AE Transformer, main step used for training."""
351354
# Summaries break with the do_refine cond, turn them off in that case.
352355
global _DO_SUMMARIES
353356
if hparams.do_refine:
354357
_DO_SUMMARIES = False
355358

356359
# Prepare.
357-
orig_targets = targets
358-
batch_size = common_layers.shape_list(orig_targets)[0]
360+
batch_size = common_layers.shape_list(inputs)[0]
359361
targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])
360362

361363
# Encoder.
@@ -375,22 +377,24 @@ def ae_transformer_internal(inputs, targets, target_space, hparams,
375377
targets_c = compress(targets, False, hparams, "compress")
376378
if hparams.mode != tf.estimator.ModeKeys.PREDICT:
377379
# Compress and bottleneck.
378-
t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc")
380+
latents_dense, latents_discrete, extra_loss, _ = bottleneck(
381+
targets_c, hparams, 2*2048, "vc")
379382
if _DO_SUMMARIES:
380-
tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1]))
383+
tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
381384
pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
382385
pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
383386
cond = tf.less(tf.random_uniform([batch_size]), pc)
384-
t_c = tf.where(cond, t_c, targets_c)
387+
latents_dense = tf.where(cond, latents_dense, targets_c)
385388
# TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
386-
losses["extra"] = vc_loss * tf.reduce_mean(tf.to_float(cond))
389+
losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
387390
# Extra loss predicting latent code from input. Discrete only.
388391
if hparams.bottleneck_kind not in ["dense", "vae"]:
389-
t_pred = decode_transformer(
390-
inputs, ed, tf.stop_gradient(t_c), hparams, "extra")
391-
t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
392+
latents_pred = decode_transformer(
393+
tf.stop_gradient(inputs), tf.stop_gradient(ed),
394+
tf.stop_gradient(latents_dense), hparams, "extra")
395+
latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
392396
losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
393-
labels=t_bit, logits=t_pred)
397+
labels=latents_discrete, logits=latents_pred)
394398
losses["latent_pred"] = tf.reduce_mean(
395399
losses["latent_pred"] * 0.5 * tf.to_float(cond))
396400
else:
@@ -405,27 +409,25 @@ def bn_inputs():
405409
bn_inputs, lambda: inputs_c)
406410
ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
407411
ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
408-
t_c = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
409-
t_c, inputs_c)
412+
latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
413+
latents_dense, inputs_c)
410414
else:
411415
if hparams.bottleneck_kind in ["dense", "vae"]:
412416
inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
413-
t_c, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc")
417+
latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc")
414418
else:
415419
latent_len = common_layers.shape_list(targets_c)[1]
416420
_, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc")
417-
t_c = tf.zeros_like(targets_c[:, :latent_len, :, :])
421+
latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
418422
if cache is None:
419-
cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams)
420-
cache = cache[0, :, :]
421-
cache = tf.reshape(cache, [1, latent_len, 1])
422-
cache = tf.tile(cache, [beam_size, 1, 1])
423-
t_c = embed(cache)
423+
cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams)
424+
latents_dense = embed(cache)
424425
# Postprocess.
425-
d = t_c
426+
d = latents_dense
426427
pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
427-
pos = pos[:, :common_layers.shape_list(t_c)[1] + 1, :, :]
428-
t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos
428+
pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
429+
latents_dense = tf.pad(latents_dense,
430+
[[0, 0], [1, 0], [0, 0], [0, 0]]) + pos
429431

430432
# Masking.
431433
if hparams.do_mask:
@@ -444,23 +446,26 @@ def bn_inputs():
444446
d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
445447
d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
446448
targets = mask * targets + (1.0 - mask) * d
447-
targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1)
449+
targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1)
448450

449451
res = decode_transformer(inputs, ed, targets, hparams, "decoder")
450452
if hparams.do_ae:
451-
res = res[:, common_layers.shape_list(t_c)[1]:, :, :]
453+
res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
452454
if hparams.do_mask and hparams.do_refine:
453455
def refine_res():
454456
return residual_conv(res, 1, (5, 1), hparams, "refine")
455457
masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
456458
all_masked = tf.less(masked_batches, 0.1)
457459
res = tf.where(all_masked, refine_res(), res)
458-
latent_time = tf.less(200000, tf.to_int32(tf.train.get_global_step()))
460+
# We'll start training only the extra model of latents after 400K steps.
461+
# Before we train only this, we decrease lr for other weights.
462+
latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step()))
463+
decreased_lr = common_layers.inverse_lin_decay(400000)
459464
losses["latent_pred"] *= tf.to_float(latent_time)
460465
losses["extra"] *= 1.0 - tf.to_float(latent_time)
461-
res = tf.cond(latent_time,
462-
lambda: tf.stop_gradient(0.7 * res) + 0.3 * res,
463-
lambda: res)
466+
decreased_lr_res = tf.stop_gradient(decreased_lr * res)
467+
decreased_lr_res += (1.0 - decreased_lr) * res
468+
res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res)
464469
return res, losses, cache
465470

466471

@@ -481,27 +486,26 @@ def body(self, features):
481486
if self._hparams.drop_inputs:
482487
inputs = None
483488
reuse = "cache_raw" in features
484-
beam_size = self._decode_hparams.beam_size
485489
with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
486490
res, loss, _ = ae_transformer_internal(
487491
inputs, features["targets"], features["target_space_id"],
488-
self._hparams, beam_size, features.get("cache_raw", None),
492+
self._hparams, features.get("cache_raw", None),
489493
predict_mask=self.predict_mask)
490494
return res, loss
491495

492496
def prepare_features_for_infer(self, features):
493497
if not self._hparams.do_ae:
494498
return features
495-
beam_size = self._decode_hparams.beam_size
496-
inputs = tf.zeros([beam_size, 1, 1, self._hparams.hidden_size])
499+
beam_batch_size = self._decode_hparams.beam_size
500+
beam_batch_size *= self._decode_hparams.batch_size
501+
inputs = tf.zeros([beam_batch_size, 1, 1, self._hparams.hidden_size])
497502
inputs = inputs if "inputs" in features else None
498503
if self._hparams.drop_inputs or not self.has_input:
499504
inputs = None
500-
targets = tf.zeros([beam_size, 1, 1, self._hparams.hidden_size])
505+
targets = tf.zeros([beam_batch_size, 1, 1, self._hparams.hidden_size])
501506
with tf.variable_scope("body"):
502507
_, _, cache = ae_transformer_internal(
503-
inputs, targets, features["target_space_id"],
504-
self._hparams, beam_size)
508+
inputs, targets, features["target_space_id"], self._hparams)
505509
features["cache_raw"] = cache
506510

507511
def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
@@ -531,6 +535,16 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
531535
logits, _ = self(features) # pylint: disable=not-callable
532536
samples = tf.argmax(logits, axis=-1)
533537

538+
# More steps.
539+
self.predict_mask = 0.0 # Use the provided targets this time.
540+
how_many_more_steps = 0 # Set to 1 or more for Gibbs-like sampling.
541+
for _ in xrange(how_many_more_steps):
542+
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
543+
features["targets"] = samples
544+
logits, _ = self(features) # pylint: disable=not-callable
545+
samples = tf.argmax(logits, axis=-1)
546+
547+
self.predict_mask = 1.0
534548
if inputs_old is not None: # Restore to not confuse Estimator.
535549
features["inputs"] = inputs_old
536550
return samples

0 commit comments

Comments
 (0)