Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2ced78d

Browse files
committed
Unify methods and started work on Bias
1 parent d6a6924 commit 2ced78d

File tree

3 files changed

+138
-138
lines changed

3 files changed

+138
-138
lines changed

tensor2tensor/models/common_attention.py

Lines changed: 94 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -345,24 +345,34 @@ def dot_product_attention(q,
345345
return tf.matmul(weights, v)
346346

347347

348-
def masked_local_attention_1d(
349-
q, k, v, block_length=128, mask_right=False, name=None):
350-
"""Attention to the source position and a neigborhood to the left of it.
351348

352-
The sequence is divided into blocks of length block_size.
353-
Attention for a given query position can only see memory positions
354-
less than or equal to the query position, in the corresponding block
355-
and the previous block.
349+
def local_attention_1d(q, k, v, bias=None,
350+
block_length=128, look_right=True, use_whole_block=False,
351+
truncate_bias=True, name=None):
352+
"""Attention to the source position and a neigborhood around it.
356353
357-
If mask_right is True, then a target position cannot see greater source
354+
The sequence is divided into blocks of length block_size. Attention for a
355+
given query position can only see memory positions within a certain number
356+
of positions before and behind it.
357+
358+
If look_right is True then each query will attend to block_length//2
359+
positions either side, otherwise it will attend to block_length previous
358360
positions.
359361
362+
If use_whole_block is True then no mask will be applied to the local blocks
363+
meaning the full blocks are used (if look_right is True then the elements to
364+
the right of the current position are still masked out). This allows use to
365+
attend to more elements without additional overhead, but means we have
366+
inconsistent window positions and sizes.
367+
360368
Args:
361-
q: a Tensor with shape [batch, heads, length, depth_k]
362-
k: a Tensor with shape [batch, heads, length, depth_k]
363-
v: a Tensor with shape [batch, heads, length, depth_v]
369+
q: a Tensor with shape [batch, heads, length_q, depth_k]
370+
k: a Tensor with shape [batch, heads, length_kv, depth_k]
371+
v: a Tensor with shape [batch, heads, length_kv, depth_v]
372+
bias: Not currently used [batch, heads, length_q, length_k]
364373
block_length: an integer
365-
mask_right: a bool
374+
look_right: a bool
375+
use_whole_block: a bool
366376
name: an optional string
367377
368378
Returns:
@@ -379,8 +389,9 @@ def masked_local_attention_1d(
379389

380390
original_length = length
381391

382-
# If (length < 2 * block_length), then we use only one block.
383-
block_length = tf.where(tf.less(length, block_length * 2),
392+
#Pad to desired length
393+
#If (length < 2 * block_length), then we use only one block.
394+
block_length = tf.where(tf.less(length, block_length),
384395
length, block_length)
385396
padding_size = tf.mod(-length, block_length)
386397
length += padding_size
@@ -389,134 +400,100 @@ def masked_local_attention_1d(
389400
padding = [[0, 0], [0, 0], [0, padding_size], [0, 0]]
390401
q = tf.pad(q, padding)
391402

392-
if mask_right:
403+
if not look_right:
393404
#Add extra padding so we son't have to do an initial query
394405
extra_padding = [[0, 0], [0, 0], [block_length, padding_size], [0, 0]]
406+
bp = [[0, 0], [0, 0], [0, padding_size], [block_length, padding_size]]
395407
else:
396408
#We shift everything over by half a block so query is in centre
397409
pad_right = block_length // 2
398410
pad_left = block_length - pad_right
399411
extra_padding = [[0, 0], [0, 0],
400-
[pad_left,padding_size+pad_right], [0, 0]]
401-
412+
[pad_left, padding_size+pad_right], [0, 0]]
413+
bp = [[0, 0], [0, 0],
414+
[0, padding_size], [pad_left, padding_size+pad_right]]
402415
k = tf.pad(k, extra_padding)
403416
v = tf.pad(v, extra_padding)
404417

405-
406-
# compute attention for all subsequent query blocks.
418+
# Reshape into blocks
407419
q = tf.reshape(q, [batch, heads, num_blocks, block_length, depth_k])
408420
k = tf.reshape(k, [batch, heads, num_blocks+1, block_length, depth_k])
409421
v = tf.reshape(v, [batch, heads, num_blocks+1, block_length, depth_v])
410422

423+
# Get local blocks by slicing
411424
def local(x):
412425
"""Create a local version of the keys or values."""
413426
prev_block = tf.slice(
414427
x, [0, 0, 0, 0, 0], [-1, -1, num_blocks, -1, -1])
415428
cur_block = tf.slice(
416429
x, [0, 0, 1, 0, 0], [-1, -1, -1, -1, -1])
417430
return tf.concat([prev_block, cur_block], 3)
418-
419431
local_k = local(k)
420432
local_v = local(v)
421-
422433
local_length = tf.shape(local_k)[3]
423434

424435
# [batch, heads, num_blocks, block_length, local_length]
425436
attention = tf.matmul(q, local_k, transpose_b=True)
437+
438+
# Apply bias (N.B: This is not currently working)
439+
if bias is not None:
440+
with tf.name_scope('bias'):
441+
b_batch = tf.shape(bias)[0]
442+
b_heads = tf.shape(bias)[1]
443+
bias_ = bias
444+
#bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0)
445+
if truncate_bias:
446+
# Use only the query dimension
447+
bias = tf.expand_dims(bias[:,:,:,0], 2)
448+
bias = tf.pad(bias, extra_padding, name='bias_pad_b')# 17, 5, 3
449+
bias = tf.reshape(bias,
450+
[b_batch, b_heads, 1, num_blocks+1, block_length],
451+
name='divide_blocks')
452+
local_b = tf.reshape(local(bias),
453+
[b_batch, b_heads, num_blocks, 1, -1], name='reshape_local')
454+
else:
455+
bias = tf.pad(bias, bp, name='pad')
456+
bias = tf.reshape(bias,
457+
[b_batch, b_heads, num_blocks, block_length,
458+
num_blocks+1, block_length], name='divide_blocks')
459+
bias = tf.transpose(bias, [4,2,0,1,3,5])
460+
bias = tf.reshape(bias,
461+
[num_blocks*(num_blocks+1), b_batch, b_heads,
462+
block_length, block_length], name='combine')
463+
indices = (num_blocks+1)*tf.range(num_blocks)
464+
prev_block = tf.gather(bias, indices)
465+
cur_block = tf.gather(bias, indices+num_blocks)
466+
local_b = tf.concat([prev_block, cur_block], 4)
467+
local_b = tf.transpose(local_b, [1,2,0,3,4])
468+
return l-local_b
469+
attention += local_b
470+
471+
attention = tf.nn.softmax(attention)
472+
473+
# Get local mask
474+
if not use_whole_block:
475+
good_part = tf.matrix_band_part(
476+
tf.ones([block_length, local_length]), 0, tf.to_int64(block_length))
477+
elif not look_right:
478+
good_part = tf.matrix_band_part(
479+
tf.ones([block_length, local_length]), -1, tf.to_int64(block_length))
480+
else:
481+
good_part = tf.ones([block_length, local_length])
426482

427-
good_part = tf.matrix_band_part(
428-
tf.ones([block_length, local_length]), 0, tf.to_int64(block_length))
429-
430-
good_part = tf.cast(good_part, tf.float64)
483+
#good_part = tf.cast(good_part, tf.float64)
431484
attention *= tf.reshape(good_part, [1, 1, 1, block_length, local_length])
432-
attention = tf.nn.softmax(attention)
433485

486+
434487
output = tf.matmul(attention, local_v)
435488
output = tf.reshape(output, [batch, heads, -1, depth_v])
436489

437-
# remove added padding
490+
# Remove added padding
438491
output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1])
439492
output.set_shape(v_shape)
440493
return output
441494

442495

443496

444-
445-
###############################################################################
446-
### Not used, left in for reference ###########################################
447-
448-
def windowed_local_attention_1d(q,
449-
k,
450-
v,
451-
window_start,
452-
window_end,
453-
bias,
454-
*args):
455-
""" Local window wrapper for dot product attention. Each element only
456-
attends to the elements from window_start to window_end. This reduces
457-
the computational complexity for long sequences at the expense of eliminating
458-
long-term dependencies.
459-
460-
N.B: For short input sequences this is much slower than just using
461-
un-windowed attention. Use only for long sequences.
462-
463-
Args:
464-
window_size: an integer
465-
q: a Tensor with shape [batch, heads, length_q, depth_k]
466-
k: a Tensor with shape [batch, heads, length_kv, depth_k]
467-
v: a Tensor with shape [batch, heads, length_kv, depth_v]
468-
window_start: an integer Tensor with shape [length_q]
469-
window_end: an integer Tensor with shape [length_q]
470-
bias: bias Tensor (see attention_bias())
471-
472-
Returns:
473-
A Tensor.
474-
"""
475-
with tf.name_scope("windowed"):
476-
477-
# Wrapper function for dot product attention with a single query vector
478-
def single(index_begin, index_end, q, k, v, bias):
479-
#Normalise range
480-
#Reshape to right shape
481-
q = tf.expand_dims(q, 2)
482-
bias = tf.expand_dims(bias, 3)
483-
#Get slices
484-
k = k[:,:,index_begin:index_end,:]
485-
v = v[:,:,index_begin:index_end,:]
486-
out = dot_product_attention(q, k, v, *args)
487-
out = tf.squeeze(out, 2)
488-
return out
489-
490-
# We'll loop over each element of q, computing its corresponding output.
491-
q = tf.transpose(q, [2, 0, 1, 3])
492-
bias = tf.transpose(bias, [3, 0, 1, 2])
493-
indices = tf.range(tf.shape(q)[0])
494-
out = tf.map_fn(
495-
lambda ii: single(
496-
window_start[ii],
497-
window_end[ii],
498-
q[ii],
499-
k,
500-
v,
501-
bias[ii]),
502-
indices,
503-
dtype=tf.float32)
504-
out = tf.transpose(out, [1, 2, 0, 3])
505-
return out
506-
507-
def local_sliding_window(length, window_size, look_right=True):
508-
indices = tf.range(length)
509-
size = window_size
510-
starts = tf.maximum(0, indices-size)
511-
ends = tf.minimum(length-1, indices+size)
512-
return starts, ends
513-
514-
### ###
515-
###############################################################################
516-
517-
518-
519-
520497
def multihead_attention(query_antecedent,
521498
memory_antecedent,
522499
bias,
@@ -527,7 +504,8 @@ def multihead_attention(query_antecedent,
527504
dropout_rate,
528505
summaries=False,
529506
image_shapes=None,
530-
window_size=None,
507+
attention_type="dot_product",
508+
block_length=128,
531509
name=None):
532510
"""Multihead scaled-dot-product attention with input/output transformations.
533511
@@ -540,9 +518,11 @@ def multihead_attention(query_antecedent,
540518
output_depth: an integer
541519
num_heads: an integer dividing total_key_depth and total_value_depth
542520
dropout_rate: a floating point number
543-
summaries: a boolean
544-
window_size: option size of window for attention. Useful only for very long
545-
sequence lengths.
521+
image_shapes: optional tuple of integer scalars.
522+
see comments for attention_image_summary()
523+
attention_type: a string, either "dot_product" or "local" or
524+
"local_mask_right"
525+
block_length: an integer - relevant for "local_mask_right"
546526
name: an optional string
547527
548528
Returns:
@@ -576,14 +556,15 @@ def multihead_attention(query_antecedent,
576556
v = split_heads(v, num_heads)
577557
key_depth_per_head = total_key_depth // num_heads
578558
q *= key_depth_per_head**-0.5
579-
if window_size is None:
559+
if attention_type == "dot_product":
580560
x = dot_product_attention(
581-
q, k, v, bias, dropout_rate, summaries, image_shapes)
561+
q, k, v, bias, dropout_rate, image_shapes)
562+
elif attention_type == "local":
563+
x = local_attention_1d(q, k, v, block_length=block_length)
582564
else:
583-
length = tf.shape(q)[2]
584-
window_start, window_end = local_sliding_window(length, window_size)
585-
x = windowed_local_attention_1d(
586-
q, k, v, window_start, window_end, bias, dropout_rate, False)
565+
assert attention_type == "local_mask_right"
566+
x = local_attention_1d(
567+
q, k, v, block_length=block_length, look_right=False)
587568
x = combine_heads(x)
588569
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
589570
return x

tensor2tensor/models/common_attention_test.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,35 +29,53 @@
2929
class CommonAttentionTest(tf.test.TestCase):
3030

3131
def testLocalAttention(self):
32-
#q = np.array([[[ [1.0, 0.0, 0.0, 0.0],
33-
# [1.0, 0.0, 0.0, 0.0],
34-
# [1.0, 0.0, 0.0, 0.0],
35-
# [1.0, 0.0, 0.0, 0.0],
36-
# [1.0, 0.0, 0.0, 0.0],
37-
# [1.0, 0.0, 0.0, 0.0],
38-
# [1.0, 0.0, 0.0, 0.0],
39-
# [1.0, 0.0, 0.0, 0.0] ]]])
40-
#k = np.array([[[ [0.0, 0.0, 0.0, 0.0],
41-
# [0.0, 0.0, 0.0, 0.0],
42-
# [0.0, 0.0, 0.0, 0.0],
43-
# [0.0, 0.0, 0.0, 0.0],
44-
# [0.0, 0.0, 0.0, 0.0],
45-
# [0.0, 0.0, 0.0, 0.0],
46-
# [0.0, 0.0, 0.0, 0.0],
47-
# [0.0, 0.0, 0.0, 0.0] ]]])
48-
#v = np.ones((1, 1, 8, 1))
32+
q = np.array([[[ [1.0, 0.0, 0.0, 0.0],
33+
[1.0, 0.0, 0.0, 0.0],
34+
[1.0, 0.0, 0.0, 0.0],
35+
[1.0, 0.0, 0.0, 0.0],
36+
[1.0, 0.0, 0.0, 0.0],
37+
[1.0, 0.0, 0.0, 0.0],
38+
[1.0, 0.0, 0.0, 0.0],
39+
[1.0, 0.0, 0.0, 0.0] ]]])
4940

50-
q = np.random.rand(5, 7, 13, 3)
51-
k = np.random.rand(5, 7, 13, 3)
52-
v = np.random.rand(5, 7, 13, 11)
41+
k = np.array([[[ [1.0, 0.0, 0.0, 0.0],
42+
[1.0, 0.0, 0.0, 0.0],
43+
[1.0, 0.0, 0.0, 0.0],
44+
[1.0, 0.0, 0.0, 0.0],
45+
[1.0, 0.0, 0.0, 0.0],
46+
[1.0, 0.0, 0.0, 0.0],
47+
[1.0, 0.0, 0.0, 0.0],
48+
[1.0, 0.0, 0.0, 0.0] ]]])
49+
50+
b = np.array([[[ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
51+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
52+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
53+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
54+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
55+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
56+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
57+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ]]])
58+
59+
#b = np.ones((1,1,8,8))
60+
#b = (1-b) * (-1e9)
61+
v = np.ones((1, 1, 8, 1))
62+
63+
#q = np.random.rand(5, 7, 13, 3)
64+
#k = np.random.rand(5, 7, 13, 3)
65+
#v = np.random.rand(5, 7, 13, 11)
66+
#b = np.random.rand(5, 1, 13, 1)
5367

5468
with self.test_session() as session:
5569
q_ = tf.constant(q)
5670
k_ = tf.constant(k)
5771
v_ = tf.constant(v)
58-
y = common_attention.masked_local_attention_1d(q_, k_, v_, block_length=tf.constant(3))
72+
b_ = tf.constant(b)
73+
y = common_attention.local_attention_1d(q_, k_, v_, b_, block_length=tf.constant(2))
5974
res = session.run(y)
60-
self.assertEqual(res.shape, (5, 7, 13, 11))
75+
#print(q)
76+
#rint(k)
77+
print(res)
78+
#self.assertEqual(res.shape, (5, 7, 13, 11))
6179

6280

6381
if __name__ == "__main__":

tensor2tensor/models/transformer_alternative.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ def alt_transformer_decoder(decoder_input,
174174

175175
def bias_to_mask(bias):
176176
# We need masks of the form batch size x input sequences
177-
# Biases seem to be of the form batch_size x 1 x input sequences x vec dim
178-
# Squeeze out dim one, and get the first element of each vector
177+
# Biases are of the form batch_size x num_heads x input sequences x
178+
# output sequences. Squeeze out dim one, and get the first element of
179+
# each vector.
179180

180181
bias = tf.squeeze(bias, [1])[:,:,0]
181182
bias = - tf.clip_by_value(bias, -1.0, 1.0)
@@ -189,7 +190,7 @@ def transformer_alt():
189190
"""Set of hyperparameters."""
190191
hparams = transformer.transformer_base()
191192
hparams.batch_size = 2048
192-
hparams.num_hidden_layers = 3
193+
hparams.num_hidden_layers = 10
193194
hparams.add_hparam("layers_per_layer", 4)
194195
hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding
195196
#hparams.add_hparam("composite_layer_type", "reembedding")

0 commit comments

Comments
 (0)