@@ -166,17 +166,17 @@ def add_positional_embedding_nd(x, max_length, name):
166
166
167
167
168
168
def embedding_to_padding (emb ):
169
- """Input embeddings -> is_padding .
169
+ """Calculates the padding mask based on which embeddings are all zero .
170
170
171
171
We have hacked symbol_modality to return all-zero embeddings for padding.
172
172
173
173
Args:
174
174
emb: a Tensor with shape [..., depth].
175
175
Returns:
176
- a boolean Tensor with shape [...].
176
+ a float Tensor with shape [...].
177
177
"""
178
178
emb_sum = tf .reduce_sum (tf .abs (emb ), axis = - 1 )
179
- return tf .equal (emb_sum , 0.0 )
179
+ return tf .to_float ( tf . equal (emb_sum , 0.0 ) )
180
180
181
181
182
182
def attention_bias_lower_triangle (length ):
@@ -197,13 +197,13 @@ def attention_bias_ignore_padding(memory_padding):
197
197
"""Create an bias tensor to be added to attention logits.
198
198
199
199
Args:
200
- memory_padding: a boolean `Tensor` with shape [batch, memory_length].
200
+ memory_padding: a float `Tensor` with shape [batch, memory_length].
201
201
202
202
Returns:
203
203
a `Tensor` with shape [batch, 1, 1, memory_length].
204
204
"""
205
- ret = tf . to_float ( memory_padding ) * - 1e9
206
- return tf .expand_dims (tf .expand_dims (ret , 1 ), 1 )
205
+ ret = memory_padding * - 1e9
206
+ return tf .expand_dims (tf .expand_dims (ret , axis = 1 ), axis = 1 )
207
207
208
208
209
209
def attention_bias_proximal (length ):
@@ -523,8 +523,7 @@ def pad_l_and_r(x, pad_length):
523
523
# [batch, heads, blocks, block_length, dim]
524
524
k_new = tf .transpose (k_new , [2 , 3 , 0 , 1 , 4 ])
525
525
526
- attention_bias = tf .expand_dims (
527
- tf .to_float (embedding_to_padding (k_new )) * - 1e9 , axis = - 2 )
526
+ attention_bias = tf .expand_dims (embedding_to_padding (k_new ) * - 1e9 , axis = - 2 )
528
527
529
528
v_t = tf .transpose (v , [2 , 0 , 1 , 3 ])
530
529
v_new = tf .gather (v_t , gather_indices )
0 commit comments