Skip to content

Commit b33a48e

Browse files
committed
make sure when sampling, batch does not exceed training batch size
1 parent 8e5fb17 commit b33a48e

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def cycle(dl):
4343
for data in dl:
4444
yield data
4545

46+
def num_to_groups(num, divisor):
47+
groups = num // divisor
48+
remainder = num % divisor
49+
arr = [divisor] * groups
50+
if remainder > 0:
51+
arr.append(remainder)
52+
return arr
53+
4654
def loss_backwards(fp16, loss, optimizer, **kwargs):
4755
if fp16:
4856
with amp.scale_loss(loss, optimizer) as scaled_loss:
@@ -434,13 +442,16 @@ def __init__(
434442
train_lr = 2e-5,
435443
train_num_steps = 100000,
436444
gradient_accumulate_every = 2,
437-
fp16 = False
445+
fp16 = False,
446+
step_start_ema = 2000
438447
):
439448
super().__init__()
440449
self.model = diffusion_model
441450
self.ema = EMA(ema_decay)
442451
self.ema_model = copy.deepcopy(self.model)
452+
self.step_start_ema = step_start_ema
443453

454+
self.batch_size = train_batch_size
444455
self.image_size = image_size
445456
self.gradient_accumulate_every = gradient_accumulate_every
446457
self.train_num_steps = train_num_steps
@@ -463,7 +474,7 @@ def reset_parameters(self):
463474
self.ema_model.load_state_dict(self.model.state_dict())
464475

465476
def step_ema(self):
466-
if self.step < 2000:
477+
if self.step < self.step_start_ema:
467478
self.reset_parameters()
468479
return
469480
self.ema.update_model_average(self.ema_model, self.model)
@@ -501,8 +512,10 @@ def train(self):
501512

502513
if self.step % SAVE_AND_SAMPLE_EVERY == 0:
503514
milestone = self.step // SAVE_AND_SAMPLE_EVERY
504-
all_images = self.ema_model.p_sample_loop((64, 3, self.image_size, self.image_size))
505-
utils.save_image(all_images, f'./sample-{milestone}.png', nrow=8)
515+
batches = num_to_groups(36, self.batch_size)
516+
all_images_list = list(map(lambda n: self.ema_model.sample(self.image_size, batch_size=n), batches))
517+
all_images = torch.cat(all_images_list, dim=0)
518+
utils.save_image(all_images, f'./sample-{milestone}.png', nrow=6)
506519
self.save(milestone)
507520

508521
self.step += 1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.2.4',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)