Skip to content

Commit 44f20aa

Browse files
committed
add value residual learning
1 parent 61e484b commit 44f20aa

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,12 @@ $ python train.py
6464
url = {https://api.semanticscholar.org/CorpusID:1505432}
6565
}
6666
```
67+
68+
```bibtex
69+
@inproceedings{Zhou2024ValueRL,
70+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
71+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
72+
year = {2024},
73+
url = {https://api.semanticscholar.org/CorpusID:273532030}
74+
}
75+
```

nGPT_pytorch/nGPT.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,18 @@ def __init__(
9898
def forward(self, x, **kwargs):
9999
residual = x
100100

101-
branch_out = l2norm(self.fn(x, **kwargs))
102-
out = l2norm(residual.lerp(branch_out, self.branch_scale()))
101+
out = self.fn(x, **kwargs)
102+
103+
tuple_output = isinstance(out, tuple)
104+
105+
if tuple_output:
106+
out, *rest = out
107+
108+
out = l2norm(out)
109+
out = l2norm(residual.lerp(out, self.branch_scale()))
110+
111+
if tuple_output:
112+
out = (out, *rest)
103113

104114
return out
105115

@@ -216,7 +226,9 @@ def forward(
216226
self,
217227
x,
218228
mask = None,
219-
rotary_embed: Module | None = None
229+
rotary_embed: Module | None = None,
230+
value_residual = None,
231+
return_values = False
220232
):
221233
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
222234

@@ -245,6 +257,11 @@ def forward(
245257
if exists(mask):
246258
mask = rearrange(mask, 'b j -> b 1 1 j')
247259

260+
# maybe value residual, from resformer paper
261+
262+
if exists(value_residual):
263+
v = v + value_residual
264+
248265
# scale is sqrt(dk)
249266

250267
with self.sdpa_context_manager():
@@ -256,7 +273,12 @@ def forward(
256273
)
257274

258275
out = self.merge_heads(out)
259-
return self.to_out(out)
276+
out = self.to_out(out)
277+
278+
if not return_values:
279+
return out
280+
281+
return out, v
260282

261283
# feedforward
262284

@@ -315,6 +337,7 @@ def __init__(
315337
tied_embedding = False,
316338
num_hyperspheres = 1,
317339
causal = True,
340+
add_value_residual = True,
318341
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
319342
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
320343
s_logit_init: float = 1.,
@@ -344,6 +367,8 @@ def __init__(
344367
self.causal = causal
345368
alpha_init = default(alpha_init, 1. / depth)
346369

370+
self.add_value_residual = add_value_residual # https://arxiv.org/abs/2410.17897v1
371+
347372
self.token_embed = NormLinear_(dim, num_tokens)
348373

349374
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -448,8 +473,13 @@ def forward(
448473

449474
tokens = token_embed[ids]
450475

476+
first_values = None
477+
451478
for attn, ff in self.layers:
452-
tokens = attn(tokens, mask = mask, rotary_embed = rotary_embed)
479+
tokens, values = attn(tokens, mask = mask, rotary_embed = rotary_embed, return_values = True, value_residual = first_values if self.add_value_residual else None)
480+
481+
first_values = default(first_values, values)
482+
453483
tokens = ff(tokens)
454484

455485
if exists(self.to_logits):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.1.18"
3+
version = "0.1.19"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def base_decoding(
9999
dim = 512,
100100
depth = 8,
101101
tied_embedding = True,
102+
add_value_residual = True,
102103
manual_norm_weights = not USE_PARAMETRIZE
103104
).to(device)
104105

0 commit comments

Comments
 (0)