File tree Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " nGPT-pytorch"
3
- version = " 0.1.17 "
3
+ version = " 0.1.18 "
4
4
description = " nGPT"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change 25
25
GENERATE_EVERY = 500
26
26
GENERATE_LENGTH = 512
27
27
SEQ_LEN = 512
28
+
28
29
USE_AMP = True
30
+ USE_PARAMETRIZE = True # whether to manually update weights after each optimizer step
29
31
30
32
device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
31
33
34
+ assert not (USE_AMP and not torch .cuda .is_available ())
35
+
32
36
# helpers
33
37
34
38
def exists (v ):
@@ -94,8 +98,8 @@ def base_decoding(
94
98
num_tokens = 256 ,
95
99
dim = 512 ,
96
100
depth = 8 ,
97
- manual_norm_weights = True ,
98
- tied_embedding = True
101
+ tied_embedding = True ,
102
+ manual_norm_weights = not USE_PARAMETRIZE
99
103
).to (device )
100
104
101
105
scaler = GradScaler (enabled = USE_AMP )
@@ -153,7 +157,8 @@ def __getitem__(self, index):
153
157
154
158
optim .zero_grad ()
155
159
156
- model .norm_weights_ ()
160
+ if not USE_PARAMETRIZE :
161
+ model .norm_weights_ ()
157
162
158
163
if i % VALIDATE_EVERY == 0 :
159
164
model .eval ()
You can’t perform that action at this time.
0 commit comments