@@ -361,122 +361,100 @@ def dot_product_attention(q,
361
361
return tf .matmul (weights , v )
362
362
363
363
364
- def masked_local_attention_1d (q ,
365
- k ,
366
- v ,
367
- block_length = 128 ,
368
- look_right = True ,
369
- use_whole_block = False ,
370
- name = None ):
371
- """Attention to the source position and a neigborhood around it.
372
-
373
- The sequence is divided into blocks of length block_size. Attention for a
374
- given query position can only see memory positions within a certain number
375
- of positions before and behind it.
376
-
377
-
378
- If look_right is True then each query will attend to block_length//2
379
- positions either side, otherwise it will attend to block_length previous
380
- positions.
364
+ def masked_local_attention_1d (
365
+ q , k , v , block_length = 128 , name = None ):
366
+ """Attention to the source position and a neigborhood to the left of it.
367
+
368
+ The sequence is divided into blocks of length block_size.
369
+ Attention for a given query position can only see memory positions
370
+ less than or equal to the query position, in the corresponding block
371
+ and the previous block.
381
372
382
- If use_whole_block is True then no mask will be applied to the local blocks
383
- meaning the full blocks are used (if look_right is True then the elements to
384
- the right of the current position are still masked out). This allows to
385
- attend to more elements without additional overhead, but means we have
386
- inconsistent window positions and sizes.
373
+ If mask_right is True, then a target position cannot see greater source
374
+ positions.
387
375
388
376
Args:
389
- q: a Tensor with shape [batch, heads, length_q , depth_k]
390
- k: a Tensor with shape [batch, heads, length_kv , depth_k]
391
- v: a Tensor with shape [batch, heads, length_kv , depth_v]
377
+ q: a Tensor with shape [batch, heads, length , depth_k]
378
+ k: a Tensor with shape [batch, heads, length , depth_k]
379
+ v: a Tensor with shape [batch, heads, length , depth_v]
392
380
block_length: an integer
393
- look_right: a bool
394
- use_whole_block: a bool
395
381
name: an optional string
396
382
397
383
Returns:
398
384
a Tensor of shape [batch, heads, length, depth_v]
399
385
"""
400
- with tf .variable_scope (
401
- name , default_name = "local_attention_1d" , values = [q , k , v ]):
386
+ with tf .variable_scope (name , default_name = "local_attention_1d" ,
387
+ values = [q , k , v ]):
402
388
v_shape = v .get_shape ()
403
389
batch = tf .shape (q )[0 ]
404
390
heads = tf .shape (q )[1 ]
405
391
length = tf .shape (q )[2 ]
392
+ # If (length < 2 * block_length), then we use only one block.
393
+ block_length = tf .where (tf .less (length , block_length * 2 ),
394
+ length , block_length )
406
395
depth_k = tf .shape (q )[3 ]
407
396
depth_v = tf .shape (v )[3 ]
408
397
original_length = length
409
-
410
- # If (length < block_length), then we use only one block.
411
- block_length = tf .where (tf .less (length , block_length ), length , block_length )
412
- # Pad to desired length.
413
398
padding_size = tf .mod (- length , block_length )
414
399
length += padding_size
415
- num_blocks = tf .div (length , block_length )
416
400
padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
417
401
q = tf .pad (q , padding )
402
+ k = tf .pad (k , padding )
403
+ v = tf .pad (v , padding )
404
+ num_blocks = tf .div (length , block_length )
418
405
419
- if not look_right :
420
- # Add extra padding so we son't have to do an initial query block.
421
- extra_padding = [[0 , 0 ], [0 , 0 ], [block_length , padding_size ], [0 , 0 ]]
422
- else :
423
- # We shift everything over by half a block so query is in center.
424
- pad_right = block_length // 2
425
- pad_left = block_length - pad_right
426
- extra_padding = [[0 , 0 ], [0 , 0 ], [pad_left , padding_size + pad_right ],
427
- [0 , 0 ]]
428
- k = tf .pad (k , extra_padding )
429
- v = tf .pad (v , extra_padding )
430
-
431
- # Reshape into blocks.
406
+ # compute attention for the first query block.
407
+ first_q = tf .slice (q , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
408
+ first_k = tf .slice (k , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
409
+ first_v = tf .slice (v , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
410
+ first_output = dot_product_attention (
411
+ first_q , first_k , first_v , attention_bias_lower_triangle (block_length ),
412
+ name = "fist_block" )
413
+
414
+ # compute attention for all subsequent query blocks.
432
415
q = tf .reshape (q , [batch , heads , num_blocks , block_length , depth_k ])
433
- k = tf .reshape (k , [batch , heads , num_blocks + 1 , block_length , depth_k ])
434
- v = tf .reshape (v , [batch , heads , num_blocks + 1 , block_length , depth_v ])
416
+ k = tf .reshape (k , [batch , heads , num_blocks , block_length , depth_k ])
417
+ v = tf .reshape (v , [batch , heads , num_blocks , block_length , depth_v ])
435
418
436
- # Get local blocks by slicing.
437
419
def local (x ):
438
420
"""Create a local version of the keys or values."""
439
- prev_block = tf .slice (x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks , - 1 , - 1 ])
440
- cur_block = tf .slice (x , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
421
+ prev_block = tf .slice (
422
+ x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks - 1 , - 1 , - 1 ])
423
+ cur_block = tf .slice (
424
+ x , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
441
425
return tf .concat ([prev_block , cur_block ], 3 )
442
-
443
426
local_k = local (k )
444
427
local_v = local (v )
445
- local_length = tf .shape ( local_k )[ 3 ]
428
+ tail_q = tf .slice ( q , [ 0 , 0 , 1 , 0 , 0 ], [ - 1 , - 1 , - 1 , - 1 , - 1 ])
446
429
447
- # [batch, heads, num_blocks, block_length, local_length]
448
- attention = tf .matmul (q , local_k , transpose_b = True )
449
- attention = tf .nn .softmax (attention )
450
-
451
- # Get local mask
452
- if not use_whole_block :
453
- good_part = tf .matrix_band_part (
454
- tf .ones ([block_length , local_length ]), 0 , tf .to_int64 (block_length ))
455
- elif not look_right :
456
- good_part = tf .matrix_band_part (
457
- tf .ones ([block_length , local_length ]), - 1 , tf .to_int64 (block_length ))
458
- else :
459
- good_part = tf .ones ([block_length , local_length ])
430
+ local_length = tf .shape (local_k )[3 ]
460
431
461
- attention *= tf .reshape (good_part , [1 , 1 , 1 , block_length , local_length ])
432
+ # [batch, heads, num_blocks - 1, block_length, local_length]
433
+ attention = tf .matmul (tail_q , local_k , transpose_b = True )
462
434
435
+ # make sure source_pos <= target_pos
436
+ good_part = tf .matrix_band_part (
437
+ tf .ones ([block_length , local_length ]), - 1 , tf .to_int64 (block_length ))
438
+ mask = (1.0 - good_part ) * - 1e9
439
+ attention += tf .reshape (mask , [1 , 1 , 1 , block_length , local_length ])
440
+ attention = tf .nn .softmax (attention )
463
441
# TODO(noam): figure out how to show a summary for the remaining blocks.
464
442
# The naive way currently causes errors due to empty tensors.
443
+ # output: [batch, heads, num_blocks-1, block_length, depth_v]
465
444
output = tf .matmul (attention , local_v )
466
445
output = tf .reshape (output , [batch , heads , - 1 , depth_v ])
467
-
468
- # Remove added padding
446
+ output = tf .concat ([first_output , output ], axis = 2 )
469
447
output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
470
448
output .set_shape (v_shape )
471
449
return output
472
450
473
451
474
- def unmasked_local_attention_1d (q ,
475
- k ,
476
- v ,
477
- block_length = 128 ,
478
- filter_width = 100 ,
479
- name = None ):
452
+ def local_attention_1d (q ,
453
+ k ,
454
+ v ,
455
+ block_length = 128 ,
456
+ filter_width = 100 ,
457
+ name = None ):
480
458
"""strided block local self-attention.
481
459
482
460
Args:
@@ -644,7 +622,7 @@ def multihead_attention(query_antecedent,
644
622
x = masked_local_attention_1d (q , k , v , block_length = block_length )
645
623
else :
646
624
assert attention_type == "local_unmasked"
647
- x = unmasked_local_attention_1d (
625
+ x = local_attention_1d (
648
626
q , k , v , block_length = block_length , filter_width = block_width )
649
627
x = combine_heads (x )
650
628
x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
0 commit comments