Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0df0f50

Browse files
authored
Merge pull request #176 from EndingCredits/master
Alternative Transformer Fix + Slding Window Attention
2 parents 69e40fb + af52f5f commit 0df0f50

File tree

5 files changed

+338
-222
lines changed

5 files changed

+338
-222
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ _pycache__/
1010
# PyPI distribution artifacts.
1111
build/
1212
dist/
13+
data/
1314

1415
# Sublime project files
1516
*.sublime-project

tensor2tensor/models/common_attention.py

Lines changed: 108 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -344,23 +344,33 @@ def dot_product_attention(q,
344344
return tf.matmul(weights, v)
345345

346346

347-
def masked_local_attention_1d(
348-
q, k, v, block_length=128, name=None):
349-
"""Attention to the source position and a neigborhood to the left of it.
347+
def local_attention_1d(q, k, v, bias=None,
348+
block_length=128, look_right=True, use_whole_block=False,
349+
truncate_bias=True, name=None):
350+
"""Attention to the source position and a neigborhood around it.
350351
351-
The sequence is divided into blocks of length block_size.
352-
Attention for a given query position can only see memory positions
353-
less than or equal to the query position, in the corresponding block
354-
and the previous block.
352+
The sequence is divided into blocks of length block_size. Attention for a
353+
given query position can only see memory positions within a certain number
354+
of positions before and behind it.
355355
356-
If mask_right is True, then a target position cannot see greater source
356+
If look_right is True then each query will attend to block_length//2
357+
positions either side, otherwise it will attend to block_length previous
357358
positions.
358359
360+
If use_whole_block is True then no mask will be applied to the local blocks
361+
meaning the full blocks are used (if look_right is True then the elements to
362+
the right of the current position are still masked out). This allows use to
363+
attend to more elements without additional overhead, but means we have
364+
inconsistent window positions and sizes.
365+
359366
Args:
360-
q: a Tensor with shape [batch, heads, length, depth_k]
361-
k: a Tensor with shape [batch, heads, length, depth_k]
362-
v: a Tensor with shape [batch, heads, length, depth_v]
367+
q: a Tensor with shape [batch, heads, length_q, depth_k]
368+
k: a Tensor with shape [batch, heads, length_kv, depth_k]
369+
v: a Tensor with shape [batch, heads, length_kv, depth_v]
370+
bias: Not currently used [batch, heads, length_q, length_k]
363371
block_length: an integer
372+
look_right: a bool
373+
use_whole_block: a bool
364374
name: an optional string
365375
366376
Returns:
@@ -372,146 +382,110 @@ def masked_local_attention_1d(
372382
batch = tf.shape(q)[0]
373383
heads = tf.shape(q)[1]
374384
length = tf.shape(q)[2]
375-
# If (length < 2 * block_length), then we use only one block.
376-
block_length = tf.where(tf.less(length, block_length * 2),
377-
length, block_length)
378385
depth_k = tf.shape(q)[3]
379386
depth_v = tf.shape(v)[3]
387+
380388
original_length = length
389+
390+
#Pad to desired length
391+
#If (length < block_length), then we use only one block.
392+
block_length = tf.where(tf.less(length, block_length),
393+
length, block_length)
381394
padding_size = tf.mod(-length, block_length)
382395
length += padding_size
383-
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
384-
q = tf.pad(q, padding)
385-
k = tf.pad(k, padding)
386-
v = tf.pad(v, padding)
387396
num_blocks = tf.div(length, block_length)
388397

389-
# compute attention for the first query block.
390-
first_q = tf.slice(q, [0, 0, 0, 0], [-1, -1, block_length, -1])
391-
first_k = tf.slice(k, [0, 0, 0, 0], [-1, -1, block_length, -1])
392-
first_v = tf.slice(v, [0, 0, 0, 0], [-1, -1, block_length, -1])
393-
first_output = dot_product_attention(
394-
first_q, first_k, first_v, attention_bias_lower_triangle(block_length),
395-
name="fist_block")
398+
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
399+
q = tf.pad(q, padding)
396400

397-
# compute attention for all subsequent query blocks.
401+
if not look_right:
402+
#Add extra padding so we son't have to do an initial query
403+
extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]]
404+
bp = [[0, 0], [0, 0], [0, padding_size], [block_length, padding_size]]
405+
else:
406+
#We shift everything over by half a block so query is in centre
407+
pad_right = block_length // 2
408+
pad_left = block_length - pad_right
409+
extra_padding = [[0, 0], [0, 0],
410+
[pad_left, padding_size+pad_right], [0, 0]]
411+
bp = [[0, 0], [0, 0],
412+
[0, padding_size], [pad_left, padding_size+pad_right]]
413+
k = tf.pad(k, extra_padding)
414+
v = tf.pad(v, extra_padding)
415+
416+
# Reshape into blocks
398417
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
399-
k = tf.reshape(k, [batch, heads, num_blocks, block_length, depth_k])
400-
v = tf.reshape(v, [batch, heads, num_blocks, block_length, depth_v])
418+
k = tf.reshape(k, [batch, heads, num_blocks+1, block_length, depth_k])
419+
v = tf.reshape(v, [batch, heads, num_blocks+1, block_length, depth_v])
401420

