Skip to content

Commit a53cf2a

Browse files
committed
fix more mistakes
1 parent 01c4b5c commit a53cf2a

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self.to_v = NormLinear(dim, dim_inner)
7272

7373
self.rotary_emb = RotaryEmbedding(dim_head)
74-
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** -0.25))
74+
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25))
7575

7676
self.norm_qk = norm_qk
7777
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
@@ -207,10 +207,10 @@ def forward(
207207

208208
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
209209

210-
attn_out = attn(tokens)
210+
attn_out = l2norm(attn(tokens))
211211
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
212212

213-
ff_out = ff(tokens)
213+
ff_out = l2norm(ff(tokens))
214214
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
215215

216216
logits = self.to_logits(tokens)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.0.5"
3+
version = "0.0.6"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def base_decoding(
8989
model = nGPT(
9090
num_tokens = 256,
9191
dim = 512,
92-
depth = 6
92+
depth = 8
9393
).to(device)
9494

9595
# prepare enwik8 data

0 commit comments

Comments
 (0)