@@ -163,11 +163,9 @@ def kmeans(x, means, hparams, name):
163
163
with tf .variable_scope (name ):
164
164
x_means_hot = nearest (x , means , hparams )
165
165
x_means = tf .gather (means , tf .argmax (x_means_hot , axis = - 1 ))
166
- x_flat = tf .reshape (x , [- 1 , hparams .hidden_size ])
167
- kl = tf .reduce_mean (tf .reduce_sum (tf .square (x_flat - x_means ), axis = - 1 ))
168
166
reg_loss1 = tf .nn .l2_loss ((tf .stop_gradient (x ) - x_means ))
169
167
reg_loss2 = hparams .beta * tf .nn .l2_loss ((x - tf .stop_gradient (x_means )))
170
- l = kl + reg_loss1 + reg_loss2
168
+ l = reg_loss1 + reg_loss2
171
169
return x_means_hot , x_means , l
172
170
173
171
@@ -208,6 +206,8 @@ def embed(x):
208
206
means = tf .get_variable (name = "means" ,
209
207
shape = [hparams .v_size , hparams .hidden_size ])
210
208
h1 = tf .gather (means , x )
209
+ elif hparams .bottleneck_kind == "rounding" :
210
+ h1 = tf .round (x )
211
211
212
212
h2 = tf .layers .dense (tf .nn .relu (h1 ), filter_size , name = "vch2" )
213
213
return tf .layers .dense (tf .nn .relu (h2 ), hparams .hidden_size , name = "vcfin" )
@@ -255,6 +255,9 @@ def embed(x):
255
255
x_means_hot , x_means , l = kmeans (x , means , hparams , name = "vq-vae-kmeans" )
256
256
h1 = tf .stop_gradient (x_means ) + x - tf .stop_gradient (x )
257
257
c = tf .argmax (x_means_hot , axis = - 1 )
258
+ if hparams .bottleneck_kind == "round" :
259
+ c = tf .round (x )
260
+ h1 = x + tf .stop_gradient (tf .round (x ) - x )
258
261
h2 = tf .layers .dense (tf .nn .relu (h1 ), filter_size , name = "vch2" )
259
262
res = tf .layers .dense (tf .nn .relu (h2 ), hparams .hidden_size , name = "vcfin" )
260
263
return res , c , l , embed
0 commit comments