@@ -22,16 +22,22 @@ def default(val, d):
2222def max_neg_value (t ):
2323 return - torch .finfo (t .dtype ).max
2424
25+ def stable_softmax (t , dim = - 1 , alpha = 32 ** 2 ):
26+ t = t / alpha
27+ t = t - torch .amax (t , dim = dim , keepdim = True )
28+ return (t * alpha ).softmax (dim = dim )
29+
2530# classes
2631
2732class Attention (nn .Module ):
28- def __init__ (self , dim , seq_len , causal = True , heads = 8 , dim_head = 64 , dropout = 0. ):
33+ def __init__ (self , dim , seq_len , causal = True , heads = 8 , dim_head = 64 , dropout = 0. , stable = False ):
2934 super ().__init__ ()
3035 inner_dim = dim_head * heads
3136 self .heads = heads
3237 self .seq_len = seq_len
3338 self .scale = dim_head ** - 0.5
3439
40+ self .stable = stable
3541 self .causal = causal
3642
3743 self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
@@ -42,6 +48,8 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
4248
4349 def forward (self , x , mask = None ):
4450 b , n , _ , h , device = * x .shape , self .heads , x .device
51+ softmax = torch .softmax if not self .stable else stable_softmax
52+
4553 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
4654 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
4755
@@ -60,7 +68,7 @@ def forward(self, x, mask = None):
6068 mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
6169 dots .masked_fill_ (mask , mask_value )
6270
63- attn = dots . softmax (dim = - 1 )
71+ attn = softmax (dots , dim = - 1 )
6472
6573 out = torch .einsum ('b h i j, b h j d -> b h i d' , attn , v )
6674 out = rearrange (out , 'b h n d -> b n (h d)' )
@@ -70,7 +78,7 @@ def forward(self, x, mask = None):
7078# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
7179
7280class SparseConvCausalAttention (nn .Module ):
73- def __init__ (self , dim , seq_len , image_size = 32 , kernel_size = 5 , dilation = 1 , heads = 8 , dim_head = 64 , dropout = 0. , ** kwargs ):
81+ def __init__ (self , dim , seq_len , image_size = 32 , kernel_size = 5 , dilation = 1 , heads = 8 , dim_head = 64 , dropout = 0. , stable = False , ** kwargs ):
7482 super ().__init__ ()
7583 assert kernel_size % 2 == 1 , 'kernel size must be odd'
7684
@@ -82,6 +90,8 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
8290 self .kernel_size = kernel_size
8391 self .dilation = dilation
8492
93+ self .stable = stable
94+
8595 self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
8696
8797 self .to_out = nn .Sequential (
@@ -91,6 +101,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
91101
92102 def forward (self , x , mask = None ):
93103 b , n , _ , h , img_size , kernel_size , dilation , seq_len , device = * x .shape , self .heads , self .image_size , self .kernel_size , self .dilation , self .seq_len , x .device
104+ softmax = torch .softmax if not self .stable else stable_softmax
94105
95106 img_seq_len = img_size ** 2
96107 text_len = seq_len + 1 - img_seq_len
@@ -121,7 +132,7 @@ def forward(self, x, mask = None):
121132 text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
122133 dots_text .masked_fill_ (text_causal_mask , mask_value )
123134
124- attn_text = dots_text . softmax (dim = - 1 )
135+ attn_text = softmax (dots_text , dim = - 1 )
125136 out_text = einsum ('b i j, b j d -> b i d' , attn_text , v_text )
126137
127138 # image attention
@@ -163,7 +174,7 @@ def forward(self, x, mask = None):
163174 dots = torch .cat ((dots_image_to_text , dots_image ), dim = - 1 )
164175 dots .masked_fill_ (mask , mask_value )
165176
166- attn = dots . softmax (dim = - 1 )
177+ attn = softmax (dots , dim = - 1 )
167178
168179 # aggregate
169180
@@ -185,7 +196,7 @@ def forward(self, x, mask = None):
185196# sparse axial causal attention
186197
187198class SparseAxialCausalAttention (nn .Module ):
188- def __init__ (self , dim , seq_len , image_size = 32 , axis = 0 , heads = 8 , dim_head = 64 , dropout = 0. , ** kwargs ):
199+ def __init__ (self , dim , seq_len , image_size = 32 , axis = 0 , heads = 8 , dim_head = 64 , dropout = 0. , stable = False , ** kwargs ):
189200 super ().__init__ ()
190201 assert axis in {0 , 1 }, 'axis must be either 0 (along height) or 1 (along width)'
191202 self .axis = axis
@@ -196,6 +207,8 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
196207 self .scale = dim_head ** - 0.5
197208 self .image_size = image_size
198209
210+ self .stable = stable
211+
199212 self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
200213
201214 self .to_out = nn .Sequential (
@@ -205,6 +218,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
205218
206219 def forward (self , x , mask = None ):
207220 b , n , _ , h , img_size , axis , seq_len , device = * x .shape , self .heads , self .image_size , self .axis , self .seq_len , x .device
221+ softmax = torch .softmax if not self .stable else stable_softmax
208222
209223 img_seq_len = img_size ** 2
210224 text_len = seq_len + 1 - img_seq_len
@@ -235,7 +249,7 @@ def forward(self, x, mask = None):
235249 text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
236250 dots_text .masked_fill_ (text_causal_mask , mask_value )
237251
238- attn_text = dots_text . softmax (dim = - 1 )
252+ attn_text = softmax (dots_text , dim = - 1 )
239253 out_text = einsum ('b i j, b j d -> b i d' , attn_text , v_text )
240254
241255 # image attention
@@ -267,7 +281,7 @@ def forward(self, x, mask = None):
267281
268282 # attention.
269283
270- attn = dots . softmax (dim = - 1 )
284+ attn = softmax (dots , dim = - 1 )
271285
272286 # aggregate
273287
0 commit comments