Skip to content

Commit f64fba1

Browse files
committed
revert
1 parent 8023931 commit f64fba1

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,9 @@ class nGPT(Module):
306306
def __init__(
307307
self,
308308
*,
309+
num_tokens,
309310
dim,
310311
depth,
311-
num_tokens = None,
312312
dim_head = 64,
313313
heads = 8,
314314
attn_norm_qk = True, # they say the query/key normalization is optional
@@ -347,12 +347,7 @@ def __init__(
347347
self.causal = causal
348348
alpha_init = default(alpha_init, 1. / depth)
349349

350-
# allow for plain stack of attention and feedforward, for trying to use in a different setting
351-
352-
only_transformer = not exists(num_tokens)
353-
self.only_transformer = only_transformer
354-
355-
self.token_embed = NormLinear_(dim, num_tokens) if not only_transformer else None
350+
self.token_embed = NormLinear_(dim, num_tokens)
356351

357352
self.layers = ModuleList([])
358353

@@ -426,7 +421,7 @@ def __init__(
426421

427422
self.layers.append(ModuleList([attn_with_residual, ff_with_residual]))
428423

429-
self.to_logits = NormLinear_(dim, num_tokens) if (not tied_embedding or only_transformer) or not exists(num_tokens) else None
424+
self.to_logits = NormLinear_(dim, num_tokens) if not tied_embedding else None
430425

431426
self.logit_scale = Scale(num_tokens, s_logit_init, default(s_logit_scale, dim ** -0.5))
432427

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.1.8"
3+
version = "0.1.9"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)