Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a0bd017

Browse files
Ashish VaswaniRyan Sepassi
authored andcommitted
Reverted back to the previous masked_local_attention_1d because the current one was giving 0 losses indicating that it was peeking into the future. The way the attention bias was being added also seemed wrong. Renamed unmasked_local_attention_1d to local_attention_1d. The user can specify local_attention_1d if they want to look left and right of the query block.
PiperOrigin-RevId: 164312109
1 parent e8ae589 commit a0bd017

File tree

2 files changed

+57
-79
lines changed

2 files changed

+57
-79
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -361,122 +361,100 @@ def dot_product_attention(q,
361361
return tf.matmul(weights, v)
362362

363363

364-
def masked_local_attention_1d(q,
365-
k,
366-
v,
367-
block_length=128,
368-
look_right=True,
369-
use_whole_block=False,
370-
name=None):
371-
"""Attention to the source position and a neigborhood around it.
372-
373-
The sequence is divided into blocks of length block_size. Attention for a
374-
given query position can only see memory positions within a certain number
375-
of positions before and behind it.
376-
377-
378-
If look_right is True then each query will attend to block_length//2
379-
positions either side, otherwise it will attend to block_length previous
380-
positions.
364+
def masked_local_attention_1d(
365+
q, k, v, block_length=128, name=None):
366+
"""Attention to the source position and a neigborhood to the left of it.
367+
368+
The sequence is divided into blocks of length block_size.
369+
Attention for a given query position can only see memory positions
370+
less than or equal to the query position, in the corresponding block
371+
and the previous block.
381372
382-
If use_whole_block is True then no mask will be applied to the local blocks
383-
meaning the full blocks are used (if look_right is True then the elements to
384-
the right of the current position are still masked out). This allows to
385-
attend to more elements without additional overhead, but means we have
386-
inconsistent window positions and sizes.
373+
If mask_right is True, then a target position cannot see greater source
374+
positions.
387375
388376
Args:
389-
q: a Tensor with shape [batch, heads, length_q, depth_k]
390-
k: a Tensor with shape [batch, heads, length_kv, depth_k]
391-
v: a Tensor with shape [batch, heads, length_kv, depth_v]
377+
q: a Tensor with shape [batch, heads, length, depth_k]
378+
k: a Tensor with shape [batch, heads, length, depth_k]
379+
v: a Tensor with shape [batch, heads, length, depth_v]
392380
block_length: an integer
393-
look_right: a bool
394-
use_whole_block: a bool
395381
name: an optional string
396382
397383
Returns:
398384
a Tensor of shape [batch, heads, length, depth_v]
399385
"""
400-
with tf.variable_scope(
401-
name, default_name="local_attention_1d", values=[q, k, v]):
386+
with tf.variable_scope(name, default_name="local_attention_1d",
387+
values=[q, k, v]):
402388
v_shape = v.get_shape()
403389
batch = tf.shape(q)[0]
404390
heads = tf.shape(q)[1]
405391
length = tf.shape(q)[2]
392+
# If (length < 2 * block_length), then we use only one block.
393+
block_length = tf.where(tf.less(length, block_length * 2),
394+
length, block_length)
406395
depth_k = tf.shape(q)[3]
407396
depth_v = tf.shape(v)[3]
408397
original_length = length
409-
410-
# If (length < block_length), then we use only one block.
411-
block_length = tf.where(tf.less(length, block_length), length, block_length)
412-
# Pad to desired length.
413398
padding_size = tf.mod(-length, block_length)
414399
length += padding_size
415-
num_blocks = tf.div(length, block_length)
416400
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
417401
q = tf.pad(q, padding)
402+
k = tf.pad(k, padding)
403+
v = tf.pad(v, padding)
404+
num_blocks = tf.div(length, block_length)
418405

419-
if not look_right:
420-
# Add extra padding so we son't have to do an initial query block.
421-
extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]]
422-
else:
423-
# We shift everything over by half a block so query is in center.
424-
pad_right = block_length // 2
425-
pad_left = block_length - pad_right
426-
extra_padding = [[0, 0], [0, 0], [pad_left, padding_size + pad_right],
427-
[0, 0]]
428-
k = tf.pad(k, extra_padding)
429-
v = tf.pad(v, extra_padding)
430-
431-
# Reshape into blocks.
406+
# compute attention for the first query block.
407+
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
408+
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
409+
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
410+
first_output = dot_product_attention(
411+
first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
412+
name="fist_block")
413+
414+
# compute attention for all subsequent query blocks.
432415
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
433-
k = tf.reshape(k, [batch, heads, num_blocks + 1, block_length, depth_k])
434-
v = tf.reshape(v, [batch, heads, num_blocks + 1, block_length, depth_v])
416+
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
417+
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
435418

436-
# Get local blocks by slicing.
437419
def local(x):
438420
"""Create a local version of the keys or values."""
439-
prev_block = tf.slice(x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1])
440-
cur_block = tf.slice(x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
421+
prev_block = tf.slice(
422+
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1])
423+
cur_block = tf.slice(
424+
x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
441425
return tf.concat([prev_block, cur_block], 3)
442-
443426
local_k = local(k)
444427
local_v = local(v)
445-
local_length = tf.shape(local_k)[3]
428+
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
446429

447-
# [batch, heads, num_blocks, block_length, local_length]
448-
attention = tf.matmul(q, local_k, transpose_b=True)
449-
attention = tf.nn.softmax(attention)
450-
451-
# Get local mask
452-
if not use_whole_block:
453-
good_part = tf.matrix_band_part(
454-
tf.ones([block_length, local_length]), 0, tf.to_int64(block_length))
455-
elif not look_right:
456-
good_part = tf.matrix_band_part(
457-
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
458-
else:
459-
good_part = tf.ones([block_length, local_length])
430+
local_length = tf.shape(local_k)[3]
460431

461-
attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length])
432+
# [batch, heads, num_blocks - 1, block_length, local_length]
433+
attention = tf.matmul(tail_q, local_k, transpose_b=True)
462434

435+
# make sure source_pos <= target_pos
436+
good_part = tf.matrix_band_part(
437+
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
438+
mask = (1.0 - good_part) * -1e9
439+
attention += tf.reshape(mask, [1, 1, 1, block_length, local_length])
440+
attention = tf.nn.softmax(attention)
463441
# TODO(noam): figure out how to show a summary for the remaining blocks.
464442
# The naive way currently causes errors due to empty tensors.
443+
# output: [batch, heads, num_blocks-1, block_length, depth_v]
465444
output = tf.matmul(attention, local_v)
466445
output = tf.reshape(output, [batch, heads, -1, depth_v])
467-
468-
# Remove added padding
446+
output = tf.concat([first_output, output], axis=2)
469447
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
470448
output.set_shape(v_shape)
471449
return output
472450

473451

474-
def unmasked_local_attention_1d(q,
475-
k,
476-
v,
477-
block_length=128,
478-
filter_width=100,
479-
name=None):
452+
def local_attention_1d(q,
453+
k,
454+
v,
455+
block_length=128,
456+
filter_width=100,
457+
name=None):
480458
"""strided block local self-attention.
481459
482460
Args:
@@ -644,7 +622,7 @@ def multihead_attention(query_antecedent,
644622
x = masked_local_attention_1d(q, k, v, block_length=block_length)
645623
else:
646624
assert attention_type == "local_unmasked"
647-
x = unmasked_local_attention_1d(
625+
x = local_attention_1d(
648626
q, k, v, block_length=block_length, filter_width=block_width)
649627
x = combine_heads(x)
650628
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")

tensor2tensor/layers/common_attention_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def testLocalUnmaskedAttention(self):
6565
x = np.random.rand(5, 4, 25, 16)
6666
y = np.random.rand(5, 4, 25, 16)
6767
with self.test_session() as session:
68-
a = common_attention.unmasked_local_attention_1d(
68+
a = common_attention.local_attention_1d(
6969
tf.constant(x, dtype=tf.float32),
7070
tf.constant(y, dtype=tf.float32),
7171
tf.constant(y, dtype=tf.float32),
@@ -79,7 +79,7 @@ def testLocalUnmaskedAttentionMatchingBlockLength(self):
7979
x = np.random.rand(5, 4, 25, 16)
8080
y = np.random.rand(5, 4, 25, 16)
8181
with self.test_session() as session:
82-
a = common_attention.unmasked_local_attention_1d(
82+
a = common_attention.local_attention_1d(
8383
tf.constant(x, dtype=tf.float32),
8484
tf.constant(y, dtype=tf.float32),
8585
tf.constant(y, dtype=tf.float32),

0 commit comments

Comments
 (0)