Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 35416da

Browse files
alexykuRyan Sepassi
authored andcommitted
adding function for local_attention_2d
PiperOrigin-RevId: 164869818
1 parent 12c59a7 commit 35416da

File tree

2 files changed

+142
-9
lines changed

2 files changed

+142
-9
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,111 @@ def pad_l_and_r(x, pad_length):
541541
return output
542542

543543

544+
def local_attention_2d(q,
545+
k,
546+
v,
547+
block_length=128,
548+
filter_flange=100,
549+
name=None):
550+
"""strided block local self-attention.
551+
552+
Args:
553+
q: a Tensor with shape [batch, heads, h, w, depth_k]
554+
k: a Tensor with shape [batch, heads, h, w, depth_k]
555+
v: a Tensor with shape [batch, heads, h, w, depth_v]
556+
block_length: an integer indicating the side length of each square block.
557+
filter_flange: an integer indicating how much to look around each block.
558+
name: an optional string
559+
560+
Returns:
561+
a Tensor of shape [batch, heads, h, w, depth_v]
562+
"""
563+
with tf.variable_scope(
564+
name, default_name="local_self_attention_2d", values=[q, k, v]):
565+
v_shape = tf.shape(v)
566+
depth_v = tf.shape(v)[4]
567+
batch_size = tf.shape(q)[0]
568+
num_heads = tf.shape(q)[1]
569+
original_length = tf.shape(q)[2] * tf.shape(q)[3]
570+
571+
def reshape_range(tensor, i, j, shape):
572+
"""Reshapes a tensor between dimensions i and j."""
573+
target_shape = tf.concat(
574+
[tf.shape(tensor)[:i], shape, tf.shape(tensor)[j:]],
575+
axis=0)
576+
return tf.reshape(tensor, target_shape)
577+
578+
def pad_to_multiple(x, d):
579+
"""Making sure x is a multiple of d."""
580+
height_padding = -tf.shape(x)[1] % d
581+
width_padding = -tf.shape(x)[2] % d
582+
paddings = [[0, 0], [0, 0], [0, height_padding],
583+
[0, width_padding], [0, 0]]
584+
return tf.pad(x, paddings)
585+
586+
def gather_indices(x, block_length, stride):
587+
"""Getting gather indices."""
588+
# making an identity matrix kernel
589+
kernel = tf.eye(block_length ** 2)
590+
kernel = reshape_range(kernel, 0, 1, [block_length, block_length, 1])
591+
# making indices [1, h, w, 1] to appy convs
592+
indices = tf.range(0, tf.shape(x)[2] * tf.shape(x)[3], delta=1)
593+
indices = tf.reshape(indices, [1, tf.shape(x)[2], tf.shape(x)[3], 1])
594+
indices = tf.nn.conv2d(
595+
tf.cast(indices, tf.float32),
596+
kernel,
597+
strides=[1, stride, stride, 1],
598+
padding="VALID")
599+
# making indices [num_blocks, dim] to gather
600+
num_blocks = tf.reduce_prod(tf.shape(indices)[:2])
601+
indices = tf.reshape(indices, [num_blocks, -1])
602+
return tf.cast(indices, tf.int32)
603+
604+
def gather_blocks(x, indices):
605+
"""Gathers flattened blocks from x."""
606+
x_shape = tf.shape(x)
607+
x = reshape_range(x, 2, 4, [tf.reduce_prod(x_shape[2:4])])
608+
# [length, batch, heads, dim]
609+
x_t = tf.transpose(x, [2, 0, 1, 3])
610+
x_new = tf.gather(x_t, indices)
611+
# returns [batch, heads, num_blocks, block_length ** 2, dim]
612+
return tf.transpose(x_new, [2, 3, 0, 1, 4])
613+
614+
q = pad_to_multiple(q, block_length)
615+
k = pad_to_multiple(k, block_length)
616+
v = pad_to_multiple(v, block_length)
617+
618+
# Setting up k and v values
619+
paddings = [[0, 0], [0, 0], [filter_flange, filter_flange],
620+
[filter_flange, filter_flange], [0, 0]]
621+
k = tf.pad(k, paddings)
622+
v = tf.pad(v, paddings)
623+
624+
# Setting up q blocks
625+
q_indices = gather_indices(q, block_length, block_length)
626+
q_new = gather_blocks(q, q_indices)
627+
628+
# Setting up k and v blocks
629+
full_filter_width = block_length + 2 * filter_flange
630+
k_and_v_indices = gather_indices(k, full_filter_width, block_length)
631+
k_new = gather_blocks(k, k_and_v_indices)
632+
v_new = gather_blocks(v, k_and_v_indices)
633+
634+
attention_bias = tf.expand_dims(
635+
tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2)
636+
637+
logits = tf.matmul(q_new, k_new, transpose_b=True)
638+
639+
attention = tf.nn.softmax(logits + attention_bias)
640+
output = tf.matmul(attention, v_new)
641+
642+
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
643+
# Remove the padding if introduced
644+
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
645+
# [batch, heads, h, w, depth_v]
646+
return tf.reshape(output, v_shape)
647+
648+
544649
def multihead_attention(query_antecedent,
545650
memory_antecedent,
546651
bias,

tensor2tensor/layers/common_attention_test.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def testDotProductAttention(self):
4141
res = session.run(a)
4242
self.assertEqual(res.shape, (5, 7, 12, 32))
4343

44-
def testMaskedLocalAttention(self):
45-
q = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [
46-
1.0, 0.0, 0.0, 0.0
47-
], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
44+
def testMaskedLocalAttention1D(self):
45+
q = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
46+
[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
47+
[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
4848
[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]]])
49-
k = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [
50-
1.0, 0.0, 0.0, 0.0
51-
], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
49+
k = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
50+
[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
51+
[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0],
5252
[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]]])
5353
v = np.ones((1, 1, 8, 1))
5454
with self.test_session() as session:
@@ -61,7 +61,7 @@ def testMaskedLocalAttention(self):
6161

