Skip to content

Commit 61e484b

Browse files
committed
trust @jfpuget and default example to using parametrize
1 parent 6f72109 commit 61e484b

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.1.17"
3+
version = "0.1.18"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
GENERATE_EVERY = 500
2626
GENERATE_LENGTH = 512
2727
SEQ_LEN = 512
28+
2829
USE_AMP = True
30+
USE_PARAMETRIZE = True # whether to manually update weights after each optimizer step
2931

3032
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
3133

34+
assert not (USE_AMP and not torch.cuda.is_available())
35+
3236
# helpers
3337

3438
def exists(v):
@@ -94,8 +98,8 @@ def base_decoding(
9498
num_tokens = 256,
9599
dim = 512,
96100
depth = 8,
97-
manual_norm_weights = True,
98-
tied_embedding = True
101+
tied_embedding = True,
102+
manual_norm_weights = not USE_PARAMETRIZE
99103
).to(device)
100104

101105
scaler = GradScaler(enabled = USE_AMP)
@@ -153,7 +157,8 @@ def __getitem__(self, index):
153157

154158
optim.zero_grad()
155159

156-
model.norm_weights_()
160+
if not USE_PARAMETRIZE:
161+
model.norm_weights_()
157162

158163
if i % VALIDATE_EVERY == 0:
159164
model.eval()

0 commit comments

Comments
 (0)