@@ -43,6 +43,14 @@ def cycle(dl):
4343 for data in dl :
4444 yield data
4545
46+ def num_to_groups (num , divisor ):
47+ groups = num // divisor
48+ remainder = num % divisor
49+ arr = [divisor ] * groups
50+ if remainder > 0 :
51+ arr .append (remainder )
52+ return arr
53+
4654def loss_backwards (fp16 , loss , optimizer , ** kwargs ):
4755 if fp16 :
4856 with amp .scale_loss (loss , optimizer ) as scaled_loss :
@@ -434,13 +442,16 @@ def __init__(
434442 train_lr = 2e-5 ,
435443 train_num_steps = 100000 ,
436444 gradient_accumulate_every = 2 ,
437- fp16 = False
445+ fp16 = False ,
446+ step_start_ema = 2000
438447 ):
439448 super ().__init__ ()
440449 self .model = diffusion_model
441450 self .ema = EMA (ema_decay )
442451 self .ema_model = copy .deepcopy (self .model )
452+ self .step_start_ema = step_start_ema
443453
454+ self .batch_size = train_batch_size
444455 self .image_size = image_size
445456 self .gradient_accumulate_every = gradient_accumulate_every
446457 self .train_num_steps = train_num_steps
@@ -463,7 +474,7 @@ def reset_parameters(self):
463474 self .ema_model .load_state_dict (self .model .state_dict ())
464475
465476 def step_ema (self ):
466- if self .step < 2000 :
477+ if self .step < self . step_start_ema :
467478 self .reset_parameters ()
468479 return
469480 self .ema .update_model_average (self .ema_model , self .model )
@@ -501,8 +512,10 @@ def train(self):
501512
502513 if self .step % SAVE_AND_SAMPLE_EVERY == 0 :
503514 milestone = self .step // SAVE_AND_SAMPLE_EVERY
504- all_images = self .ema_model .p_sample_loop ((64 , 3 , self .image_size , self .image_size ))
505- utils .save_image (all_images , f'./sample-{ milestone } .png' , nrow = 8 )
515+ batches = num_to_groups (36 , self .batch_size )
516+ all_images_list = list (map (lambda n : self .ema_model .sample (self .image_size , batch_size = n ), batches ))
517+ all_images = torch .cat (all_images_list , dim = 0 )
518+ utils .save_image (all_images , f'./sample-{ milestone } .png' , nrow = 6 )
506519 self .save (milestone )
507520
508521 self .step += 1
0 commit comments