Skip to content

Commit a34355d

Browse files
committed
line up the scaling as in the paper
1 parent f961ad8 commit a34355d

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,25 @@ def default(v, d):
2222
def l2norm(t, dim = -1):
2323
return F.normalize(t, dim = dim, p = 2)
2424

25+
# scale
26+
27+
class Scale(Module):
28+
"""
29+
latter part of section 2.5 in the paper
30+
"""
31+
def __init__(
32+
self,
33+
dim,
34+
init = 1.,
35+
scale = 1.
36+
):
37+
super().__init__()
38+
self.scale = nn.Parameter(torch.ones(dim) * scale)
39+
self.forward_scale = init / scale
40+
41+
def forward(self):
42+
return self.scale * self.forward_scale
43+
2544
# for use with parametrize
2645

2746
class L2Norm(Module):
@@ -87,13 +106,18 @@ def __init__(
87106
super().__init__()
88107
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
89108

109+
dim_sqrt = dim ** 0.5
110+
self.dim_sqrt = dim_sqrt
111+
self.attn_scale = dim_head ** 0.5
112+
90113
dim_inner = dim_head * heads
91114
self.to_q = NormLinear_(dim, dim_inner)
92115
self.to_k = NormLinear_(dim, dim_inner)
93116
self.to_v = NormLinear_(dim, dim_inner)
94117

95118
self.rotary_emb = RotaryEmbedding(dim_head)
96-
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25))
119+
self.q_scale = Scale(dim, 1, dim ** -0.5)
120+
self.k_scale = Scale(dim, 1, dim ** -0.5)
97121

98122
self.norm_qk = norm_qk
99123
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
@@ -107,28 +131,31 @@ def forward(
107131
):
108132
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
109133

134+
# scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs
135+
136+
q = q * self.q_scale()
137+
k = k * self.k_scale()
138+
139+
# split heads
140+
110141
q, k, v = map(self.split_heads, (q, k, v))
111142

112143
# maybe query key norm
113144

114145
if self.norm_qk:
115146
q, k = map(l2norm, (q, k))
116147

117-
# scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs
118-
119-
q, k = (q * self.qk_scale), (k * self.qk_scale)
120-
121148
# rotary positions
122149

123150
q = self.rotary_emb.rotate_queries_or_keys(q)
124151
k = self.rotary_emb.rotate_queries_or_keys(k)
125152

126-
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
153+
# scale is sqrt(dk)
127154

128155
out = F.scaled_dot_product_attention(
129156
q, k, v,
130157
is_causal = True,
131-
scale = 1.
158+
scale = self.attn_scale
132159
)
133160

134161
out = self.merge_heads(out)
@@ -151,16 +178,16 @@ def __init__(
151178
self.to_hidden = NormLinear_(dim, dim_inner)
152179
self.to_gate = NormLinear_(dim, dim_inner)
153180

154-
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
155-
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
181+
self.hidden_scale = Scale(dim_inner)
182+
self.gate_scale = Scale(dim_inner)
156183

157184
self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)
158185

159186
def forward(self, x):
160187
hidden, gate = self.to_hidden(x), self.to_gate(x)
161188

162-
hidden = hidden * self.hidden_scale
163-
gate = gate * self.gate_scale * (self.dim ** 0.5)
189+
hidden = hidden * self.hidden_scale()
190+
gate = gate * self.gate_scale() * (self.dim ** 0.5)
164191

165192
hidden = F.silu(gate) * hidden
166193
return self.to_out(hidden)
@@ -187,28 +214,23 @@ def __init__(
187214
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
188215

189216
self.dim = dim
190-
191217
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
192218

193219
self.token_embed = NormLinear_(dim, num_tokens)
194220

195221
self.layers = ModuleList([])
196-
self.residual_lerp_scales = nn.ParameterList([])
197222

198223
for _ in range(depth):
199224
self.layers.append(ModuleList([
200225
Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk, manual_norm_weights = manual_norm_weights),
201226
FeedForward(dim, expand_factor = ff_expand_factor, manual_norm_weights = manual_norm_weights),
202-
]))
203-
204-
self.residual_lerp_scales.append(nn.ParameterList([
205-
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
206-
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
227+
Scale(dim, residual_lerp_scale_init, dim ** -0.5),
228+
Scale(dim, residual_lerp_scale_init, dim ** -0.5),
207229
]))
208230

209231
self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None
210232

211-
self.logit_scale = nn.Parameter(torch.ones(num_tokens))
233+
self.logit_scale = Scale(num_tokens, 1., dim ** -0.5)
212234

213235
self.ignore_index = ce_ignore_index
214236

@@ -232,21 +254,21 @@ def forward(
232254
token_embed = self.token_embed.weight
233255
tokens = token_embed[ids]
234256

235-
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
257+
for attn, ff, attn_alpha, ff_alpha in self.layers:
236258

237259
attn_out = l2norm(attn(tokens))
238-
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
260+
tokens = l2norm(tokens.lerp(attn_out, attn_alpha()))
239261

240262
ff_out = l2norm(ff(tokens))
241-
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
263+
tokens = l2norm(tokens.lerp(ff_out, ff_alpha()))
242264

243265
if exists(self.to_logits):
244266
logits = self.to_logits(tokens)
245267
else:
246268
# tied embeddings
247269
logits = einsum(tokens, token_embed, 'b n d, c d -> b n c')
248270

249-
logits = logits * self.logit_scale * (self.dim ** 0.5)
271+
logits = logits * self.logit_scale()
250272

251273
if not return_loss:
252274
return logits

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.0.10"
3+
version = "0.0.11"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
NUM_BATCHES = int(1e5)
1818
BATCH_SIZE = 4
1919
GRAD_ACCUM_EVERY = 4
20-
LEARNING_RATE = 1e-4
20+
LEARNING_RATE = 1e-3
2121
VALIDATE_EVERY = 100
2222
PRIME_LENGTH = 128
2323
GENERATE_EVERY = 500
@@ -91,7 +91,7 @@ def base_decoding(
9191
num_tokens = 256,
9292
dim = 512,
9393
depth = 8,
94-
manual_norm_weights = True
94+
manual_norm_weights = True,
9595
).to(device)
9696

9797
# prepare enwik8 data

0 commit comments

Comments
 (0)