Skip to content

Commit 8be11d3

Browse files
committed
allow for model to auto norm weights on optimizer step
1 parent c03b634 commit 8be11d3

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,14 @@ def norm_weights_(self):
459459

460460
module.norm_weights_()
461461

462+
def register_step_post_hook(self, optimizer):
463+
assert hasattr(optimizer, 'register_step_post_hook')
464+
465+
def hook(*_):
466+
self.norm_weights_()
467+
468+
return optimizer.register_step_post_hook(hook)
469+
462470
def forward(
463471
self,
464472
ids,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.1.19"
3+
version = "0.1.20"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def __getitem__(self, index):
139139
train_loader = cycle(train_loader)
140140
val_loader = cycle(val_loader)
141141

142+
# if not using parametrize, register normalizing on optimizer step
143+
144+
if not USE_PARAMETRIZE:
145+
model.register_step_post_hook(optim)
146+
142147
# training
143148

144149
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
@@ -159,9 +164,6 @@ def __getitem__(self, index):
159164

160165
optim.zero_grad()
161166

162-
if not USE_PARAMETRIZE:
163-
model.norm_weights_()
164-
165167
if i % VALIDATE_EVERY == 0:
166168
model.eval()
167169
with torch.no_grad():

0 commit comments

Comments
 (0)