Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c104976

Browse files
eli7copybara-github
authored andcommitted
adding hparam to make encoder self-attention optional.
PiperOrigin-RevId: 313707269
1 parent f65b5e4 commit c104976

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

tensor2tensor/models/evolved_transformer.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -223,34 +223,35 @@ def evolved_transformer_encoder(encoder_input,
223223
hidden_state = common_layers.layer_postprocess(
224224
residual_state, hidden_state, hparams)
225225

226-
with tf.variable_scope("self_attention"):
227-
residual_state = hidden_state
228-
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
226+
if hparams.get("et_encoder_self_attention", True):
227+
with tf.variable_scope("self_attention"):
228+
residual_state = hidden_state
229+
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
229230

230-
hidden_state = common_attention.multihead_attention(
231-
hidden_state,
232-
None,
233-
encoder_self_attention_bias,
234-
hparams.attention_key_channels or hparams.hidden_size,
235-
hparams.attention_value_channels or hparams.hidden_size,
236-
hparams.hidden_size,
237-
hparams.num_heads,
238-
hparams.attention_dropout,
239-
attention_type=hparams.self_attention_type,
240-
max_relative_position=hparams.max_relative_position,
241-
heads_share_relative_embedding=(
242-
hparams.heads_share_relative_embedding),
243-
add_relative_to_values=hparams.add_relative_to_values,
244-
save_weights_to=save_weights_to,
245-
make_image_summary=make_image_summary,
246-
dropout_broadcast_dims=attention_dropout_broadcast_dims,
247-
max_length=hparams.get("max_length"),
248-
vars_3d=hparams.get("attention_variables_3d"),
249-
activation_dtype=hparams.get("activation_dtype", "float32"),
250-
weight_dtype=hparams.get("weight_dtype", "float32"))
231+
hidden_state = common_attention.multihead_attention(
232+
hidden_state,
233+
None,
234+
encoder_self_attention_bias,
235+
hparams.attention_key_channels or hparams.hidden_size,
236+
hparams.attention_value_channels or hparams.hidden_size,
237+
hparams.hidden_size,
238+
hparams.num_heads,
239+
hparams.attention_dropout,
240+
attention_type=hparams.self_attention_type,
241+
max_relative_position=hparams.max_relative_position,
242+
heads_share_relative_embedding=(
243+
hparams.heads_share_relative_embedding),
244+
add_relative_to_values=hparams.add_relative_to_values,
245+
save_weights_to=save_weights_to,
246+
make_image_summary=make_image_summary,
247+
dropout_broadcast_dims=attention_dropout_broadcast_dims,
248+
max_length=hparams.get("max_length"),
249+
vars_3d=hparams.get("attention_variables_3d"),
250+
activation_dtype=hparams.get("activation_dtype", "float32"),
251+
weight_dtype=hparams.get("weight_dtype", "float32"))
251252

252-
hidden_state = common_layers.layer_postprocess(
253-
residual_state, hidden_state, hparams)
253+
hidden_state = common_layers.layer_postprocess(
254+
residual_state, hidden_state, hparams)
254255

255256
with tf.variable_scope("dense_layers"):
256257
residual_state = hidden_state

0 commit comments

Comments
 (0)