Skip to content

Commit 01c4b5c

Browse files
committed
oops
1 parent 0c5f1c3 commit 01c4b5c

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def __init__(
3535
self,
3636
dim,
3737
dim_out,
38-
norm_dim = -1
38+
norm_dim_in = True
3939
):
4040
super().__init__()
4141
self.linear = nn.Linear(dim, dim_out, bias = False)
4242

4343
parametrize.register_parametrization(
4444
self.linear,
4545
'weight',
46-
L2Norm(dim = norm_dim)
46+
L2Norm(dim = -1 if norm_dim_in else 0)
4747
)
4848

4949
@property
@@ -66,9 +66,9 @@ def __init__(
6666
):
6767
super().__init__()
6868
dim_inner = dim_head * heads
69-
self.to_q = NormLinear(dim, dim_inner, norm_dim = 0)
70-
self.to_k = NormLinear(dim, dim_inner, norm_dim = 0)
71-
self.to_v = NormLinear(dim, dim_inner, norm_dim = 0)
69+
self.to_q = NormLinear(dim, dim_inner)
70+
self.to_k = NormLinear(dim, dim_inner)
71+
self.to_v = NormLinear(dim, dim_inner)
7272

7373
self.rotary_emb = RotaryEmbedding(dim_head)
7474
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** -0.25))
@@ -77,7 +77,7 @@ def __init__(
7777
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
7878
self.merge_heads = Rearrange('b h n d -> b n (h d)')
7979

80-
self.to_out = NormLinear(dim_inner, dim)
80+
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
8181

8282
def forward(
8383
self,
@@ -123,13 +123,13 @@ def __init__(
123123
self.dim = dim
124124
dim_inner = int(dim * expand_factor * 2 / 3)
125125

126-
self.to_hidden = NormLinear(dim, dim_inner, norm_dim = 0)
127-
self.to_gate = NormLinear(dim, dim_inner, norm_dim = 0)
126+
self.to_hidden = NormLinear(dim, dim_inner)
127+
self.to_gate = NormLinear(dim, dim_inner)
128128

129129
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
130130
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
131131

132-
self.to_out = NormLinear(dim_inner, dim)
132+
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
133133

134134
def forward(self, x):
135135
hidden, gate = self.to_hidden(x), self.to_gate(x)
@@ -177,7 +177,7 @@ def __init__(
177177
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
178178
]))
179179

180-
self.to_logits = NormLinear(dim, num_tokens, norm_dim = 0)
180+
self.to_logits = NormLinear(dim, num_tokens)
181181

182182
self.logit_scale = nn.Parameter(torch.ones(num_tokens))
183183

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)