Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit fd77a8b

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Fix the rounding bottleneck. At present the input is squashed into 1-d and is rounded in the interval [0, v_size].
PiperOrigin-RevId: 179778221
1 parent 5388318 commit fd77a8b

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

tensor2tensor/models/transformer_vae.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,12 @@
1818
from __future__ import absolute_import
1919
from __future__ import division
2020
from __future__ import print_function
21-
2221
# Dependency imports
23-
24-
from six.moves import xrange # pylint: disable=redefined-builtin
25-
2622
from tensor2tensor.layers import common_layers
2723
from tensor2tensor.models import transformer
2824
from tensor2tensor.utils import expert_utils
2925
from tensor2tensor.utils import registry
3026
from tensor2tensor.utils import t2t_model
31-
3227
import tensorflow as tf
3328

3429

@@ -207,7 +202,7 @@ def embed(x):
207202
shape=[hparams.v_size, hparams.hidden_size])
208203
h1 = tf.gather(means, x)
209204
elif hparams.bottleneck_kind == "rounding":
210-
h1 = tf.round(x)
205+
h1 = x
211206

212207
h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
213208
return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
@@ -255,9 +250,19 @@ def embed(x):
255250
x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans")
256251
h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x)
257252
c = tf.argmax(x_means_hot, axis=-1)
258-
if hparams.bottleneck_kind == "round":
259-
c = tf.round(x)
260-
h1 = x + tf.stop_gradient(tf.round(x) - x)
253+
if hparams.bottleneck_kind == "rounding":
254+
h = tf.layers.dense(x, 1, name="vcc")
255+
256+
# Make h between 0 and 1
257+
h = tf.sigmoid(h)
258+
259+
# Multiply by z_size to get it between [0, z_size]
260+
h *= hparams.v_size
261+
262+
# Use the rounding bottleneck
263+
h1 = h + tf.stop_gradient(tf.round(h) - h)
264+
c = tf.squeeze(tf.round(h), axis=-1)
265+
c = tf.to_int32(c)
261266
h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
262267
res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
263268
return res, c, l, embed

0 commit comments

Comments
 (0)