Skip to content

Commit e35d871

Browse files
committed
value residual for the experimental version
1 parent 6fcd3c6 commit e35d871

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

nGPT_pytorch/nGPTExperimental.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,21 @@ def __init__(
9898
def forward(self, x, **kwargs):
9999
residual = x
100100

101-
branch_out = l2norm(self.fn(x, **kwargs))
101+
branch_out = self.fn(x, **kwargs)
102102

103+
is_tuple_output = isinstance(branch_out, tuple)
104+
105+
if is_tuple_output:
106+
branch_out, *rest = branch_out
107+
108+
branch_out = l2norm(branch_out)
103109
not_ortho = einsum(branch_out, residual, '... d, ... d -> ...').square().mean()
104110

105111
out = l2norm(residual.lerp(branch_out, self.branch_scale()))
106112

113+
if is_tuple_output:
114+
out = (out, *rest)
115+
107116
return out, not_ortho
108117

109118
# for use with parametrize
@@ -222,14 +231,20 @@ def __init__(
222231
def forward(
223232
self,
224233
x,
225-
mask = None
234+
mask = None,
235+
value_residual = None
226236
):
227237
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
228238

229239
# split heads
230240

231241
q, k, v = map(self.split_heads, (q, k, v))
232242

243+
# value residual - https://arxiv.org/abs/2410.17897
244+
245+
if exists(value_residual):
246+
v = v + value_residual
247+
233248
# maybe query key norm
234249

235250
if self.norm_qk:
@@ -261,7 +276,7 @@ def forward(
261276
)
262277

263278
out = self.merge_heads(out)
264-
return self.to_out(out)
279+
return self.to_out(out), v
265280

266281
# feedforward
267282

@@ -460,12 +475,16 @@ def forward(
460475

461476
tokens = token_embed[ids]
462477

478+
value_residual = None
479+
463480
aux_loss = 0.
464481

465482
for attn, ff in self.layers:
466-
tokens, ortho_loss = attn(tokens, mask = mask)
483+
(tokens, values), ortho_loss = attn(tokens, mask = mask, value_residual = value_residual)
467484
aux_loss = aux_loss + ortho_loss
468485

486+
value_residual = default(value_residual, values)
487+
469488
tokens, ortho_loss = ff(tokens)
470489
aux_loss = aux_loss + ortho_loss
471490

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.1.16"
3+
version = "0.1.17"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.utils.data import DataLoader, Dataset
1212
import torch.nn.utils.parametrize as parametrize
1313

14-
from nGPT_pytorch import nGPT
14+
from nGPT_pytorch.nGPTExperimental import nGPT
1515

1616
# constants
1717

0 commit comments

Comments
 (0)