6262
self.assertEqual(res.shape, (1, 1, 8, 1))
6363

64-
def testLocalUnmaskedAttention(self):
64+
def testLocalUnmaskedAttention1D(self):
6565
x = np.random.rand(5, 4, 25, 16)
6666
y = np.random.rand(5, 4, 25, 16)
6767
with self.test_session() as session:
@@ -75,7 +75,7 @@ def testLocalUnmaskedAttention(self):
7575
res = session.run(a)
7676
self.assertEqual(res.shape, (5, 4, 25, 16))
7777

78-
def testLocalUnmaskedAttentionMatchingBlockLength(self):
78+
def testLocalUnmaskedAttention1DMatchingBlockLength(self):
7979
x = np.random.rand(5, 4, 25, 16)
8080
y = np.random.rand(5, 4, 25, 16)
8181
with self.test_session() as session:
@@ -89,6 +89,34 @@ def testLocalUnmaskedAttentionMatchingBlockLength(self):
8989
res = session.run(a)
9090
self.assertEqual(res.shape, (5, 4, 25, 16))
9191

92+
def testLocalUnmaskedAttention2D(self):
93+
x = np.random.rand(5, 4, 25, 25, 16)
94+
y = np.random.rand(5, 4, 25, 25, 16)
95+
with self.test_session() as session:
96+
a = common_attention.local_attention_2d(
97+
tf.constant(x, dtype=tf.float32),
98+
tf.constant(y, dtype=tf.float32),
99+
tf.constant(y, dtype=tf.float32),
100+
block_length=4,
101+
filter_flange=3)
102+
session.run(tf.global_variables_initializer())
103+
res = session.run(a)
104+
self.assertEqual(res.shape, (5, 4, 25, 25, 16))
105+
106+
def testLocalUnmaskedAttention2DMatchingBlockLength(self):
107+
x = np.random.rand(5, 4, 25, 25, 16)
108+
y = np.random.rand(5, 4, 25, 25, 16)
109+
with self.test_session() as session:
110+
a = common_attention.local_attention_2d(
111+
tf.constant(x, dtype=tf.float32),
112+
tf.constant(y, dtype=tf.float32),
113+
tf.constant(y, dtype=tf.float32),
114+
block_length=5,
115+
filter_flange=3)
116+
session.run(tf.global_variables_initializer())
117+
res = session.run(a)
118+
self.assertEqual(res.shape, (5, 4, 25, 25, 16))
119+
92120

93121
if __name__ == "__main__":
94122
tf.test.main()

0 commit comments

Comments
 (0)