421+
# Get local blocks by slicing
402422
def local(x):
403423
"""Create a local version of the keys or values."""
404424
prev_block = tf.slice(
405-
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks - 1, -1, -1])
425+
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1])
406426
cur_block = tf.slice(
407427
x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
408428
return tf.concat([prev_block, cur_block], 3)
409429
local_k = local(k)
410430
local_v = local(v)
411-
tail_q = tf.slice(q, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
412-
413431
local_length = tf.shape(local_k)[3]
414432

415-
# [batch, heads, num_blocks - 1, block_length, local_length]
416-
attention = tf.matmul(tail_q, local_k, transpose_b=True)
417-
418-
# make sure source_pos <= target_pos
419-
good_part = tf.matrix_band_part(
420-
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
421-
mask = (1.0 - good_part) * -1e9
422-
attention += tf.reshape(mask, [1, 1, 1, block_length, local_length])
433+
# [batch, heads, num_blocks, block_length, local_length]
434+
attention = tf.matmul(q, local_k, transpose_b=True)
435+
436+
# Apply bias (N.B: This is not currently working)
437+
if bias is not None:
438+
with tf.name_scope('bias'):
439+
b_batch = tf.shape(bias)[0]
440+
b_heads = tf.shape(bias)[1]
441+
bias_ = bias
442+
#bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0)
443+
if truncate_bias:
444+
# Use only the query dimension
445+
bias = tf.expand_dims(bias[:,:,:,0], 2)
446+
bias = tf.pad(bias, extra_padding, name='bias_pad_b')# 17, 5, 3
447+
bias = tf.reshape(bias,
448+
[b_batch, b_heads, 1, num_blocks+1, block_length],
449+
name='divide_blocks')
450+
local_b = tf.reshape(local(bias),
451+
[b_batch, b_heads, num_blocks, 1, -1], name='reshape_local')
452+
else:
453+
bias = tf.pad(bias, bp, name='pad')
454+
bias = tf.reshape(bias,
455+
[b_batch, b_heads, num_blocks, block_length,
456+
num_blocks+1, block_length], name='divide_blocks')
457+
bias = tf.transpose(bias, [4,2,0,1,3,5])
458+
bias = tf.reshape(bias,
459+
[num_blocks*(num_blocks+1), b_batch, b_heads,
460+
block_length, block_length], name='combine')
461+
indices = (num_blocks+1)*tf.range(num_blocks)
462+
prev_block = tf.gather(bias, indices)
463+
cur_block = tf.gather(bias, indices+num_blocks)
464+
local_b = tf.concat([prev_block, cur_block], 4)
465+
local_b = tf.transpose(local_b, [1,2,0,3,4])
466+
return l-local_b
467+
attention += local_b
468+
423469
attention = tf.nn.softmax(attention)
424-
# TODO(noam): figure out how to show a summary for the remaining blocks.
425-
# The naive way currently causes errors due to empty tensors.
426-
# output: [batch, heads, num_blocks-1, block_length, depth_v]
427-
output = tf.matmul(attention, local_v)
428-
output = tf.reshape(output, [batch, heads, -1, depth_v])
429-
output = tf.concat([first_output, output], axis=2)
430-
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
431-
output.set_shape(v_shape)
432-
return output
433-
470+
471+
# Get local mask
472+
if not use_whole_block:
473+
good_part = tf.matrix_band_part(
474+
tf.ones([block_length, local_length]), 0, tf.to_int64(block_length))
475+
elif not look_right:
476+
good_part = tf.matrix_band_part(
477+
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
478+
else:
479+
good_part = tf.ones([block_length, local_length])
434480

435-
def unmasked_local_attention_1d(q, k, v, block_length=128, filter_width=100,
436-
name=None):
437-
"""strided block local self-attention.
481+
#good_part = tf.cast(good_part, tf.float64)
482+
attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length])
438483

439-
Args:
440-
q: a Tensor with shape [batch, heads, length, depth_k]
441-
k: a Tensor with shape [batch, heads, length, depth_k]
442-
v: a Tensor with shape [batch, heads, length, depth_v]
443-
block_length: an integer
444-
filter_width: an integer indicating how much to look left.
445-
name: an optional string
484+
485+
output = tf.matmul(attention, local_v)
486+
output = tf.reshape(output, [batch, heads, -1, depth_v])
446487

447-
Returns:
448-
a Tensor of shape [batch, heads, length, depth_v]
449-
"""
450-
with tf.variable_scope(name, default_name="local_self_attention_1d",
451-
values=[q, k, v]):
452-
v_shape = v.get_shape()
453-
depth_v = tf.shape(v)[3]
454-
batch_size = tf.shape(q)[0]
455-
num_heads = tf.shape(q)[1]
456-
original_length = tf.shape(q)[2]
457-
# making sure q is a multiple of d
458-
def pad_to_multiple(x, pad_length):
459-
x_length = tf.shape(x)[2]
460-
return tf.pad(x, [[0, 0], [0, 0], [0, -x_length % pad_length], [0, 0]])
461-
def pad_l_and_r(x, pad_length):
462-
return tf.pad(x, [[0, 0], [0, 0], [pad_length, pad_length], [0, 0]])
463-
q = pad_to_multiple(q, block_length)
464-
k = pad_to_multiple(k, block_length)
465-
v = pad_to_multiple(v, block_length)
466-
467-
# Setting up q blocks
468-
new_q_shape = tf.shape(q)
469-
# Setting up q blocks
470-
q = tf.reshape(q, [new_q_shape[0], new_q_shape[1],
471-
new_q_shape[2]//block_length,
472-
block_length, new_q_shape[3]])
473-
474-
# Setting up k and v values
475-
k = pad_l_and_r(k, filter_width)
476-
v = pad_l_and_r(v, filter_width)
477-
478-
length = tf.shape(k)[2]
479-
full_filter_width = block_length + 2*filter_width
480-
# getting gather indices
481-
indices = tf.range(0, length, delta=1, name="index_range")
482-
# making indices [1, length, 1] to appy convs
483-
indices = tf.reshape(indices, [1, -1, 1])
484-
kernel = tf.expand_dims(tf.eye(full_filter_width), axis=1)
485-
gather_indices = tf.nn.conv1d(
486-
tf.cast(indices, tf.float32),
487-
kernel,
488-
block_length,
489-
padding="VALID",
490-
name="gather_conv")
491-
492-
gather_indices = tf.squeeze(tf.cast(gather_indices, tf.int32), axis=0)
493-
494-
# [length, batch, heads, dim]
495-
k_t = tf.transpose(k, [2, 0, 1, 3])
496-
k_new = tf.gather(k_t, gather_indices)
497-
498-
# [batch, heads, blocks, block_length, dim]
499-
k_new = tf.transpose(k_new, [2, 3, 0, 1, 4])
500-
501-
attention_bias = tf.expand_dims(
502-
tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2)
503-
504-
v_t = tf.transpose(v, [2, 0, 1, 3])
505-
v_new = tf.gather(v_t, gather_indices)
506-
v_new = tf.transpose(v_new, [2, 3, 0, 1, 4])
507-
508-
logits = tf.matmul(q, k_new, transpose_b=True)
509-
510-
attention = tf.nn.softmax(logits+attention_bias)
511-
output = tf.matmul(attention, v_new)
512-
513-
output = tf.reshape(output, [batch_size, num_heads, -1, depth_v])
514-
# Remove the padding if introduced
488+
# Remove added padding
515489
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
516490
output.set_shape(v_shape)
517491
return output
@@ -542,8 +516,8 @@ def multihead_attention(query_antecedent,
542516
dropout_rate: a floating point number
543517
image_shapes: optional tuple of integer scalars.
544518
see comments for attention_image_summary()
545-
attention_type: a string, either "dot_product" or "local_mask_right" or
546-
"local_unmasked"
519+
attention_type: a string, either "dot_product" or "local" or
520+
"local_mask_right"
547521
block_length: an integer - relevant for "local_mask_right"
548522
name: an optional string
549523
@@ -592,11 +566,12 @@ def multihead_attention(query_antecedent,
592566
if attention_type == "dot_product":
593567
x = dot_product_attention(
594568
q, k, v, bias, dropout_rate, image_shapes)
595-
elif attention_type == "local_mask_right":
596-
x = masked_local_attention_1d(q, k, v, block_length=block_length)
569+
elif attention_type == "local":
570+
x = local_attention_1d(q, k, v, block_length=block_length)
597571
else:
598-
assert attention_type == "local_unmasked"
599-
x = unmasked_local_attention_1d(q, k, v, block_length=block_length)
572+
assert attention_type == "local_mask_right"
573+
x = local_attention_1d(
574+
q, k, v, block_length=block_length, look_right=False)
600575
x = combine_heads(x)
601576
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
602577
return x
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for common layers."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
# Dependency imports
22+
23+
import numpy as np
24+
from tensor2tensor.models import common_attention
25+
26+
import tensorflow as tf
27+
28+
29+
class CommonAttentionTest(tf.test.TestCase):
30+
31+
def testLocalAttention(self):
32+
q = np.array([[[ [1.0, 0.0, 0.0, 0.0],
33+
[1.0, 0.0, 0.0, 0.0],
34+
[1.0, 0.0, 0.0, 0.0],
35+
[1.0, 0.0, 0.0, 0.0],
36+
[1.0, 0.0, 0.0, 0.0],
37+
[1.0, 0.0, 0.0, 0.0],
38+
[1.0, 0.0, 0.0, 0.0],
39+
[1.0, 0.0, 0.0, 0.0] ]]])
40+
41+
k = np.array([[[ [1.0, 0.0, 0.0, 0.0],
42+
[1.0, 0.0, 0.0, 0.0],
43+
[1.0, 0.0, 0.0, 0.0],
44+
[1.0, 0.0, 0.0, 0.0],
45+
[1.0, 0.0, 0.0, 0.0],
46+
[1.0, 0.0, 0.0, 0.0],
47+
[1.0, 0.0, 0.0, 0.0],
48+
[1.0, 0.0, 0.0, 0.0] ]]])
49+
50+
b = np.array([[[ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
51+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
52+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
53+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
54+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
55+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
56+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
57+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ]]])
58+
59+
#b = np.ones((1,1,8,8))
60+
#b = (1-b) * (-1e9)
61+
v = np.ones((1, 1, 8, 1))
62+
63+
#q = np.random.rand(5, 7, 13, 3)
64+
#k = np.random.rand(5, 7, 13, 3)
65+
#v = np.random.rand(5, 7, 13, 11)
66+
#b = np.random.rand(5, 1, 13, 1)
67+
68+
with self.test_session() as session:
69+
q_ = tf.constant(q)
70+
k_ = tf.constant(k)
71+
v_ = tf.constant(v)
72+
b_ = tf.constant(b)
73+
y = common_attention.local_attention_1d(q_, k_, v_, b_, block_length=tf.constant(2))
74+
res = session.run(y)
75+
#print(q)
76+
#rint(k)
77+
print(res)
78+
#self.assertEqual(res.shape, (5, 7, 13, 11))
79+
80+
81+
if __name__ == "__main__":
82+
tf.test.main()

0 commit comments

Comments
 (0)