Skip to content

Commit 79c5f04

Browse files
committed
allow for turning off horizontal flip augmentation
1 parent 479f60c commit 79c5f04

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,15 +541,15 @@ def forward(self, img, *args, **kwargs):
541541
# dataset classes
542542

543543
class Dataset(data.Dataset):
544-
def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png']):
544+
def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png'], augment_horizontal_flip = False):
545545
super().__init__()
546546
self.folder = folder
547547
self.image_size = image_size
548548
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
549549

550550
self.transform = transforms.Compose([
551551
transforms.Resize(image_size),
552-
transforms.RandomHorizontalFlip(),
552+
transforms.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
553553
transforms.CenterCrop(image_size),
554554
transforms.ToTensor()
555555
])
@@ -580,7 +580,8 @@ def __init__(
580580
step_start_ema = 2000,
581581
update_ema_every = 10,
582582
save_and_sample_every = 1000,
583-
results_folder = './results'
583+
results_folder = './results',
584+
augment_horizontal_flip = True
584585
):
585586
super().__init__()
586587
self.model = diffusion_model
@@ -596,7 +597,7 @@ def __init__(
596597
self.gradient_accumulate_every = gradient_accumulate_every
597598
self.train_num_steps = train_num_steps
598599

599-
self.ds = Dataset(folder, image_size)
600+
self.ds = Dataset(folder, image_size, augment_horizontal_flip = augment_horizontal_flip)
600601
self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True))
601602
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
602603

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.18.0',
6+
version = '0.18.1',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)