@@ -541,6 +541,111 @@ def pad_l_and_r(x, pad_length):
541
541
return output
542
542
543
543
544
+ def local_attention_2d (q ,
545
+ k ,
546
+ v ,
547
+ block_length = 128 ,
548
+ filter_flange = 100 ,
549
+ name = None ):
550
+ """strided block local self-attention.
551
+
552
+ Args:
553
+ q: a Tensor with shape [batch, heads, h, w, depth_k]
554
+ k: a Tensor with shape [batch, heads, h, w, depth_k]
555
+ v: a Tensor with shape [batch, heads, h, w, depth_v]
556
+ block_length: an integer indicating the side length of each square block.
557
+ filter_flange: an integer indicating how much to look around each block.
558
+ name: an optional string
559
+
560
+ Returns:
561
+ a Tensor of shape [batch, heads, h, w, depth_v]
562
+ """
563
+ with tf .variable_scope (
564
+ name , default_name = "local_self_attention_2d" , values = [q , k , v ]):
565
+ v_shape = tf .shape (v )
566
+ depth_v = tf .shape (v )[4 ]
567
+ batch_size = tf .shape (q )[0 ]
568
+ num_heads = tf .shape (q )[1 ]
569
+ original_length = tf .shape (q )[2 ] * tf .shape (q )[3 ]
570
+
571
+ def reshape_range (tensor , i , j , shape ):
572
+ """Reshapes a tensor between dimensions i and j."""
573
+ target_shape = tf .concat (
574
+ [tf .shape (tensor )[:i ], shape , tf .shape (tensor )[j :]],
575
+ axis = 0 )
576
+ return tf .reshape (tensor , target_shape )
577
+
578
+ def pad_to_multiple (x , d ):
579
+ """Making sure x is a multiple of d."""
580
+ height_padding = - tf .shape (x )[1 ] % d
581
+ width_padding = - tf .shape (x )[2 ] % d
582
+ paddings = [[0 , 0 ], [0 , 0 ], [0 , height_padding ],
583
+ [0 , width_padding ], [0 , 0 ]]
584
+ return tf .pad (x , paddings )
585
+
586
+ def gather_indices (x , block_length , stride ):
587
+ """Getting gather indices."""
588
+ # making an identity matrix kernel
589
+ kernel = tf .eye (block_length ** 2 )
590
+ kernel = reshape_range (kernel , 0 , 1 , [block_length , block_length , 1 ])
591
+ # making indices [1, h, w, 1] to appy convs
592
+ indices = tf .range (0 , tf .shape (x )[2 ] * tf .shape (x )[3 ], delta = 1 )
593
+ indices = tf .reshape (indices , [1 , tf .shape (x )[2 ], tf .shape (x )[3 ], 1 ])
594
+ indices = tf .nn .conv2d (
595
+ tf .cast (indices , tf .float32 ),
596
+ kernel ,
597
+ strides = [1 , stride , stride , 1 ],
598
+ padding = "VALID" )
599
+ # making indices [num_blocks, dim] to gather
600
+ num_blocks = tf .reduce_prod (tf .shape (indices )[:2 ])
601
+ indices = tf .reshape (indices , [num_blocks , - 1 ])
602
+ return tf .cast (indices , tf .int32 )
603
+
604
+ def gather_blocks (x , indices ):
605
+ """Gathers flattened blocks from x."""
606
+ x_shape = tf .shape (x )
607
+ x = reshape_range (x , 2 , 4 , [tf .reduce_prod (x_shape [2 :4 ])])
608
+ # [length, batch, heads, dim]
609
+ x_t = tf .transpose (x , [2 , 0 , 1 , 3 ])
610
+ x_new = tf .gather (x_t , indices )
611
+ # returns [batch, heads, num_blocks, block_length ** 2, dim]
612
+ return tf .transpose (x_new , [2 , 3 , 0 , 1 , 4 ])
613
+
614
+ q = pad_to_multiple (q , block_length )
615
+ k = pad_to_multiple (k , block_length )
616
+ v = pad_to_multiple (v , block_length )
617
+
618
+ # Setting up k and v values
619
+ paddings = [[0 , 0 ], [0 , 0 ], [filter_flange , filter_flange ],
620
+ [filter_flange , filter_flange ], [0 , 0 ]]
621
+ k = tf .pad (k , paddings )
622
+ v = tf .pad (v , paddings )
623
+
624
+ # Setting up q blocks
625
+ q_indices = gather_indices (q , block_length , block_length )
626
+ q_new = gather_blocks (q , q_indices )
627
+
628
+ # Setting up k and v blocks
629
+ full_filter_width = block_length + 2 * filter_flange
630
+ k_and_v_indices = gather_indices (k , full_filter_width , block_length )
631
+ k_new = gather_blocks (k , k_and_v_indices )
632
+ v_new = gather_blocks (v , k_and_v_indices )
633
+
634
+ attention_bias = tf .expand_dims (
635
+ tf .to_float (embedding_to_padding (k_new )) * - 1e9 , axis = - 2 )
636
+
637
+ logits = tf .matmul (q_new , k_new , transpose_b = True )
638
+
639
+ attention = tf .nn .softmax (logits + attention_bias )
640
+ output = tf .matmul (attention , v_new )
641
+
642
+ output = tf .reshape (output , [batch_size , num_heads , - 1 , depth_v ])
643
+ # Remove the padding if introduced
644
+ output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
645
+ # [batch, heads, h, w, depth_v]
646
+ return tf .reshape (output , v_shape )
647
+
648
+
544
649
def multihead_attention (query_antecedent ,
545
650
memory_antecedent ,
546
651
bias ,
0 commit comments