44from torch import nn , einsum
55import torch .nn .functional as F
66from inspect import isfunction
7+ from collections import namedtuple
78from functools import partial
89
910from torch .utils .data import Dataset , DataLoader
2223
2324from accelerate import Accelerator
2425
26+ # constants
27+
28+ ModelPrediction = namedtuple ('ModelPrediction' , ['pred_noise' , 'pred_x_start' ])
29+
2530# helpers functions
2631
2732def exists (x ):
@@ -383,25 +388,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
383388class GaussianDiffusion (nn .Module ):
384389 def __init__ (
385390 self ,
386- denoise_fn ,
391+ model ,
387392 * ,
388393 image_size ,
389394 channels = 3 ,
390395 timesteps = 1000 ,
396+ sampling_timesteps = None ,
391397 loss_type = 'l1' ,
392398 objective = 'pred_noise' ,
393399 beta_schedule = 'cosine' ,
394400 p2_loss_weight_gamma = 0. , # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
395- p2_loss_weight_k = 1
401+ p2_loss_weight_k = 1 ,
402+ ddim_sampling_eta = 1.
396403 ):
397404 super ().__init__ ()
398- assert not (type (self ) == GaussianDiffusion and denoise_fn .channels != denoise_fn .out_dim )
405+ assert not (type (self ) == GaussianDiffusion and model .channels != model .out_dim )
399406
400407 self .channels = channels
401408 self .image_size = image_size
402- self .denoise_fn = denoise_fn
409+ self .model = model
403410 self .objective = objective
404411
412+ assert objective in {'pred_noise' , 'pred_x0' }, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start)'
413+
405414 if beta_schedule == 'linear' :
406415 betas = linear_beta_schedule (timesteps )
407416 elif beta_schedule == 'cosine' :
@@ -417,6 +426,14 @@ def __init__(
417426 self .num_timesteps = int (timesteps )
418427 self .loss_type = loss_type
419428
429+ # sampling related parameters
430+
431+ self .sampling_timesteps = default (sampling_timesteps , timesteps ) # default num sampling timesteps to number of timesteps at training
432+
433+ assert self .sampling_timesteps <= timesteps
434+ self .is_ddim_sampling = self .sampling_timesteps < timesteps
435+ self .ddim_sampling_eta = ddim_sampling_eta
436+
420437 # helper function to register buffer from float64 to float32
421438
422439 register_buffer = lambda name , val : self .register_buffer (name , val .to (torch .float32 ))
@@ -457,6 +474,12 @@ def predict_start_from_noise(self, x_t, t, noise):
457474 extract (self .sqrt_recipm1_alphas_cumprod , t , x_t .shape ) * noise
458475 )
459476
477+ def predict_noise_from_start (self , x_t , t , x0 ):
478+ return (
479+ (x0 - extract (self .sqrt_recip_alphas_cumprod , t , x_t .shape ) * x_t ) / \
480+ extract (self .sqrt_recipm1_alphas_cumprod , t , x_t .shape )
481+ )
482+
460483 def q_posterior (self , x_start , x_t , t ):
461484 posterior_mean = (
462485 extract (self .posterior_mean_coef1 , t , x_t .shape ) * x_start +
@@ -466,15 +489,22 @@ def q_posterior(self, x_start, x_t, t):
466489 posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
467490 return posterior_mean , posterior_variance , posterior_log_variance_clipped
468491
469- def p_mean_variance (self , x , t , clip_denoised : bool ):
470- model_output = self .denoise_fn (x , t )
492+ def model_predictions (self , x , t ):
493+ model_output = self .model (x , t )
471494
472495 if self .objective == 'pred_noise' :
473- x_start = self .predict_start_from_noise (x , t = t , noise = model_output )
496+ pred_noise = model_output
497+ x_start = self .predict_start_from_noise (x , t , model_output )
498+
474499 elif self .objective == 'pred_x0' :
500+ pred_noise = self .predict_noise_from_start (x , t , model_output )
475501 x_start = model_output
476- else :
477- raise ValueError (f'unknown objective { self .objective } ' )
502+
503+ return ModelPrediction (pred_noise , x_start )
504+
505+ def p_mean_variance (self , x , t , clip_denoised : bool ):
506+ preds = self .model_predictions (x , t )
507+ x_start = preds .pred_x_start
478508
479509 if clip_denoised :
480510 x_start .clamp_ (- 1. , 1. )
@@ -483,32 +513,59 @@ def p_mean_variance(self, x, t, clip_denoised: bool):
483513 return model_mean , posterior_variance , posterior_log_variance
484514
485515 @torch .no_grad ()
486- def p_sample (self , x , t , clip_denoised = True ):
516+ def p_sample (self , x , t : int , clip_denoised = True ):
487517 b , * _ , device = * x .shape , x .device
488- model_mean , _ , model_log_variance = self .p_mean_variance (x = x , t = t , clip_denoised = clip_denoised )
489- noise = torch .randn_like (x )
490- # no noise when t == 0
491- nonzero_mask = (1 - (t == 0 ).float ()).reshape (b , * ((1 ,) * (len (x .shape ) - 1 )))
492- return model_mean + nonzero_mask * (0.5 * model_log_variance ).exp () * noise
518+ batched_times = torch .full ((x .shape [0 ],), t , device = x .device , dtype = torch .long )
519+ model_mean , _ , model_log_variance = self .p_mean_variance (x = x , t = batched_times , clip_denoised = clip_denoised )
520+ noise = torch .randn_like (x ) if t > 0 else 0. # no noise if t == 0
521+ return model_mean + (0.5 * model_log_variance ).exp () * noise
493522
494523 @torch .no_grad ()
495524 def p_sample_loop (self , shape ):
496- device = self .betas .device
525+ batch , device = shape [ 0 ], self .betas .device
497526
498- b = shape [0 ]
499527 img = torch .randn (shape , device = device )
500528
501- for i in tqdm (reversed (range (0 , self .num_timesteps )), desc = 'sampling loop time step' , total = self .num_timesteps ):
502- img = self .p_sample (img , torch .full ((b ,), i , device = device , dtype = torch .long ))
529+ for t in tqdm (reversed (range (0 , self .num_timesteps )), desc = 'sampling loop time step' ):
530+ img = self .p_sample (img , t )
531+
532+ img = unnormalize_to_zero_to_one (img )
533+ return img
534+
535+ @torch .no_grad ()
536+ def ddim_sample (self , shape , clip_denoised = False ):
537+ batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
538+
539+ times = torch .linspace (0. , total_timesteps , steps = sampling_timesteps + 2 )[:- 1 ]
540+
541+ times = list (reversed (times .int ().tolist ()))
542+ time_pairs = list (zip (times [:- 1 ], times [1 :]))
543+
544+ img = torch .randn (shape , device = device )
545+
546+ for time , time_next in tqdm (time_pairs , desc = 'sampling loop time step' ):
547+ alpha = self .alphas_cumprod_prev [time ]
548+ alpha_next = self .alphas_cumprod_prev [time_next ]
549+
550+ time_cond = torch .full ((batch ,), time , device = device , dtype = torch .long )
551+
552+ pred_noise , x_start , * _ = self .model_predictions (img , time_cond )
553+
554+ c1 = eta * ((1 - alpha / alpha_next ) * (1 - alpha_next ) / (1 - alpha )).sqrt ()
555+ c2 = ((1 - alpha_next ) - torch .square (c1 )).sqrt ()
556+
557+ img = x_start * alpha_next .sqrt () + \
558+ c1 * torch .randn_like (img ) + \
559+ c2 * pred_noise
503560
504561 img = unnormalize_to_zero_to_one (img )
505562 return img
506563
507564 @torch .no_grad ()
508565 def sample (self , batch_size = 16 ):
509- image_size = self .image_size
510- channels = self .channels
511- return self . p_sample_loop ((batch_size , channels , image_size , image_size ))
566+ image_size , channels = self .image_size , self . channels
567+ sample_fn = self .p_sample_loop if not self . is_ddim_sampling else self . ddim_sample
568+ return sample_fn ((batch_size , channels , image_size , image_size ))
512569
513570 @torch .no_grad ()
514571 def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
@@ -547,8 +604,8 @@ def p_losses(self, x_start, t, noise = None):
547604 b , c , h , w = x_start .shape
548605 noise = default (noise , lambda : torch .randn_like (x_start ))
549606
550- x = self .q_sample (x_start = x_start , t = t , noise = noise )
551- model_out = self .denoise_fn (x , t )
607+ x = self .q_sample (x_start = x_start , t = t , noise = noise )
608+ model_out = self .model (x , t )
552609
553610 if self .objective == 'pred_noise' :
554611 target = noise
@@ -677,15 +734,13 @@ def __init__(
677734 self .model , self .dl , self .opt = self .accelerator .prepare (self .model , self .dl , self .opt )
678735
679736 def save (self , milestone ):
680- if not self .accelerator .is_main_process :
737+ if not self .accelerator .is_local_main_process :
681738 return
682739
683- opt = self .accelerator .unwrap_model (self .opt )
684-
685740 data = {
686741 'step' : self .step ,
687742 'model' : self .accelerator .get_state_dict (self .model ),
688- 'opt' : opt .state_dict (),
743+ 'opt' : self . opt .state_dict (),
689744 'ema' : self .ema .state_dict (),
690745 'scaler' : self .accelerator .scaler .state_dict () if exists (self .accelerator .scaler ) else None
691746 }
@@ -696,12 +751,10 @@ def load(self, milestone):
696751 data = torch .load (str (self .results_folder / f'model-{ milestone } .pt' ))
697752
698753 model = self .accelerator .unwrap_model (self .model )
699- opt = self .accelerator .unwrap_model (self .opt )
700-
701754 model .load_state_dict (data ['model' ])
702- opt .load_state_dict (data ['opt' ])
703755
704756 self .step = data ['step' ]
757+ self .opt .load_state_dict (data ['opt' ])
705758 self .ema .load_state_dict (data ['ema' ])
706759
707760 if exists (self .accelerator .scaler ) and exists (data ['scaler' ]):
0 commit comments