|
18 | 18 | from __future__ import absolute_import
|
19 | 19 | from __future__ import division
|
20 | 20 | from __future__ import print_function
|
21 |
| - |
22 | 21 | # Dependency imports
|
23 |
| - |
24 |
| -from six.moves import xrange # pylint: disable=redefined-builtin |
25 |
| - |
26 | 22 | from tensor2tensor.layers import common_layers
|
27 | 23 | from tensor2tensor.models import transformer
|
28 | 24 | from tensor2tensor.utils import expert_utils
|
29 | 25 | from tensor2tensor.utils import registry
|
30 | 26 | from tensor2tensor.utils import t2t_model
|
31 |
| - |
32 | 27 | import tensorflow as tf
|
33 | 28 |
|
34 | 29 |
|
@@ -207,7 +202,7 @@ def embed(x):
|
207 | 202 | shape=[hparams.v_size, hparams.hidden_size])
|
208 | 203 | h1 = tf.gather(means, x)
|
209 | 204 | elif hparams.bottleneck_kind == "rounding":
|
210 |
| - h1 = tf.round(x) |
| 205 | + h1 = x |
211 | 206 |
|
212 | 207 | h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
|
213 | 208 | return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
|
@@ -255,9 +250,19 @@ def embed(x):
|
255 | 250 | x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans")
|
256 | 251 | h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x)
|
257 | 252 | c = tf.argmax(x_means_hot, axis=-1)
|
258 |
| - if hparams.bottleneck_kind == "round": |
259 |
| - c = tf.round(x) |
260 |
| - h1 = x + tf.stop_gradient(tf.round(x) - x) |
| 253 | + if hparams.bottleneck_kind == "rounding": |
| 254 | + h = tf.layers.dense(x, 1, name="vcc") |
| 255 | + |
| 256 | + # Make h between 0 and 1 |
| 257 | + h = tf.sigmoid(h) |
| 258 | + |
| 259 | + # Multiply by z_size to get it between [0, z_size] |
| 260 | + h *= hparams.v_size |
| 261 | + |
| 262 | + # Use the rounding bottleneck |
| 263 | + h1 = h + tf.stop_gradient(tf.round(h) - h) |
| 264 | + c = tf.squeeze(tf.round(h), axis=-1) |
| 265 | + c = tf.to_int32(c) |
261 | 266 | h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
|
262 | 267 | res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
|
263 | 268 | return res, c, l, embed
|
|
0 commit comments