Skip to content

Commit f961ad8

Browse files
committed
add tied embeddings for @faresobeid to play around with
1 parent 9779df0 commit f961ad8

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn.functional as F
77
from torch.nn.utils.parametrize import register_parametrization
88

9-
from einops import rearrange
9+
from einops import rearrange, einsum
1010
from einops.layers.torch import Rearrange
1111

1212
from rotary_embedding_torch import RotaryEmbedding
@@ -180,7 +180,8 @@ def __init__(
180180
ff_expand_factor = 4.,
181181
ce_ignore_index = -1,
182182
residual_lerp_scale_init = None,
183-
manual_norm_weights = False
183+
manual_norm_weights = False,
184+
tied_embedding = False
184185
):
185186
super().__init__()
186187
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
@@ -205,7 +206,7 @@ def __init__(
205206
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
206207
]))
207208

208-
self.to_logits = NormLinear_(dim, num_tokens)
209+
self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None
209210

210211
self.logit_scale = nn.Parameter(torch.ones(num_tokens))
211212

@@ -228,7 +229,8 @@ def forward(
228229
if return_loss:
229230
ids, labels = ids[:, :-1], ids[:, 1:]
230231

231-
tokens = self.token_embed.weight[ids]
232+
token_embed = self.token_embed.weight
233+
tokens = token_embed[ids]
232234

233235
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
234236

@@ -238,7 +240,12 @@ def forward(
238240
ff_out = l2norm(ff(tokens))
239241
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
240242

241-
logits = self.to_logits(tokens)
243+
if exists(self.to_logits):
244+
logits = self.to_logits(tokens)
245+
else:
246+
# tied embeddings
247+
logits = einsum(tokens, token_embed, 'b n d, c d -> b n c')
248+
242249
logits = logits * self.logit_scale * (self.dim ** 0.5)
243250

244251
if not return_loss:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)