Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit af4f1e0

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Simplify calls to embedding_to_padding, we always end up converting the padding mask to a float tensor.
PiperOrigin-RevId: 164777753
1 parent ae49192 commit af4f1e0

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,17 @@ def add_positional_embedding_nd(x, max_length, name):
166166

167167

168168
def embedding_to_padding(emb):
169-
"""Input embeddings -> is_padding.
169+
"""Calculates the padding mask based on which embeddings are all zero.
170170
171171
We have hacked symbol_modality to return all-zero embeddings for padding.
172172
173173
Args:
174174
emb: a Tensor with shape [..., depth].
175175
Returns:
176-
a boolean Tensor with shape [...].
176+
a float Tensor with shape [...].
177177
"""
178178
emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1)
179-
return tf.equal(emb_sum, 0.0)
179+
return tf.to_float(tf.equal(emb_sum, 0.0))
180180

181181

182182
def attention_bias_lower_triangle(length):
@@ -197,13 +197,13 @@ def attention_bias_ignore_padding(memory_padding):
197197
"""Create an bias tensor to be added to attention logits.
198198
199199
Args:
200-
memory_padding: a boolean `Tensor` with shape [batch, memory_length].
200+
memory_padding: a float `Tensor` with shape [batch, memory_length].
201201
202202
Returns:
203203
a `Tensor` with shape [batch, 1, 1, memory_length].
204204
"""
205-
ret = tf.to_float(memory_padding) * -1e9
206-
return tf.expand_dims(tf.expand_dims(ret, 1), 1)
205+
ret = memory_padding * -1e9
206+
return tf.expand_dims(tf.expand_dims(ret, axis=1), axis=1)
207207

208208

209209
def attention_bias_proximal(length):
@@ -523,8 +523,7 @@ def pad_l_and_r(x, pad_length):
523523
# [batch, heads, blocks, block_length, dim]
524524
k_new = tf.transpose(k_new, [2, 3, 0, 1, 4])
525525

526-
attention_bias = tf.expand_dims(
527-
tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2)
526+
attention_bias = tf.expand_dims(embedding_to_padding(k_new) * -1e9, axis=-2)
528527

529528
v_t = tf.transpose(v, [2, 0, 1, 3])
530529
v_new = tf.gather(v_t, gather_indices)

0 commit comments

Comments
 (0)