@@ -134,6 +134,7 @@ def __init__(
134
134
dim_head = 64 ,
135
135
heads = 8 ,
136
136
norm_qk = True ,
137
+ causal = True ,
137
138
manual_norm_weights = False ,
138
139
s_qk_init = 1. ,
139
140
s_qk_scale = None ,
@@ -145,6 +146,8 @@ def __init__(
145
146
norm_eps = 0.
146
147
):
147
148
super ().__init__ ()
149
+ self .causal = causal
150
+
148
151
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights , norm_eps = norm_eps )
149
152
self .l2norm = partial (l2norm , norm_eps = norm_eps )
150
153
@@ -179,7 +182,8 @@ def __init__(
179
182
180
183
def forward (
181
184
self ,
182
- x
185
+ x ,
186
+ mask = None
183
187
):
184
188
q , k , v = self .to_q (x ), self .to_k (x ), self .to_v (x )
185
189
@@ -202,12 +206,18 @@ def forward(
202
206
q = self .rotary_emb .rotate_queries_or_keys (q )
203
207
k = self .rotary_emb .rotate_queries_or_keys (k )
204
208
209
+ # for non-autoregressive masking
210
+
211
+ if exists (mask ):
212
+ mask = rearrange (mask , 'b j -> b 1 1 j' )
213
+
205
214
# scale is sqrt(dk)
206
215
207
216
with self .sdpa_context_manager ():
208
217
out = F .scaled_dot_product_attention (
209
218
q , k , v ,
210
- is_causal = True ,
219
+ attn_mask = mask ,
220
+ is_causal = self .causal ,
211
221
scale = self .attn_scale
212
222
)
213
223
@@ -268,6 +278,7 @@ def __init__(
268
278
ce_ignore_index = - 1 ,
269
279
manual_norm_weights = False ,
270
280
tied_embedding = False ,
281
+ causal = True ,
271
282
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
272
283
alpha_init : float | None = None , # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
273
284
s_logit_init : float = 1. ,
@@ -294,6 +305,7 @@ def __init__(
294
305
self .l2norm = partial (l2norm , norm_eps = norm_eps )
295
306
296
307
self .dim = dim
308
+ self .causal = causal
297
309
alpha_init = default (alpha_init , 1. / depth )
298
310
299
311
self .token_embed = NormLinear_ (dim , num_tokens )
@@ -332,6 +344,7 @@ def __init__(
332
344
dim ,
333
345
dim_head = dim_head ,
334
346
heads = heads ,
347
+ causal = causal ,
335
348
norm_qk = attn_norm_qk ,
336
349
manual_norm_weights = manual_norm_weights ,
337
350
s_qk_init = s_qk_init_ ,
@@ -382,18 +395,20 @@ def norm_weights_(self):
382
395
def forward (
383
396
self ,
384
397
ids ,
398
+ mask = None ,
385
399
return_loss = False
386
400
):
387
401
token_embed , l2norm = self .token_embed .weight , self .l2norm
388
402
389
403
if return_loss :
404
+ assert self .causal
390
405
ids , labels = ids [:, :- 1 ], ids [:, 1 :]
391
406
392
407
tokens = token_embed [ids ]
393
408
394
409
for attn , ff , attn_alpha , ff_alpha in self .layers :
395
410
396
- attn_out = l2norm (attn (tokens ))
411
+ attn_out = l2norm (attn (tokens , mask = mask ))
397
412
tokens = l2norm (tokens .lerp (attn_out , attn_alpha ()))
398
413
399
414
ff_out = l2norm (ff (tokens ))
0 commit comments