Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ae62ed6

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Remove the extra kl loss term from the VQ-VAE loss.
PiperOrigin-RevId: 179490021
1 parent 69e4b36 commit ae62ed6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tensor2tensor/models/transformer_vae.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,9 @@ def kmeans(x, means, hparams, name):
163163
with tf.variable_scope(name):
164164
x_means_hot = nearest(x, means, hparams)
165165
x_means = tf.gather(means, tf.argmax(x_means_hot, axis=-1))
166-
x_flat = tf.reshape(x, [-1, hparams.hidden_size])
167-
kl = tf.reduce_mean(tf.reduce_sum(tf.square(x_flat - x_means), axis=-1))
168166
reg_loss1 = tf.nn.l2_loss((tf.stop_gradient(x) - x_means))
169167
reg_loss2 = hparams.beta * tf.nn.l2_loss((x - tf.stop_gradient(x_means)))
170-
l = kl + reg_loss1 + reg_loss2
168+
l = reg_loss1 + reg_loss2
171169
return x_means_hot, x_means, l
172170

173171

@@ -208,6 +206,8 @@ def embed(x):
208206
means = tf.get_variable(name="means",
209207
shape=[hparams.v_size, hparams.hidden_size])
210208
h1 = tf.gather(means, x)
209+
elif hparams.bottleneck_kind == "rounding":
210+
h1 = tf.round(x)
211211

212212
h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
213213
return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
@@ -255,6 +255,9 @@ def embed(x):
255255
x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans")
256256
h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x)
257257
c = tf.argmax(x_means_hot, axis=-1)
258+
if hparams.bottleneck_kind == "round":
259+
c = tf.round(x)
260+
h1 = x + tf.stop_gradient(tf.round(x) - x)
258261
h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
259262
res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
260263
return res, c, l, embed

0 commit comments

Comments
 (0)