Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4a36fb8

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Fix for issue #215 on github, update transformer_vae.
PiperOrigin-RevId: 164771762
1 parent a9826de commit 4a36fb8

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tensor2tensor/models/transformer_vae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def encode(x, x_space, hparams, name):
109109
with tf.variable_scope(name):
110110
(encoder_input, encoder_self_attention_bias,
111111
_) = transformer.transformer_prepare_encoder(x, x_space, hparams)
112-
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout)
112+
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
113113
return transformer.transformer_encoder(
114114
encoder_input, encoder_self_attention_bias, hparams)
115115

@@ -143,7 +143,7 @@ def vae_transformer_internal(inputs, targets, target_space, hparams):
143143
max_prestep = hparams.kl_warmup_steps
144144
prob_targets = 0.95 if is_training else 1.0
145145
targets_dropout_max = common_layers.inverse_lin_decay(max_prestep) - 0.01
146-
targets = dropmask(targets, targets_dropout_max, is_training)
146+
targets = dropmask(targets, targets_dropout_max * 0.7, is_training)
147147
targets = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
148148
lambda: targets, lambda: tf.zeros_like(targets))
149149

@@ -168,7 +168,7 @@ def vae_transformer_internal(inputs, targets, target_space, hparams):
168168
# ret = tf.squeeze(to_decode, axis=2)
169169

170170
# Randomize decoder inputs..
171-
kl_loss *= common_layers.inverse_exp_decay(max_prestep) * 3.0
171+
kl_loss *= common_layers.inverse_exp_decay(max_prestep) * 10.0
172172
return tf.expand_dims(ret, axis=2), kl_loss
173173

174174

tensor2tensor/utils/devices.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _replica_device_setter(worker_device):
112112
if FLAGS.schedule == "local_run":
113113
assert not FLAGS.sync
114114
datashard_devices = ["gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu)]
115-
if FLAGS.locally_shard_to_cpu:
115+
if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1:
116116
datashard_devices += ["cpu:0"]
117117
caching_devices = None
118118
elif FLAGS.sync:

0 commit comments

Comments
 (0)