30
30
31
31
32
32
@registry .register_model
33
- class ResidualAutoencoder (basic .BasicAutoencoder ):
33
+ class AutoencoderAutoregressive (basic .BasicAutoencoder ):
34
+ """Autoencoder with an autoregressive part."""
35
+
36
+ def body (self , features ):
37
+ hparams = self ._hparams
38
+ shape = common_layers .shape_list (features ["targets" ])
39
+ # Run the basic autoencoder part first.
40
+ basic_result , losses = super (AutoencoderAutoregressive , self ).body (features )
41
+ # Prepare inputs for autoregressive modes.
42
+ targets_keep_prob = 1.0 - hparams .autoregressive_dropout
43
+ targets_dropout = common_layers .dropout_with_broadcast_dims (
44
+ features ["targets" ], targets_keep_prob , broadcast_dims = [- 1 ])
45
+ targets1d = tf .reshape (targets_dropout , [shape [0 ], - 1 , shape [3 ]])
46
+ targets_shifted = common_layers .shift_right_3d (targets1d )
47
+ basic1d = tf .reshape (basic_result , [shape [0 ], - 1 , shape [3 ]])
48
+ concat1d = tf .concat ([basic1d , targets_shifted ], axis = - 1 )
49
+ # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
50
+ if hparams .autoregressive_forget_base :
51
+ concat1d = tf .reshape (features ["targets" ], [shape [0 ], - 1 , shape [3 ]])
52
+ concat1d = common_layers .shift_right_3d (concat1d )
53
+ # The autoregressive part depends on the mode.
54
+ if hparams .autoregressive_mode == "none" :
55
+ assert not hparams .autoregressive_forget_base
56
+ return basic_result , losses
57
+ if hparams .autoregressive_mode == "conv3" :
58
+ res = common_layers .conv1d (concat1d , shape [3 ], 3 , padding = "LEFT" ,
59
+ activation = common_layers .belu ,
60
+ name = "autoregressive_conv3" )
61
+ return tf .reshape (res , shape ), losses
62
+ if hparams .autoregressive_mode == "conv5" :
63
+ res = common_layers .conv1d (concat1d , shape [3 ], 5 , padding = "LEFT" ,
64
+ activation = common_layers .belu ,
65
+ name = "autoregressive_conv5" )
66
+ return tf .reshape (res , shape ), losses
67
+ if hparams .autoregressive_mode == "sru" :
68
+ res = common_layers .conv1d (concat1d , shape [3 ], 3 , padding = "LEFT" ,
69
+ activation = common_layers .belu ,
70
+ name = "autoregressive_sru_conv3" )
71
+ res = common_layers .sru (res )
72
+ return tf .reshape (res , shape ), losses
73
+
74
+ raise ValueError ("Unsupported autoregressive mode: %s"
75
+ % hparams .autoregressive_mode )
76
+
77
+
78
+ @registry .register_model
79
+ class AutoencoderResidual (AutoencoderAutoregressive ):
34
80
"""Residual autoencoder."""
35
81
36
82
def encoder (self , x ):
@@ -106,7 +152,7 @@ def decoder(self, x):
106
152
107
153
108
154
@registry .register_model
109
- class BasicDiscreteAutoencoder ( basic . BasicAutoencoder ):
155
+ class AutoencoderBasicDiscrete ( AutoencoderAutoregressive ):
110
156
"""Discrete autoencoder."""
111
157
112
158
def bottleneck (self , x ):
@@ -132,7 +178,7 @@ def sample(self):
132
178
133
179
134
180
@registry .register_model
135
- class ResidualDiscreteAutoencoder ( ResidualAutoencoder ):
181
+ class AutoencoderResidualDiscrete ( AutoencoderResidual ):
136
182
"""Discrete residual autoencoder."""
137
183
138
184
def bottleneck (self , x , bottleneck_size = None ):
@@ -160,13 +206,15 @@ def sample(self):
160
206
size = [hp .batch_size , hp .sample_height // div_x , hp .sample_width // div_y ,
161
207
hp .bottleneck_size ]
162
208
rand = tf .random_uniform (size )
163
- res1 = 2.0 * tf .to_float (tf .less (0.5 , rand )) - 1.0
164
- res2 = tf .zeros_like (rand ) - 1.0
165
- return tf .concat ([res2 [:, :, :, :2 ], res1 [:, :, :, 2 :]], axis = - 1 )
209
+ res = 2.0 * tf .to_float (tf .less (0.5 , rand )) - 1.0
210
+ # If you want to set some first bits to a fixed value, do this:
211
+ # fixed = tf.zeros_like(rand) - 1.0
212
+ # res = tf.concat([fixed[:, :, :, :2], res[:, :, :, 2:]], axis=-1)
213
+ return res
166
214
167
215
168
216
@registry .register_model
169
- class OrderedDiscreteAutoencoder ( ResidualDiscreteAutoencoder ):
217
+ class AutoencoderOrderedDiscrete ( AutoencoderResidualDiscrete ):
170
218
"""Ordered discrete autoencoder."""
171
219
172
220
def bottleneck (self , x ):
@@ -195,7 +243,7 @@ def bottleneck(self, x):
195
243
196
244
197
245
@registry .register_model
198
- class StackedAutoencoder ( ResidualDiscreteAutoencoder ):
246
+ class AutoencoderStacked ( AutoencoderResidualDiscrete ):
199
247
"""A stacked autoencoder."""
200
248
201
249
def stack (self , b , size , bottleneck_size , name ):
@@ -290,9 +338,19 @@ def body(self, features):
290
338
291
339
292
340
@registry .register_hparams
293
- def residual_autoencoder ():
294
- """Residual autoencoder model."""
341
+ def autoencoder_autoregressive ():
342
+ """Autoregressive autoencoder model."""
295
343
hparams = basic .basic_autoencoder ()
344
+ hparams .add_hparam ("autoregressive_forget_base" , False )
345
+ hparams .add_hparam ("autoregressive_mode" , "conv3" )
346
+ hparams .add_hparam ("autoregressive_dropout" , 0.4 )
347
+ return hparams
348
+
349
+
350
+ @registry .register_hparams
351
+ def autoencoder_residual ():
352
+ """Residual autoencoder model."""
353
+ hparams = autoencoder_autoregressive ()
296
354
hparams .optimizer = "Adam"
297
355
hparams .learning_rate_constant = 0.0001
298
356
hparams .learning_rate_warmup_steps = 500
@@ -311,9 +369,9 @@ def residual_autoencoder():
311
369
312
370
313
371
@registry .register_hparams
314
- def basic_discrete_autoencoder ():
372
+ def autoencoder_basic_discrete ():
315
373
"""Basic autoencoder model."""
316
- hparams = basic . basic_autoencoder ()
374
+ hparams = autoencoder_autoregressive ()
317
375
hparams .num_hidden_layers = 5
318
376
hparams .hidden_size = 64
319
377
hparams .bottleneck_size = 4096
@@ -324,9 +382,9 @@ def basic_discrete_autoencoder():
324
382
325
383
326
384
@registry .register_hparams
327
- def residual_discrete_autoencoder ():
385
+ def autoencoder_residual_discrete ():
328
386
"""Residual discrete autoencoder model."""
329
- hparams = residual_autoencoder ()
387
+ hparams = autoencoder_residual ()
330
388
hparams .bottleneck_size = 4096
331
389
hparams .bottleneck_noise = 0.1
332
390
hparams .bottleneck_warmup_steps = 3000
@@ -339,9 +397,9 @@ def residual_discrete_autoencoder():
339
397
340
398
341
399
@registry .register_hparams
342
- def residual_discrete_autoencoder_big ():
400
+ def autoencoder_residual_discrete_big ():
343
401
"""Residual discrete autoencoder model, big version."""
344
- hparams = residual_discrete_autoencoder ()
402
+ hparams = autoencoder_residual_discrete ()
345
403
hparams .hidden_size = 128
346
404
hparams .max_hidden_size = 4096
347
405
hparams .bottleneck_noise = 0.1
@@ -351,15 +409,15 @@ def residual_discrete_autoencoder_big():
351
409
352
410
353
411
@registry .register_hparams
354
- def ordered_discrete_autoencoder ():
412
+ def autoencoder_ordered_discrete ():
355
413
"""Basic autoencoder model."""
356
- hparams = residual_discrete_autoencoder ()
414
+ hparams = autoencoder_residual_discrete ()
357
415
return hparams
358
416
359
417
360
418
@registry .register_hparams
361
- def stacked_autoencoder ():
419
+ def autoencoder_stacked ():
362
420
"""Stacked autoencoder model."""
363
- hparams = residual_discrete_autoencoder ()
421
+ hparams = autoencoder_residual_discrete ()
364
422
hparams .bottleneck_size = 128
365
423
return hparams
0 commit comments