@@ -54,13 +54,15 @@ def forward(self, x):
5454 return self .net (x )
5555
5656class Attention (nn .Module ):
57- def __init__ (self , dim , seq_len , causal = True , heads = 8 , dim_head = 64 , dropout = 0. ):
57+ def __init__ (self , dim , seq_len , causal = True , heads = 8 , dim_head = 64 , dropout = 0. , noncausal_attn_len = 0 ):
5858 super ().__init__ ()
5959 inner_dim = dim_head * heads
6060 self .heads = heads
6161 self .seq_len = seq_len
6262 self .scale = dim ** - 0.5
63+
6364 self .causal = causal
65+ self .noncausal_attn_len = noncausal_attn_len
6466
6567 self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
6668 self .to_out = nn .Sequential (
@@ -84,6 +86,11 @@ def forward(self, x, mask = None):
8486 if self .causal :
8587 i , j = dots .shape [- 2 :]
8688 mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
89+
90+ if self .noncausal_attn_len > 0 :
91+ ind = slice (0 , self .noncausal_attn_len )
92+ mask [ind , ind ] = False
93+
8794 dots .masked_fill_ (mask , mask_value )
8895
8996 attn = dots .softmax (dim = - 1 )
@@ -146,6 +153,10 @@ def forward(self, x, mask = None):
146153 mask_value = - (torch .finfo (q .dtype ).max / 2 )
147154 attn_mask .masked_fill_ (mask , mask_value )
148155
156+ if self .noncausal_attn_len :
157+ ind = slice (0 , self .noncausal_attn_len )
158+ attn_mask [ind , ind ] = 0.
159+
149160 out = self .attn_fn (q , k , v , attn_mask = attn_mask , key_padding_mask = key_pad_mask )
150161 out = rearrange (out , 'b h n d -> b n (h d)' )
151162 out = self .to_out (out )
@@ -165,6 +176,7 @@ def __init__(
165176 ff_mult = 4 ,
166177 attn_dropout = 0. ,
167178 ff_dropout = 0. ,
179+ noncausal_attn_len = 0 ,
168180 sparse_attn = True ,
169181 sparse_attn_global_indices = []
170182 ):
@@ -176,7 +188,7 @@ def __init__(
176188 attn_class = Attention if not sparse_attn else partial (SparseAttention , sparse_attn_global_indices = sparse_attn_global_indices )
177189
178190 layers .append (nn .ModuleList ([
179- PreNorm (dim , attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout )),
191+ PreNorm (dim , attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout , noncausal_attn_len = noncausal_attn_len )),
180192 PreNorm (dim , FeedForward (dim , mult = ff_mult , dropout = ff_dropout ))
181193 ]))
182194
0 commit comments