Skip to content

Commit 9779df0

Browse files
committed
address #1
1 parent 1271296 commit 9779df0

File tree

3 files changed

+54
-27
lines changed

3 files changed

+54
-27
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from functools import partial
2+
13
import torch
24
from torch import nn
35
from torch.nn import Module, ModuleList
46
import torch.nn.functional as F
5-
import torch.nn.utils.parametrize as parametrize
7+
from torch.nn.utils.parametrize import register_parametrization
68

79
from einops import rearrange
810
from einops.layers.torch import Rearrange
@@ -35,16 +37,33 @@ def __init__(
3537
self,
3638
dim,
3739
dim_out,
38-
norm_dim_in = True
40+
norm_dim_in = True,
41+
parametrize = True
3942
):
4043
super().__init__()
4144
self.linear = nn.Linear(dim, dim_out, bias = False)
4245

43-
parametrize.register_parametrization(
44-
self.linear,
45-
'weight',
46-
L2Norm(dim = -1 if norm_dim_in else 0)
47-
)
46+
self.parametrize = parametrize
47+
self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0)
48+
49+
if parametrize:
50+
register_parametrization(
51+
self.linear,
52+
'weight',
53+
self.l2norm
54+
)
55+
56+
self.norm_weights_()
57+
58+
@torch.no_grad()
59+
def norm_weights_(self):
60+
if self.parametrize:
61+
normed = self.weight
62+
original = self.linear.parametrizations.weight.original
63+
64+
original.copy_(normed)
65+
else:
66+
self.weight.copy_(self.l2norm(self.weight))
4867

4968
@property
5069
def weight(self):
@@ -62,13 +81,16 @@ def __init__(
6281
*,
6382
dim_head = 64,
6483
heads = 8,
65-
norm_qk = True
84+
norm_qk = True,
85+
manual_norm_weights = False
6686
):
6787
super().__init__()
88+
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
89+
6890
dim_inner = dim_head * heads
69-
self.to_q = NormLinear(dim, dim_inner)
70-
self.to_k = NormLinear(dim, dim_inner)
71-
self.to_v = NormLinear(dim, dim_inner)
91+
self.to_q = NormLinear_(dim, dim_inner)
92+
self.to_k = NormLinear_(dim, dim_inner)
93+
self.to_v = NormLinear_(dim, dim_inner)
7294

7395
self.rotary_emb = RotaryEmbedding(dim_head)
7496
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25))
@@ -77,7 +99,7 @@ def __init__(
7799
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
78100
self.merge_heads = Rearrange('b h n d -> b n (h d)')
79101

80-
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
102+
self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)
81103

82104
def forward(
83105
self,
@@ -117,19 +139,22 @@ def __init__(
117139
self,
118140
dim,
119141
*,
120-
expand_factor = 4
142+
expand_factor = 4,
143+
manual_norm_weights = False
121144
):
122145
super().__init__()
146+
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
147+
123148
self.dim = dim
124149
dim_inner = int(dim * expand_factor * 2 / 3)
125150

126-
self.to_hidden = NormLinear(dim, dim_inner)
127-
self.to_gate = NormLinear(dim, dim_inner)
151+
self.to_hidden = NormLinear_(dim, dim_inner)
152+
self.to_gate = NormLinear_(dim, dim_inner)
128153

129154
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
130155
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
131156

132-
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
157+
self.to_out = NormLinear_(dim_inner, dim, norm_dim_in = False)
133158

134159
def forward(self, x):
135160
hidden, gate = self.to_hidden(x), self.to_gate(x)
@@ -154,30 +179,33 @@ def __init__(
154179
attn_norm_qk = True, # they say the query/key normalization is optional
155180
ff_expand_factor = 4.,
156181
ce_ignore_index = -1,
157-
residual_lerp_scale_init = None
182+
residual_lerp_scale_init = None,
183+
manual_norm_weights = False
158184
):
159185
super().__init__()
186+
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
187+
160188
self.dim = dim
161189

162190
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
163191

164-
self.token_embed = NormLinear(dim, num_tokens)
192+
self.token_embed = NormLinear_(dim, num_tokens)
165193

166194
self.layers = ModuleList([])
167195
self.residual_lerp_scales = nn.ParameterList([])
168196

169197
for _ in range(depth):
170198
self.layers.append(ModuleList([
171-
Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk),
172-
FeedForward(dim, expand_factor = ff_expand_factor),
199+
Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk, manual_norm_weights = manual_norm_weights),
200+
FeedForward(dim, expand_factor = ff_expand_factor, manual_norm_weights = manual_norm_weights),
173201
]))
174202

175203
self.residual_lerp_scales.append(nn.ParameterList([
176204
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
177205
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
178206
]))
179207

180-
self.to_logits = NormLinear(dim, num_tokens)
208+
self.to_logits = NormLinear_(dim, num_tokens)
181209

182210
self.logit_scale = nn.Parameter(torch.ones(num_tokens))
183211

@@ -189,10 +217,7 @@ def norm_weights_(self):
189217
if not isinstance(module, NormLinear):
190218
continue
191219

192-
normed = module.weight
193-
original = module.linear.parametrizations.weight.original
194-
195-
original.copy_(normed)
220+
module.norm_weights_()
196221

197222
def forward(
198223
self,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.0.8"
3+
version = "0.0.9"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.optim import Adam
99
from torch import Tensor
1010
from torch.utils.data import DataLoader, Dataset
11+
import torch.nn.utils.parametrize as parametrize
1112

1213
from nGPT_pytorch import nGPT
1314

@@ -89,7 +90,8 @@ def base_decoding(
8990
model = nGPT(
9091
num_tokens = 256,
9192
dim = 512,
92-
depth = 8
93+
depth = 8,
94+
manual_norm_weights = True
9395
).to(device)
9496

9597
# prepare enwik8 data

0 commit comments

Comments
 (0)