@@ -541,15 +541,15 @@ def forward(self, img, *args, **kwargs):
541541# dataset classes
542542
543543class Dataset (data .Dataset ):
544- def __init__ (self , folder , image_size , exts = ['jpg' , 'jpeg' , 'png' ]):
544+ def __init__ (self , folder , image_size , exts = ['jpg' , 'jpeg' , 'png' ], augment_horizontal_flip = False ):
545545 super ().__init__ ()
546546 self .folder = folder
547547 self .image_size = image_size
548548 self .paths = [p for ext in exts for p in Path (f'{ folder } ' ).glob (f'**/*.{ ext } ' )]
549549
550550 self .transform = transforms .Compose ([
551551 transforms .Resize (image_size ),
552- transforms .RandomHorizontalFlip (),
552+ transforms .RandomHorizontalFlip () if augment_horizontal_flip else nn . Identity () ,
553553 transforms .CenterCrop (image_size ),
554554 transforms .ToTensor ()
555555 ])
@@ -580,7 +580,8 @@ def __init__(
580580 step_start_ema = 2000 ,
581581 update_ema_every = 10 ,
582582 save_and_sample_every = 1000 ,
583- results_folder = './results'
583+ results_folder = './results' ,
584+ augment_horizontal_flip = True
584585 ):
585586 super ().__init__ ()
586587 self .model = diffusion_model
@@ -596,7 +597,7 @@ def __init__(
596597 self .gradient_accumulate_every = gradient_accumulate_every
597598 self .train_num_steps = train_num_steps
598599
599- self .ds = Dataset (folder , image_size )
600+ self .ds = Dataset (folder , image_size , augment_horizontal_flip = augment_horizontal_flip )
600601 self .dl = cycle (data .DataLoader (self .ds , batch_size = train_batch_size , shuffle = True , pin_memory = True ))
601602 self .opt = Adam (diffusion_model .parameters (), lr = train_lr )
602603
0 commit comments