Skip to content

Commit 19f4212

Browse files
committed
make sure to freeze VAE parameters after being passed into DALL-E
1 parent a0e8ea4 commit 19f4212

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def masked_mean(t, mask, dim = 1):
3131
t = t.masked_fill(~mask[:, :, None], 0.)
3232
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
3333

34+
def set_requires_grad(model, value):
35+
for param in model.parameters():
36+
param.requires_grad = value
37+
3438
def eval_decorator(fn):
3539
def inner(model, *args, **kwargs):
3640
was_training = model.training
@@ -347,6 +351,7 @@ def __init__(
347351
self.total_seq_len = seq_len
348352

349353
self.vae = vae
354+
set_requires_grad(self.vae, False) # freeze VAE from being trained
350355

351356
self.transformer = Transformer(
352357
dim = dim,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.11.2',
7+
version = '0.11.3',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

train_dalle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
def exists(val):
6161
return val is not None
6262

63+
def get_trainable_params(model):
64+
return [params for params in model.parameters() if params.requires_grad]
6365

6466
# constants
6567

@@ -229,7 +231,7 @@ def group_weight(model):
229231

230232
# optimizer
231233

232-
opt = Adam(dalle.parameters(), lr=LEARNING_RATE)
234+
opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
233235

234236
if LR_DECAY:
235237
scheduler = ReduceLROnPlateau(
@@ -272,7 +274,7 @@ def group_weight(model):
272274
args=args,
273275
model=dalle,
274276
optimizer=opt,
275-
model_parameters=dalle.parameters(),
277+
model_parameters=get_trainable_params(dalle),
276278
training_data=ds if using_deepspeed else dl,
277279
lr_scheduler=scheduler if LR_DECAY else None,
278280
config_params=deepspeed_config,

0 commit comments

Comments
 (0)