Skip to content

Commit c58093f

Browse files
committed
make all scaling hyperparameter configurable for completeness
1 parent a34355d commit c58093f

File tree

3 files changed

+97
-16
lines changed

3 files changed

+97
-16
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from functools import partial
24

35
import torch
@@ -19,6 +21,11 @@ def exists(v):
1921
def default(v, d):
2022
return v if exists(v) else d
2123

24+
def cast_tuple(t, length = 1):
25+
out = t if isinstance(t, tuple) else ((t,) * length)
26+
assert len(out) == length
27+
return out
28+
2229
def l2norm(t, dim = -1):
2330
return F.normalize(t, dim = dim, p = 2)
2431

@@ -101,7 +108,9 @@ def __init__(
101108
dim_head = 64,
102109
heads = 8,
103110
norm_qk = True,
104-
manual_norm_weights = False
111+
manual_norm_weights = False,
112+
s_qk_init = 1.,
113+
s_qk_scale = None
105114
):
106115
super().__init__()
107116
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
@@ -167,7 +176,11 @@ def __init__(
167176
dim,
168177
*,
169178
expand_factor = 4,
170-
manual_norm_weights = False
179+
manual_norm_weights = False,
180+
s_hidden_init = 1.,
181+
s_hidden_scale = 1.,
182+
s_gate_init = 1.,
183+
s_gate_scale = 1.
171184
):
172185
super().__init__()
173186
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
@@ -178,8 +191,8 @@ def __init__(
178191
self.to_hidden = NormLinear_(dim, dim_inner)
179192
self.to_gate = NormLinear_(dim, dim_inner)
180193

181-
self.hidden_scale = Scale(dim_inner)
182-
self.gate_scale = Scale(dim_inner)
194+
self.hidden_scale = Scale(dim_inner, s_hidden_init, s_hidden_scale)
195+
self.gate_scale = Scale(dim_inner, s_gate_init, s_gate_scale)
183196

184197
self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)
185198

@@ -206,31 +219,98 @@ def __init__(
206219
attn_norm_qk = True, # they say the query/key normalization is optional
207220
ff_expand_factor = 4.,
208221
ce_ignore_index = -1,
209-
residual_lerp_scale_init = None,
210222
manual_norm_weights = False,
211-
tied_embedding = False
223+
tied_embedding = False,
224+
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
225+
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
226+
s_logit_init: float = 1.,
227+
s_logit_scale: float | None = None,
228+
alpha_attn_init: float | tuple[float, ...] | None = None,
229+
alpha_attn_scale: float | tuple[float, ...] | None = None,
230+
alpha_ff_init: float | tuple[float, ...] | None = None,
231+
alpha_ff_scale: float | tuple[float, ...] | None = None,
232+
s_qk_init: float | tuple[float, ...] = 1.,
233+
s_qk_scale: float | tuple[float, ...] | None = None,
234+
s_ff_hidden_init: float | tuple[float, ...] = 1.,
235+
s_ff_hidden_scale: float | tuple[float, ...] = 1.,
236+
s_ff_gate_init: float | tuple[float, ...] = 1.,
237+
s_ff_gate_scale: float | tuple[float, ...] = 1.
212238
):
213239
super().__init__()
214240
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
215241

216242
self.dim = dim
217-
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
243+
alpha_init = default(alpha_init, 1. / depth)
218244

219245
self.token_embed = NormLinear_(dim, num_tokens)
220246

221247
self.layers = ModuleList([])
222248

223-
for _ in range(depth):
224-
self.layers.append(ModuleList([
225-
Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk, manual_norm_weights = manual_norm_weights),
226-
FeedForward(dim, expand_factor = ff_expand_factor, manual_norm_weights = manual_norm_weights),
227-
Scale(dim, residual_lerp_scale_init, dim ** -0.5),
228-
Scale(dim, residual_lerp_scale_init, dim ** -0.5),
229-
]))
249+
scale_hparams = (
250+
alpha_attn_init,
251+
alpha_attn_scale,
252+
alpha_ff_init,
253+
alpha_ff_scale,
254+
s_qk_init,
255+
s_qk_scale,
256+
s_ff_hidden_init,
257+
s_ff_hidden_scale,
258+
s_ff_gate_init,
259+
s_ff_gate_scale
260+
)
261+
262+
scale_hparams = tuple(cast_tuple(hparam, depth) for hparam in scale_hparams)
263+
264+
for (
265+
alpha_attn_init_,
266+
alpha_attn_scale_,
267+
alpha_ff_init_,
268+
alpha_ff_scale_,
269+
s_qk_init_,
270+
s_qk_scale_,
271+
s_ff_hidden_init_,
272+
s_ff_hidden_scale_,
273+
s_ff_gate_init_,
274+
s_ff_gate_scale_
275+
) in zip(*scale_hparams):
276+
277+
attn = Attention(
278+
dim,
279+
dim_head = dim_head,
280+
heads = heads,
281+
norm_qk = attn_norm_qk,
282+
manual_norm_weights = manual_norm_weights,
283+
s_qk_init = s_qk_init_,
284+
s_qk_scale = s_qk_scale_,
285+
)
286+
287+
ff = FeedForward(
288+
dim,
289+
expand_factor = ff_expand_factor,
290+
manual_norm_weights = manual_norm_weights,
291+
s_hidden_init = s_ff_hidden_init_,
292+
s_hidden_scale = s_ff_hidden_scale_,
293+
s_gate_init = s_ff_gate_init_,
294+
s_gate_scale = s_ff_gate_scale_
295+
)
296+
297+
attn_interp_factor = Scale(
298+
dim,
299+
default(alpha_attn_init_, alpha_init),
300+
default(alpha_attn_scale_, dim ** -0.5)
301+
)
302+
303+
ff_interp_factor = Scale(
304+
dim,
305+
default(alpha_ff_init_, alpha_init),
306+
default(alpha_ff_scale_, dim ** -0.5)
307+
)
308+
309+
self.layers.append(ModuleList([attn, ff, attn_interp_factor, ff_interp_factor]))
230310

231311
self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None
232312

233-
self.logit_scale = Scale(num_tokens, 1., dim ** -0.5)
313+
self.logit_scale = Scale(num_tokens, s_logit_init, default(s_logit_scale, dim ** -0.5))
234314

235315
self.ignore_index = ce_ignore_index
236316

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.0.11"
3+
version = "0.0.12"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def base_decoding(
9292
dim = 512,
9393
depth = 8,
9494
manual_norm_weights = True,
95+
tied_embedding = True
9596
).to(device)
9697

9798
# prepare enwik8 data

0 commit comments

Comments
 (0)