@@ -98,8 +98,18 @@ def __init__(
98
98
def forward (self , x , ** kwargs ):
99
99
residual = x
100
100
101
- branch_out = l2norm (self .fn (x , ** kwargs ))
102
- out = l2norm (residual .lerp (branch_out , self .branch_scale ()))
101
+ out = self .fn (x , ** kwargs )
102
+
103
+ tuple_output = isinstance (out , tuple )
104
+
105
+ if tuple_output :
106
+ out , * rest = out
107
+
108
+ out = l2norm (out )
109
+ out = l2norm (residual .lerp (out , self .branch_scale ()))
110
+
111
+ if tuple_output :
112
+ out = (out , * rest )
103
113
104
114
return out
105
115
@@ -216,7 +226,9 @@ def forward(
216
226
self ,
217
227
x ,
218
228
mask = None ,
219
- rotary_embed : Module | None = None
229
+ rotary_embed : Module | None = None ,
230
+ value_residual = None ,
231
+ return_values = False
220
232
):
221
233
q , k , v = self .to_q (x ), self .to_k (x ), self .to_v (x )
222
234
@@ -245,6 +257,11 @@ def forward(
245
257
if exists (mask ):
246
258
mask = rearrange (mask , 'b j -> b 1 1 j' )
247
259
260
+ # maybe value residual, from resformer paper
261
+
262
+ if exists (value_residual ):
263
+ v = v + value_residual
264
+
248
265
# scale is sqrt(dk)
249
266
250
267
with self .sdpa_context_manager ():
@@ -256,7 +273,12 @@ def forward(
256
273
)
257
274
258
275
out = self .merge_heads (out )
259
- return self .to_out (out )
276
+ out = self .to_out (out )
277
+
278
+ if not return_values :
279
+ return out
280
+
281
+ return out , v
260
282
261
283
# feedforward
262
284
@@ -315,6 +337,7 @@ def __init__(
315
337
tied_embedding = False ,
316
338
num_hyperspheres = 1 ,
317
339
causal = True ,
340
+ add_value_residual = True ,
318
341
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
319
342
alpha_init : float | None = None , # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
320
343
s_logit_init : float = 1. ,
@@ -344,6 +367,8 @@ def __init__(
344
367
self .causal = causal
345
368
alpha_init = default (alpha_init , 1. / depth )
346
369
370
+ self .add_value_residual = add_value_residual # https://arxiv.org/abs/2410.17897v1
371
+
347
372
self .token_embed = NormLinear_ (dim , num_tokens )
348
373
349
374
self .rotary_embed = RotaryEmbedding (dim_head )
@@ -448,8 +473,13 @@ def forward(
448
473
449
474
tokens = token_embed [ids ]
450
475
476
+ first_values = None
477
+
451
478
for attn , ff in self .layers :
452
- tokens = attn (tokens , mask = mask , rotary_embed = rotary_embed )
479
+ tokens , values = attn (tokens , mask = mask , rotary_embed = rotary_embed , return_values = True , value_residual = first_values if self .add_value_residual else None )
480
+
481
+ first_values = default (first_values , values )
482
+
453
483
tokens = ff (tokens )
454
484
455
485
if exists (self .to_logits ):
0 commit comments