Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9598fa2

Browse files
royaurkoCopybara-Service
authored andcommitted
Add word dropout to transformer_vae
PiperOrigin-RevId: 200761864
1 parent 49f7f58 commit 9598fa2

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tensor2tensor/models/research/transformer_vae.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,14 @@ def ae_transformer_internal(inputs,
360360
targets, _ = common_layers.pad_to_same_length(
361361
targets, max_targets_len_from_inputs,
362362
final_length_divisible_by=2**hparams.num_compress_steps)
363-
targets_c = compress(targets, inputs, False, hparams, "compress")
363+
if hparams.word_dropout:
364+
mask = tf.random_uniform(shape=common_layers.shape_list(targets),
365+
minval=0.0, maxval=1.0)
366+
targets_noisy = tf.where(mask > hparams.word_dropout, targets,
367+
tf.zeros_like(targets))
368+
else:
369+
targets_noisy = targets
370+
targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
364371
if hparams.mode != tf.estimator.ModeKeys.PREDICT:
365372
# Compress and bottleneck.
366373
latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
@@ -668,6 +675,7 @@ def transformer_ae_small():
668675
hparams.add_hparam("noise_dev", 0.5)
669676
hparams.add_hparam("d_mix", 0.5)
670677
hparams.add_hparam("logit_normalization", True)
678+
hparams.add_hparam("word_dropout", 0.1)
671679
# Bottleneck kinds supported: dense, vae, semhash, gumbel-softmax, dvq.
672680
hparams.add_hparam("bottleneck_kind", "semhash")
673681
hparams.add_hparam("num_blocks", 1)

0 commit comments

Comments
 (0)