@@ -345,24 +345,34 @@ def dot_product_attention(q,
345
345
return tf .matmul (weights , v )
346
346
347
347
348
- def masked_local_attention_1d (
349
- q , k , v , block_length = 128 , mask_right = False , name = None ):
350
- """Attention to the source position and a neigborhood to the left of it.
351
348
352
- The sequence is divided into blocks of length block_size.
353
- Attention for a given query position can only see memory positions
354
- less than or equal to the query position, in the corresponding block
355
- and the previous block .
349
+ def local_attention_1d ( q , k , v , bias = None ,
350
+ block_length = 128 , look_right = True , use_whole_block = False ,
351
+ truncate_bias = True , name = None ):
352
+ """Attention to the source position and a neigborhood around it .
356
353
357
- If mask_right is True, then a target position cannot see greater source
354
+ The sequence is divided into blocks of length block_size. Attention for a
355
+ given query position can only see memory positions within a certain number
356
+ of positions before and behind it.
357
+
358
+ If look_right is True then each query will attend to block_length//2
359
+ positions either side, otherwise it will attend to block_length previous
358
360
positions.
359
361
362
+ If use_whole_block is True then no mask will be applied to the local blocks
363
+ meaning the full blocks are used (if look_right is True then the elements to
364
+ the right of the current position are still masked out). This allows use to
365
+ attend to more elements without additional overhead, but means we have
366
+ inconsistent window positions and sizes.
367
+
360
368
Args:
361
- q: a Tensor with shape [batch, heads, length, depth_k]
362
- k: a Tensor with shape [batch, heads, length, depth_k]
363
- v: a Tensor with shape [batch, heads, length, depth_v]
369
+ q: a Tensor with shape [batch, heads, length_q, depth_k]
370
+ k: a Tensor with shape [batch, heads, length_kv, depth_k]
371
+ v: a Tensor with shape [batch, heads, length_kv, depth_v]
372
+ bias: Not currently used [batch, heads, length_q, length_k]
364
373
block_length: an integer
365
- mask_right: a bool
374
+ look_right: a bool
375
+ use_whole_block: a bool
366
376
name: an optional string
367
377
368
378
Returns:
@@ -379,8 +389,9 @@ def masked_local_attention_1d(
379
389
380
390
original_length = length
381
391
382
- # If (length < 2 * block_length), then we use only one block.
383
- block_length = tf .where (tf .less (length , block_length * 2 ),
392
+ #Pad to desired length
393
+ #If (length < 2 * block_length), then we use only one block.
394
+ block_length = tf .where (tf .less (length , block_length ),
384
395
length , block_length )
385
396
padding_size = tf .mod (- length , block_length )
386
397
length += padding_size
@@ -389,134 +400,100 @@ def masked_local_attention_1d(
389
400
padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
390
401
q = tf .pad (q , padding )
391
402
392
- if mask_right :
403
+ if not look_right :
393
404
#Add extra padding so we son't have to do an initial query
394
405
extra_padding = [[0 , 0 ], [0 , 0 ], [block_length , padding_size ], [0 , 0 ]]
406
+ bp = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [block_length , padding_size ]]
395
407
else :
396
408
#We shift everything over by half a block so query is in centre
397
409
pad_right = block_length // 2
398
410
pad_left = block_length - pad_right
399
411
extra_padding = [[0 , 0 ], [0 , 0 ],
400
- [pad_left ,padding_size + pad_right ], [0 , 0 ]]
401
-
412
+ [pad_left , padding_size + pad_right ], [0 , 0 ]]
413
+ bp = [[0 , 0 ], [0 , 0 ],
414
+ [0 , padding_size ], [pad_left , padding_size + pad_right ]]
402
415
k = tf .pad (k , extra_padding )
403
416
v = tf .pad (v , extra_padding )
404
417
405
-
406
- # compute attention for all subsequent query blocks.
418
+ # Reshape into blocks
407
419
q = tf .reshape (q , [batch , heads , num_blocks , block_length , depth_k ])
408
420
k = tf .reshape (k , [batch , heads , num_blocks + 1 , block_length , depth_k ])
409
421
v = tf .reshape (v , [batch , heads , num_blocks + 1 , block_length , depth_v ])
410
422
423
+ # Get local blocks by slicing
411
424
def local (x ):
412
425
"""Create a local version of the keys or values."""
413
426
prev_block = tf .slice (
414
427
x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks , - 1 , - 1 ])
415
428
cur_block = tf .slice (
416
429
x , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
417
430
return tf .concat ([prev_block , cur_block ], 3 )
418
-
419
431
local_k = local (k )
420
432
local_v = local (v )
421
-
422
433
local_length = tf .shape (local_k )[3 ]
423
434
424
435
# [batch, heads, num_blocks, block_length, local_length]
425
436
attention = tf .matmul (q , local_k , transpose_b = True )
437
+
438
+ # Apply bias (N.B: This is not currently working)
439
+ if bias is not None :
440
+ with tf .name_scope ('bias' ):
441
+ b_batch = tf .shape (bias )[0 ]
442
+ b_heads = tf .shape (bias )[1 ]
443
+ bias_ = bias
444
+ #bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0)
445
+ if truncate_bias :
446
+ # Use only the query dimension
447
+ bias = tf .expand_dims (bias [:,:,:,0 ], 2 )
448
+ bias = tf .pad (bias , extra_padding , name = 'bias_pad_b' )# 17, 5, 3
449
+ bias = tf .reshape (bias ,
450
+ [b_batch , b_heads , 1 , num_blocks + 1 , block_length ],
451
+ name = 'divide_blocks' )
452
+ local_b = tf .reshape (local (bias ),
453
+ [b_batch , b_heads , num_blocks , 1 , - 1 ], name = 'reshape_local' )
454
+ else :
455
+ bias = tf .pad (bias , bp , name = 'pad' )
456
+ bias = tf .reshape (bias ,
457
+ [b_batch , b_heads , num_blocks , block_length ,
458
+ num_blocks + 1 , block_length ], name = 'divide_blocks' )
459
+ bias = tf .transpose (bias , [4 ,2 ,0 ,1 ,3 ,5 ])
460
+ bias = tf .reshape (bias ,
461
+ [num_blocks * (num_blocks + 1 ), b_batch , b_heads ,
462
+ block_length , block_length ], name = 'combine' )
463
+ indices = (num_blocks + 1 )* tf .range (num_blocks )
464
+ prev_block = tf .gather (bias , indices )
465
+ cur_block = tf .gather (bias , indices + num_blocks )
466
+ local_b = tf .concat ([prev_block , cur_block ], 4 )
467
+ local_b = tf .transpose (local_b , [1 ,2 ,0 ,3 ,4 ])
468
+ return l - local_b
469
+ attention += local_b
470
+
471
+ attention = tf .nn .softmax (attention )
472
+
473
+ # Get local mask
474
+ if not use_whole_block :
475
+ good_part = tf .matrix_band_part (
476
+ tf .ones ([block_length , local_length ]), 0 , tf .to_int64 (block_length ))
477
+ elif not look_right :
478
+ good_part = tf .matrix_band_part (
479
+ tf .ones ([block_length , local_length ]), - 1 , tf .to_int64 (block_length ))
480
+ else :
481
+ good_part = tf .ones ([block_length , local_length ])
426
482
427
- good_part = tf .matrix_band_part (
428
- tf .ones ([block_length , local_length ]), 0 , tf .to_int64 (block_length ))
429
-
430
- good_part = tf .cast (good_part , tf .float64 )
483
+ #good_part = tf.cast(good_part, tf.float64)
431
484
attention *= tf .reshape (good_part , [1 , 1 , 1 , block_length , local_length ])
432
- attention = tf .nn .softmax (attention )
433
485
486
+
434
487
output = tf .matmul (attention , local_v )
435
488
output = tf .reshape (output , [batch , heads , - 1 , depth_v ])
436
489
437
- # remove added padding
490
+ # Remove added padding
438
491
output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
439
492
output .set_shape (v_shape )
440
493
return output
441
494
442
495
443
496
444
-
445
- ###############################################################################
446
- ### Not used, left in for reference ###########################################
447
-
448
- def windowed_local_attention_1d (q ,
449
- k ,
450
- v ,
451
- window_start ,
452
- window_end ,
453
- bias ,
454
- * args ):
455
- """ Local window wrapper for dot product attention. Each element only
456
- attends to the elements from window_start to window_end. This reduces
457
- the computational complexity for long sequences at the expense of eliminating
458
- long-term dependencies.
459
-
460
- N.B: For short input sequences this is much slower than just using
461
- un-windowed attention. Use only for long sequences.
462
-
463
- Args:
464
- window_size: an integer
465
- q: a Tensor with shape [batch, heads, length_q, depth_k]
466
- k: a Tensor with shape [batch, heads, length_kv, depth_k]
467
- v: a Tensor with shape [batch, heads, length_kv, depth_v]
468
- window_start: an integer Tensor with shape [length_q]
469
- window_end: an integer Tensor with shape [length_q]
470
- bias: bias Tensor (see attention_bias())
471
-
472
- Returns:
473
- A Tensor.
474
- """
475
- with tf .name_scope ("windowed" ):
476
-
477
- # Wrapper function for dot product attention with a single query vector
478
- def single (index_begin , index_end , q , k , v , bias ):
479
- #Normalise range
480
- #Reshape to right shape
481
- q = tf .expand_dims (q , 2 )
482
- bias = tf .expand_dims (bias , 3 )
483
- #Get slices
484
- k = k [:,:,index_begin :index_end ,:]
485
- v = v [:,:,index_begin :index_end ,:]
486
- out = dot_product_attention (q , k , v , * args )
487
- out = tf .squeeze (out , 2 )
488
- return out
489
-
490
- # We'll loop over each element of q, computing its corresponding output.
491
- q = tf .transpose (q , [2 , 0 , 1 , 3 ])
492
- bias = tf .transpose (bias , [3 , 0 , 1 , 2 ])
493
- indices = tf .range (tf .shape (q )[0 ])
494
- out = tf .map_fn (
495
- lambda ii : single (
496
- window_start [ii ],
497
- window_end [ii ],
498
- q [ii ],
499
- k ,
500
- v ,
501
- bias [ii ]),
502
- indices ,
503
- dtype = tf .float32 )
504
- out = tf .transpose (out , [1 , 2 , 0 , 3 ])
505
- return out
506
-
507
- def local_sliding_window (length , window_size , look_right = True ):
508
- indices = tf .range (length )
509
- size = window_size
510
- starts = tf .maximum (0 , indices - size )
511
- ends = tf .minimum (length - 1 , indices + size )
512
- return starts , ends
513
-
514
- ### ###
515
- ###############################################################################
516
-
517
-
518
-
519
-
520
497
def multihead_attention (query_antecedent ,
521
498
memory_antecedent ,
522
499
bias ,
@@ -527,7 +504,8 @@ def multihead_attention(query_antecedent,
527
504
dropout_rate ,
528
505
summaries = False ,
529
506
image_shapes = None ,
530
- window_size = None ,
507
+ attention_type = "dot_product" ,
508
+ block_length = 128 ,
531
509
name = None ):
532
510
"""Multihead scaled-dot-product attention with input/output transformations.
533
511
@@ -540,9 +518,11 @@ def multihead_attention(query_antecedent,
540
518
output_depth: an integer
541
519
num_heads: an integer dividing total_key_depth and total_value_depth
542
520
dropout_rate: a floating point number
543
- summaries: a boolean
544
- window_size: option size of window for attention. Useful only for very long
545
- sequence lengths.
521
+ image_shapes: optional tuple of integer scalars.
522
+ see comments for attention_image_summary()
523
+ attention_type: a string, either "dot_product" or "local" or
524
+ "local_mask_right"
525
+ block_length: an integer - relevant for "local_mask_right"
546
526
name: an optional string
547
527
548
528
Returns:
@@ -576,14 +556,15 @@ def multihead_attention(query_antecedent,
576
556
v = split_heads (v , num_heads )
577
557
key_depth_per_head = total_key_depth // num_heads
578
558
q *= key_depth_per_head ** - 0.5
579
- if window_size is None :
559
+ if attention_type == "dot_product" :
580
560
x = dot_product_attention (
581
- q , k , v , bias , dropout_rate , summaries , image_shapes )
561
+ q , k , v , bias , dropout_rate , image_shapes )
562
+ elif attention_type == "local" :
563
+ x = local_attention_1d (q , k , v , block_length = block_length )
582
564
else :
583
- length = tf .shape (q )[2 ]
584
- window_start , window_end = local_sliding_window (length , window_size )
585
- x = windowed_local_attention_1d (
586
- q , k , v , window_start , window_end , bias , dropout_rate , False )
565
+ assert attention_type == "local_mask_right"
566
+ x = local_attention_1d (
567
+ q , k , v , block_length = block_length , look_right = False )
587
568
x = combine_heads (x )
588
569
x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
589
570
return x
0 commit comments