@@ -344,23 +344,33 @@ def dot_product_attention(q,
344
344
return tf .matmul (weights , v )
345
345
346
346
347
- def masked_local_attention_1d (
348
- q , k , v , block_length = 128 , name = None ):
349
- """Attention to the source position and a neigborhood to the left of it.
347
+ def local_attention_1d (q , k , v , bias = None ,
348
+ block_length = 128 , look_right = True , use_whole_block = False ,
349
+ truncate_bias = True , name = None ):
350
+ """Attention to the source position and a neigborhood around it.
350
351
351
- The sequence is divided into blocks of length block_size.
352
- Attention for a given query position can only see memory positions
353
- less than or equal to the query position, in the corresponding block
354
- and the previous block.
352
+ The sequence is divided into blocks of length block_size. Attention for a
353
+ given query position can only see memory positions within a certain number
354
+ of positions before and behind it.
355
355
356
- If mask_right is True, then a target position cannot see greater source
356
+ If look_right is True then each query will attend to block_length//2
357
+ positions either side, otherwise it will attend to block_length previous
357
358
positions.
358
359
360
+ If use_whole_block is True then no mask will be applied to the local blocks
361
+ meaning the full blocks are used (if look_right is True then the elements to
362
+ the right of the current position are still masked out). This allows use to
363
+ attend to more elements without additional overhead, but means we have
364
+ inconsistent window positions and sizes.
365
+
359
366
Args:
360
- q: a Tensor with shape [batch, heads, length, depth_k]
361
- k: a Tensor with shape [batch, heads, length, depth_k]
362
- v: a Tensor with shape [batch, heads, length, depth_v]
367
+ q: a Tensor with shape [batch, heads, length_q, depth_k]
368
+ k: a Tensor with shape [batch, heads, length_kv, depth_k]
369
+ v: a Tensor with shape [batch, heads, length_kv, depth_v]
370
+ bias: Not currently used [batch, heads, length_q, length_k]
363
371
block_length: an integer
372
+ look_right: a bool
373
+ use_whole_block: a bool
364
374
name: an optional string
365
375
366
376
Returns:
@@ -372,146 +382,110 @@ def masked_local_attention_1d(
372
382
batch = tf .shape (q )[0 ]
373
383
heads = tf .shape (q )[1 ]
374
384
length = tf .shape (q )[2 ]
375
- # If (length < 2 * block_length), then we use only one block.
376
- block_length = tf .where (tf .less (length , block_length * 2 ),
377
- length , block_length )
378
385
depth_k = tf .shape (q )[3 ]
379
386
depth_v = tf .shape (v )[3 ]
387
+
380
388
original_length = length
389
+
390
+ #Pad to desired length
391
+ #If (length < block_length), then we use only one block.
392
+ block_length = tf .where (tf .less (length , block_length ),
393
+ length , block_length )
381
394
padding_size = tf .mod (- length , block_length )
382
395
length += padding_size
383
- padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
384
- q = tf .pad (q , padding )
385
- k = tf .pad (k , padding )
386
- v = tf .pad (v , padding )
387
396
num_blocks = tf .div (length , block_length )
388
397
389
- # compute attention for the first query block.
390
- first_q = tf .slice (q , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
391
- first_k = tf .slice (k , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
392
- first_v = tf .slice (v , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
393
- first_output = dot_product_attention (
394
- first_q , first_k , first_v , attention_bias_lower_triangle (block_length ),
395
- name = "fist_block" )
398
+ padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
399
+ q = tf .pad (q , padding )
396
400
397
- # compute attention for all subsequent query blocks.
401
+ if not look_right :
402
+ #Add extra padding so we son't have to do an initial query
403
+ extra_padding = [[0 , 0 ], [0 , 0 ], [block_length , padding_size ], [0 , 0 ]]
404
+ bp = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [block_length , padding_size ]]
405
+ else :
406
+ #We shift everything over by half a block so query is in centre
407
+ pad_right = block_length // 2
408
+ pad_left = block_length - pad_right
409
+ extra_padding = [[0 , 0 ], [0 , 0 ],
410
+ [pad_left , padding_size + pad_right ], [0 , 0 ]]
411
+ bp = [[0 , 0 ], [0 , 0 ],
412
+ [0 , padding_size ], [pad_left , padding_size + pad_right ]]
413
+ k = tf .pad (k , extra_padding )
414
+ v = tf .pad (v , extra_padding )
415
+
416
+ # Reshape into blocks
398
417
q = tf .reshape (q , [batch , heads , num_blocks , block_length , depth_k ])
399
- k = tf .reshape (k , [batch , heads , num_blocks , block_length , depth_k ])
400
- v = tf .reshape (v , [batch , heads , num_blocks , block_length , depth_v ])
418
+ k = tf .reshape (k , [batch , heads , num_blocks + 1 , block_length , depth_k ])
419
+ v = tf .reshape (v , [batch , heads , num_blocks + 1 , block_length , depth_v ])
401
420
421
+ # Get local blocks by slicing
402
422
def local (x ):
403
423
"""Create a local version of the keys or values."""
404
424
prev_block = tf .slice (
405
- x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks - 1 , - 1 , - 1 ])
425
+ x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks , - 1 , - 1 ])
406
426
cur_block = tf .slice (
407
427
x , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
408
428
return tf .concat ([prev_block , cur_block ], 3 )
409
429
local_k = local (k )
410
430
local_v = local (v )
411
- tail_q = tf .slice (q , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
412
-
413
431
local_length = tf .shape (local_k )[3 ]
414
432
415
- # [batch, heads, num_blocks - 1, block_length, local_length]
416
- attention = tf .matmul (tail_q , local_k , transpose_b = True )
417
-
418
- # make sure source_pos <= target_pos
419
- good_part = tf .matrix_band_part (
420
- tf .ones ([block_length , local_length ]), - 1 , tf .to_int64 (block_length ))
421
- mask = (1.0 - good_part ) * - 1e9
422
- attention += tf .reshape (mask , [1 , 1 , 1 , block_length , local_length ])
433
+ # [batch, heads, num_blocks, block_length, local_length]
434
+ attention = tf .matmul (q , local_k , transpose_b = True )
435
+
436
+ # Apply bias (N.B: This is not currently working)
437
+ if bias is not None :
438
+ with tf .name_scope ('bias' ):
439
+ b_batch = tf .shape (bias )[0 ]
440
+ b_heads = tf .shape (bias )[1 ]
441
+ bias_ = bias
442
+ #bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0)
443
+ if truncate_bias :
444
+ # Use only the query dimension
445
+ bias = tf .expand_dims (bias [:,:,:,0 ], 2 )
446
+ bias = tf .pad (bias , extra_padding , name = 'bias_pad_b' )# 17, 5, 3
447
+ bias = tf .reshape (bias ,
448
+ [b_batch , b_heads , 1 , num_blocks + 1 , block_length ],
449
+ name = 'divide_blocks' )
450
+ local_b = tf .reshape (local (bias ),
451
+ [b_batch , b_heads , num_blocks , 1 , - 1 ], name = 'reshape_local' )
452
+ else :
453
+ bias = tf .pad (bias , bp , name = 'pad' )
454
+ bias = tf .reshape (bias ,
455
+ [b_batch , b_heads , num_blocks , block_length ,
456
+ num_blocks + 1 , block_length ], name = 'divide_blocks' )
457
+ bias = tf .transpose (bias , [4 ,2 ,0 ,1 ,3 ,5 ])
458
+ bias = tf .reshape (bias ,
459
+ [num_blocks * (num_blocks + 1 ), b_batch , b_heads ,
460
+ block_length , block_length ], name = 'combine' )
461
+ indices = (num_blocks + 1 )* tf .range (num_blocks )
462
+ prev_block = tf .gather (bias , indices )
463
+ cur_block = tf .gather (bias , indices + num_blocks )
464
+ local_b = tf .concat ([prev_block , cur_block ], 4 )
465
+ local_b = tf .transpose (local_b , [1 ,2 ,0 ,3 ,4 ])
466
+ return l - local_b
467
+ attention += local_b
468
+
423
469
attention = tf .nn .softmax (attention )
424
- # TODO(noam): figure out how to show a summary for the remaining blocks.
425
- # The naive way currently causes errors due to empty tensors.
426
- # output: [batch, heads, num_blocks-1, block_length, depth_v]
427
- output = tf .matmul ( attention , local_v )
428
- output = tf .reshape ( output , [ batch , heads , - 1 , depth_v ] )
429
- output = tf . concat ([ first_output , output ], axis = 2 )
430
- output = tf .slice ( output , [ 0 , 0 , 0 , 0 ], [ - 1 , - 1 , original_length , - 1 ])
431
- output . set_shape ( v_shape )
432
- return output
433
-
470
+
471
+ # Get local mask
472
+ if not use_whole_block :
473
+ good_part = tf .matrix_band_part (
474
+ tf .ones ([ block_length , local_length ]), 0 , tf . to_int64 ( block_length ) )
475
+ elif not look_right :
476
+ good_part = tf .matrix_band_part (
477
+ tf . ones ([ block_length , local_length ]), - 1 , tf . to_int64 ( block_length ) )
478
+ else :
479
+ good_part = tf . ones ([ block_length , local_length ])
434
480
435
- def unmasked_local_attention_1d (q , k , v , block_length = 128 , filter_width = 100 ,
436
- name = None ):
437
- """strided block local self-attention.
481
+ #good_part = tf.cast(good_part, tf.float64)
482
+ attention *= tf .reshape (good_part , [1 , 1 , 1 , block_length , local_length ])
438
483
439
- Args:
440
- q: a Tensor with shape [batch, heads, length, depth_k]
441
- k: a Tensor with shape [batch, heads, length, depth_k]
442
- v: a Tensor with shape [batch, heads, length, depth_v]
443
- block_length: an integer
444
- filter_width: an integer indicating how much to look left.
445
- name: an optional string
484
+
485
+ output = tf .matmul (attention , local_v )
486
+ output = tf .reshape (output , [batch , heads , - 1 , depth_v ])
446
487
447
- Returns:
448
- a Tensor of shape [batch, heads, length, depth_v]
449
- """
450
- with tf .variable_scope (name , default_name = "local_self_attention_1d" ,
451
- values = [q , k , v ]):
452
- v_shape = v .get_shape ()
453
- depth_v = tf .shape (v )[3 ]
454
- batch_size = tf .shape (q )[0 ]
455
- num_heads = tf .shape (q )[1 ]
456
- original_length = tf .shape (q )[2 ]
457
- # making sure q is a multiple of d
458
- def pad_to_multiple (x , pad_length ):
459
- x_length = tf .shape (x )[2 ]
460
- return tf .pad (x , [[0 , 0 ], [0 , 0 ], [0 , - x_length % pad_length ], [0 , 0 ]])
461
- def pad_l_and_r (x , pad_length ):
462
- return tf .pad (x , [[0 , 0 ], [0 , 0 ], [pad_length , pad_length ], [0 , 0 ]])
463
- q = pad_to_multiple (q , block_length )
464
- k = pad_to_multiple (k , block_length )
465
- v = pad_to_multiple (v , block_length )
466
-
467
- # Setting up q blocks
468
- new_q_shape = tf .shape (q )
469
- # Setting up q blocks
470
- q = tf .reshape (q , [new_q_shape [0 ], new_q_shape [1 ],
471
- new_q_shape [2 ]// block_length ,
472
- block_length , new_q_shape [3 ]])
473
-
474
- # Setting up k and v values
475
- k = pad_l_and_r (k , filter_width )
476
- v = pad_l_and_r (v , filter_width )
477
-
478
- length = tf .shape (k )[2 ]
479
- full_filter_width = block_length + 2 * filter_width
480
- # getting gather indices
481
- indices = tf .range (0 , length , delta = 1 , name = "index_range" )
482
- # making indices [1, length, 1] to appy convs
483
- indices = tf .reshape (indices , [1 , - 1 , 1 ])
484
- kernel = tf .expand_dims (tf .eye (full_filter_width ), axis = 1 )
485
- gather_indices = tf .nn .conv1d (
486
- tf .cast (indices , tf .float32 ),
487
- kernel ,
488
- block_length ,
489
- padding = "VALID" ,
490
- name = "gather_conv" )
491
-
492
- gather_indices = tf .squeeze (tf .cast (gather_indices , tf .int32 ), axis = 0 )
493
-
494
- # [length, batch, heads, dim]
495
- k_t = tf .transpose (k , [2 , 0 , 1 , 3 ])
496
- k_new = tf .gather (k_t , gather_indices )
497
-
498
- # [batch, heads, blocks, block_length, dim]
499
- k_new = tf .transpose (k_new , [2 , 3 , 0 , 1 , 4 ])
500
-
501
- attention_bias = tf .expand_dims (
502
- tf .to_float (embedding_to_padding (k_new )) * - 1e9 , axis = - 2 )
503
-
504
- v_t = tf .transpose (v , [2 , 0 , 1 , 3 ])
505
- v_new = tf .gather (v_t , gather_indices )
506
- v_new = tf .transpose (v_new , [2 , 3 , 0 , 1 , 4 ])
507
-
508
- logits = tf .matmul (q , k_new , transpose_b = True )
509
-
510
- attention = tf .nn .softmax (logits + attention_bias )
511
- output = tf .matmul (attention , v_new )
512
-
513
- output = tf .reshape (output , [batch_size , num_heads , - 1 , depth_v ])
514
- # Remove the padding if introduced
488
+ # Remove added padding
515
489
output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
516
490
output .set_shape (v_shape )
517
491
return output
@@ -542,8 +516,8 @@ def multihead_attention(query_antecedent,
542
516
dropout_rate: a floating point number
543
517
image_shapes: optional tuple of integer scalars.
544
518
see comments for attention_image_summary()
545
- attention_type: a string, either "dot_product" or "local_mask_right " or
546
- "local_unmasked "
519
+ attention_type: a string, either "dot_product" or "local " or
520
+ "local_mask_right "
547
521
block_length: an integer - relevant for "local_mask_right"
548
522
name: an optional string
549
523
@@ -592,11 +566,12 @@ def multihead_attention(query_antecedent,
592
566
if attention_type == "dot_product" :
593
567
x = dot_product_attention (
594
568
q , k , v , bias , dropout_rate , image_shapes )
595
- elif attention_type == "local_mask_right " :
596
- x = masked_local_attention_1d (q , k , v , block_length = block_length )
569
+ elif attention_type == "local " :
570
+ x = local_attention_1d (q , k , v , block_length = block_length )
597
571
else :
598
- assert attention_type == "local_unmasked"
599
- x = unmasked_local_attention_1d (q , k , v , block_length = block_length )
572
+ assert attention_type == "local_mask_right"
573
+ x = local_attention_1d (
574
+ q , k , v , block_length = block_length , look_right = False )
600
575
x = combine_heads (x )
601
576
x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
602
577
return x
0 commit comments