@@ -223,34 +223,35 @@ def evolved_transformer_encoder(encoder_input,
223
223
hidden_state = common_layers .layer_postprocess (
224
224
residual_state , hidden_state , hparams )
225
225
226
- with tf .variable_scope ("self_attention" ):
227
- residual_state = hidden_state
228
- hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
226
+ if hparams .get ("et_encoder_self_attention" , True ):
227
+ with tf .variable_scope ("self_attention" ):
228
+ residual_state = hidden_state
229
+ hidden_state = common_layers .layer_preprocess (hidden_state , hparams )
229
230
230
- hidden_state = common_attention .multihead_attention (
231
- hidden_state ,
232
- None ,
233
- encoder_self_attention_bias ,
234
- hparams .attention_key_channels or hparams .hidden_size ,
235
- hparams .attention_value_channels or hparams .hidden_size ,
236
- hparams .hidden_size ,
237
- hparams .num_heads ,
238
- hparams .attention_dropout ,
239
- attention_type = hparams .self_attention_type ,
240
- max_relative_position = hparams .max_relative_position ,
241
- heads_share_relative_embedding = (
242
- hparams .heads_share_relative_embedding ),
243
- add_relative_to_values = hparams .add_relative_to_values ,
244
- save_weights_to = save_weights_to ,
245
- make_image_summary = make_image_summary ,
246
- dropout_broadcast_dims = attention_dropout_broadcast_dims ,
247
- max_length = hparams .get ("max_length" ),
248
- vars_3d = hparams .get ("attention_variables_3d" ),
249
- activation_dtype = hparams .get ("activation_dtype" , "float32" ),
250
- weight_dtype = hparams .get ("weight_dtype" , "float32" ))
231
+ hidden_state = common_attention .multihead_attention (
232
+ hidden_state ,
233
+ None ,
234
+ encoder_self_attention_bias ,
235
+ hparams .attention_key_channels or hparams .hidden_size ,
236
+ hparams .attention_value_channels or hparams .hidden_size ,
237
+ hparams .hidden_size ,
238
+ hparams .num_heads ,
239
+ hparams .attention_dropout ,
240
+ attention_type = hparams .self_attention_type ,
241
+ max_relative_position = hparams .max_relative_position ,
242
+ heads_share_relative_embedding = (
243
+ hparams .heads_share_relative_embedding ),
244
+ add_relative_to_values = hparams .add_relative_to_values ,
245
+ save_weights_to = save_weights_to ,
246
+ make_image_summary = make_image_summary ,
247
+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
248
+ max_length = hparams .get ("max_length" ),
249
+ vars_3d = hparams .get ("attention_variables_3d" ),
250
+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
251
+ weight_dtype = hparams .get ("weight_dtype" , "float32" ))
251
252
252
- hidden_state = common_layers .layer_postprocess (
253
- residual_state , hidden_state , hparams )
253
+ hidden_state = common_layers .layer_postprocess (
254
+ residual_state , hidden_state , hparams )
254
255
255
256
with tf .variable_scope ("dense_layers" ):
256
257
residual_state = hidden_state
0 commit comments