Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e96ca3a

Browse files
royaurkoCopybara-Service
authored andcommitted
Bring back latent pred masking as it is somewhat unstable without it
PiperOrigin-RevId: 200764166
1 parent 6fb1537 commit e96ca3a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tensor2tensor/models/research/transformer_nat.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None):
299299

300300
targets = d
301301
res = decode_transformer(inputs, ed, d, hparams, "decoder")
302+
latent_time = tf.less(hparams.mask_startup_steps,
303+
tf.to_int32(tf.train.get_global_step()))
304+
losses["latent_pred"] *= tf.to_float(latent_time)
302305
return res, losses, cache
303306

304307

@@ -385,14 +388,15 @@ def transformer_nat_small():
385388
hparams.optimizer = "Adam"
386389
hparams.optimizer_adam_epsilon = 1e-9
387390
hparams.optimizer_adam_beta1 = 0.9
388-
hparams.optimizer_adam_beta2 = 0.997 # Needs tuning, try 0.98 to 0.999.
391+
hparams.optimizer_adam_beta2 = 0.997
389392
hparams.add_hparam("bottleneck_kind", "vq")
390393
hparams.add_hparam("bottleneck_bits", 12)
391394
hparams.add_hparam("num_compress_steps", 3)
392395
hparams.add_hparam("beta", 0.25)
393396
hparams.add_hparam("epsilon", 1e-5)
394397
hparams.add_hparam("decay", 0.999)
395398
hparams.add_hparam("num_samples", 10)
399+
hparams.add_hparam("mask_startup_steps", 50000)
396400
return hparams
397401

398402

0 commit comments

Comments
 (0)