File tree Expand file tree Collapse file tree 3 files changed +14
-4
lines changed Expand file tree Collapse file tree 3 files changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -459,6 +459,14 @@ def norm_weights_(self):
459
459
460
460
module .norm_weights_ ()
461
461
462
+ def register_step_post_hook (self , optimizer ):
463
+ assert hasattr (optimizer , 'register_step_post_hook' )
464
+
465
+ def hook (* _ ):
466
+ self .norm_weights_ ()
467
+
468
+ return optimizer .register_step_post_hook (hook )
469
+
462
470
def forward (
463
471
self ,
464
472
ids ,
Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " nGPT-pytorch"
3
- version = " 0.1.19 "
3
+ version = " 0.1.20 "
4
4
description = " nGPT"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -139,6 +139,11 @@ def __getitem__(self, index):
139
139
train_loader = cycle (train_loader )
140
140
val_loader = cycle (val_loader )
141
141
142
+ # if not using parametrize, register normalizing on optimizer step
143
+
144
+ if not USE_PARAMETRIZE :
145
+ model .register_step_post_hook (optim )
146
+
142
147
# training
143
148
144
149
for i in tqdm .tqdm (range (NUM_BATCHES ), mininterval = 10.0 , desc = "training" ):
@@ -159,9 +164,6 @@ def __getitem__(self, index):
159
164
160
165
optim .zero_grad ()
161
166
162
- if not USE_PARAMETRIZE :
163
- model .norm_weights_ ()
164
-
165
167
if i % VALIDATE_EVERY == 0 :
166
168
model .eval ()
167
169
with torch .no_grad ():
You can’t perform that action at this time.
0 commit comments