23
23
24
24
from six .moves import xrange # pylint: disable=redefined-builtin
25
25
26
+ from tensor2tensor .layers import common_attention
26
27
from tensor2tensor .layers import common_layers
27
28
from tensor2tensor .models import transformer
28
29
from tensor2tensor .utils import registry
@@ -49,13 +50,43 @@ def residual_conv(x, repeat, hparams, name, reuse=None):
49
50
return x
50
51
51
52
52
- def decompress_step (source , hparams , first_relu , name ):
53
+ def attend (x , source , hparams , name ):
54
+ with tf .variable_scope (name ):
55
+ x = tf .squeeze (x , axis = 2 )
56
+ if len (source .get_shape ()) > 3 :
57
+ source = tf .squeeze (source , axis = 2 )
58
+ source = common_attention .add_timing_signal_1d (source )
59
+ y = common_attention .multihead_attention (
60
+ common_layers .layer_preprocess (x , hparams ), source , None ,
61
+ hparams .attention_key_channels or hparams .hidden_size ,
62
+ hparams .attention_value_channels or hparams .hidden_size ,
63
+ hparams .hidden_size , hparams .num_heads ,
64
+ hparams .attention_dropout )
65
+ res = common_layers .layer_postprocess (x , y , hparams )
66
+ return tf .expand_dims (res , axis = 2 )
67
+
68
+
69
+ def interleave (x , y , axis = 1 ):
70
+ x = tf .expand_dims (x , axis = axis + 1 )
71
+ y = tf .expand_dims (y , axis = axis + 1 )
72
+ return tf .concat ([x , y ], axis = axis + 1 )
73
+
74
+
75
+ def decompress_step (source , c , hparams , first_relu , name ):
53
76
"""Decompression function."""
54
77
with tf .variable_scope (name ):
55
78
shape = tf .shape (source )
56
- thicker = common_layers .conv_block (
57
- source , hparams .hidden_size * 2 , [((1 , 1 ), (1 , 1 ))],
58
- first_relu = first_relu , name = "decompress_conv" )
79
+ if c is not None :
80
+ source = attend (source , c , hparams , "decompress_attend" )
81
+ first = common_layers .conv_block (
82
+ source ,
83
+ hparams .hidden_size , [((1 , 1 ), (3 , 1 )), ((1 , 1 ), (3 , 1 ))],
84
+ first_relu = first_relu , padding = "SAME" , name = "decompress_conv1" )
85
+ second = common_layers .conv_block (
86
+ tf .concat ([source , first ], axis = 3 ),
87
+ hparams .hidden_size , [((1 , 1 ), (3 , 1 )), ((1 , 1 ), (3 , 1 ))],
88
+ first_relu = first_relu , padding = "SAME" , name = "decompress_conv2" )
89
+ thicker = interleave (first , second )
59
90
return tf .reshape (thicker , [shape [0 ], shape [1 ] * 2 , 1 , hparams .hidden_size ])
60
91
61
92
@@ -71,23 +102,25 @@ def vae(x, hparams, name):
71
102
return z , tf .reduce_mean (kl ), mu , log_sigma
72
103
73
104
74
- def compress (inputs , hparams , name ):
105
+ def compress (x , c , hparams , name ):
75
106
"""Compress."""
76
107
with tf .variable_scope (name ):
77
108
# Run compression by strided convs.
78
- cur = inputs
109
+ cur = x
79
110
for i in xrange (hparams .num_compress_steps ):
111
+ if c is not None :
112
+ cur = attend (cur , c , hparams , "compress_attend_%d" % i )
80
113
cur = residual_conv (cur , 1 , hparams , "compress_rc_%d" % i )
81
114
cur = common_layers .conv_block (
82
115
cur , hparams .hidden_size , [((1 , 1 ), (2 , 1 ))],
83
116
strides = (2 , 1 ), name = "compress_%d" % i )
84
117
return cur
85
118
86
119
87
- def vae_compress (inputs , hparams , compress_name , decompress_name , reuse = None ):
120
+ def vae_compress (x , c , hparams , compress_name , decompress_name , reuse = None ):
88
121
"""Compress, then VAE."""
89
122
with tf .variable_scope (compress_name , reuse = reuse ):
90
- cur = compress (inputs , hparams , "compress" )
123
+ cur = compress (x , c , hparams , "compress" )
91
124
# Convolve and ReLu to get state.
92
125
cur = common_layers .conv_block (
93
126
cur , hparams .hidden_size , [((1 , 1 ), (1 , 1 ))], name = "mid_conv" )
@@ -100,7 +133,7 @@ def vae_compress(inputs, hparams, compress_name, decompress_name, reuse=None):
100
133
for i in xrange (hparams .num_compress_steps ):
101
134
j = hparams .num_compress_steps - i - 1
102
135
z = residual_conv (z , 1 , hparams , "decompress_rc_%d" % j )
103
- z = decompress_step (z , hparams , i > 0 , "decompress__step_%d" % j )
136
+ z = decompress_step (z , c , hparams , i > 0 , "decompress__step_%d" % j )
104
137
return z , kl_loss , mu , log_sigma
105
138
106
139
@@ -124,6 +157,13 @@ def dropmask(targets, targets_dropout_max, is_training):
124
157
return targets * keep_mask
125
158
126
159
160
+ def ffn (x , hparams , name ):
161
+ with tf .variable_scope (name ):
162
+ y = transformer .transformer_ffn_layer (
163
+ common_layers .layer_preprocess (x , hparams ), hparams )
164
+ return common_layers .layer_postprocess (x , y , hparams )
165
+
166
+
127
167
def vae_transformer_internal (inputs , targets , target_space , hparams ):
128
168
"""VAE Transformer, main step used for training."""
129
169
with tf .variable_scope ("vae_transformer" ):
@@ -140,36 +180,40 @@ def vae_transformer_internal(inputs, targets, target_space, hparams):
140
180
inputs = encode (inputs , target_space , hparams , "input_enc" )
141
181
142
182
# Dropout targets or swap for zeros 5% of the time.
183
+ targets_nodrop = targets
143
184
max_prestep = hparams .kl_warmup_steps
144
185
prob_targets = 0.95 if is_training else 1.0
145
186
targets_dropout_max = common_layers .inverse_lin_decay (max_prestep ) - 0.01
146
187
targets = dropmask (targets , targets_dropout_max * 0.7 , is_training )
147
188
targets = tf .cond (tf .less (tf .random_uniform ([]), prob_targets ),
148
189
lambda : targets , lambda : tf .zeros_like (targets ))
149
-
150
- # Join targets with inputs, run encoder.
151
- # to_encode = common_layers.conv_block(
152
- # tf.expand_dims(tf.concat([targets, inputs], axis=2), axis=2),
153
- # hparams.hidden_size, [((1, 1), (1, 1))],
154
- # first_relu=False, name="join_targets")
155
- # to_compress = encode(tf.squeeze(to_encode, axis=2),
156
- # target_space, hparams, "enc")
190
+ targets = targets_nodrop
157
191
158
192
# Compress and vae.
159
- z , kl_loss , _ , _ = vae_compress (tf .expand_dims (targets , axis = 2 ), hparams ,
160
- "vae_compress" , "vae_decompress" )
193
+ z = tf .get_variable ("z" , [hparams .hidden_size ])
194
+ z = tf .reshape (z , [1 , 1 , 1 , - 1 ])
195
+ z = tf .tile (z , [tf .shape (inputs )[0 ], 1 , 1 , 1 ])
196
+
197
+ z = attend (z , inputs , hparams , "z_attendsi" )
198
+ z = ffn (z , hparams , "zff2" )
199
+ z = attend (z , targets , hparams , "z_attendst2" )
200
+ z = ffn (z , hparams , "zff3" )
201
+ z , kl_loss , _ , _ = vae (z , hparams , name = "vae" )
202
+ z = tf .layers .dense (z , hparams .hidden_size , name = "z_to_dense" )
203
+
204
+ # z, kl_loss, _, _ = vae_compress(
205
+ # tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2),
206
+ # hparams, "vae_compress", "vae_decompress")
161
207
162
- # Join z with inputs, run decoder.
163
- to_decode = common_layers .conv_block (
164
- tf .concat ([z , tf .expand_dims (inputs , axis = 2 )], axis = 3 ),
165
- hparams .hidden_size , [((1 , 1 ), (1 , 1 ))], name = "join_z" )
166
- ret = encode (tf .squeeze (to_decode , axis = 2 ), target_space , hparams , "dec" )
167
- # to_decode = residual_conv(to_decode, 2, hparams, "dec_conv")
168
- # ret = tf.squeeze(to_decode, axis=2)
208
+ decoder_in = tf .squeeze (z , axis = 2 ) + tf .zeros_like (targets )
209
+ (decoder_input , decoder_self_attention_bias ) = (
210
+ transformer .transformer_prepare_decoder (decoder_in , hparams ))
211
+ ret = transformer .transformer_decoder (
212
+ decoder_input , inputs , decoder_self_attention_bias , None , hparams )
169
213
170
- # Randomize decoder inputs..
171
- kl_loss *= common_layers . inverse_exp_decay ( max_prestep ) * 10.0
172
- return tf .expand_dims (ret , axis = 2 ), kl_loss
214
+ kl_loss *= common_layers . inverse_exp_decay ( int ( max_prestep * 1.5 )) * 5.0
215
+ losses = { "kl" : kl_loss }
216
+ return tf .expand_dims (ret , axis = 2 ), losses
173
217
174
218
175
219
@registry .register_model
@@ -203,13 +247,15 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
203
247
sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
204
248
samples = tf .concat (sharded_samples , 0 )
205
249
206
- # 2nd step.
207
- with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
208
- features ["targets" ] = samples
209
- sharded_logits , _ = self .model_fn (
210
- features , False , last_position_only = last_position_only )
211
- sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
212
- samples = tf .concat (sharded_samples , 0 )
250
+ # More steps.
251
+ how_many_more_steps = 20
252
+ for _ in xrange (how_many_more_steps ):
253
+ with tf .variable_scope (tf .get_variable_scope (), reuse = True ):
254
+ features ["targets" ] = samples
255
+ sharded_logits , _ = self .model_fn (
256
+ features , False , last_position_only = last_position_only )
257
+ sharded_samples = self ._data_parallelism (tf .argmax , sharded_logits , 4 )
258
+ samples = tf .concat (sharded_samples , 0 )
213
259
214
260
if inputs_old is not None : # Restore to not confuse Estimator.
215
261
features ["inputs" ] = inputs_old
@@ -221,9 +267,10 @@ def transformer_vae_small():
221
267
"""Set of hyperparameters."""
222
268
hparams = transformer .transformer_small ()
223
269
hparams .batch_size = 2048
270
+ hparams .learning_rate_warmup_steps = 16000
224
271
hparams .add_hparam ("z_size" , 128 )
225
272
hparams .add_hparam ("num_compress_steps" , 4 )
226
- hparams .add_hparam ("kl_warmup_steps" , 50000 )
273
+ hparams .add_hparam ("kl_warmup_steps" , 60000 )
227
274
return hparams
228
275
229
276
@@ -233,9 +280,9 @@ def transformer_vae_base():
233
280
hparams = transformer_vae_small ()
234
281
hparams .hidden_size = 512
235
282
hparams .filter_size = 2048
236
- hparams .attention_dropout = 0.1
237
- hparams .relu_dropout = 0.1
238
- hparams .dropout = 0.1
239
- hparams .num_hidden_layers = 4
283
+ hparams .attention_dropout = 0.0
284
+ hparams .relu_dropout = 0.0
285
+ hparams .dropout = 0.0
286
+ hparams .num_hidden_layers = 3
240
287
hparams .z_size = 256
241
288
return hparams
0 commit comments