1919from einops import rearrange , reduce
2020from einops .layers .torch import Rearrange
2121
22+ from ema_pytorch import EMA
23+
2224# helpers functions
2325
2426def exists (x ):
@@ -50,21 +52,6 @@ def unnormalize_to_zero_to_one(t):
5052
5153# small helper modules
5254
53- class EMA ():
54- def __init__ (self , beta ):
55- super ().__init__ ()
56- self .beta = beta
57-
58- def update_model_average (self , ma_model , current_model ):
59- for current_params , ma_params in zip (current_model .parameters (), ma_model .parameters ()):
60- old_weight , up_weight = ma_params .data , current_params .data
61- ma_params .data = self .update_average (old_weight , up_weight )
62-
63- def update_average (self , old , new ):
64- if old is None :
65- return new
66- return old * self .beta + (1 - self .beta ) * new
67-
6855class Residual (nn .Module ):
6956 def __init__ (self , fn ):
7057 super ().__init__ ()
@@ -612,8 +599,7 @@ def __init__(
612599 self .image_size = diffusion_model .image_size
613600
614601 self .model = diffusion_model
615- self .ema = EMA (ema_decay )
616- self .ema_model = copy .deepcopy (self .model )
602+ self .ema = EMA (diffusion_model , beta = ema_decay )
617603 self .update_ema_every = update_ema_every
618604
619605 self .step_start_ema = step_start_ema
@@ -636,22 +622,11 @@ def __init__(
636622 self .results_folder = Path (results_folder )
637623 self .results_folder .mkdir (exist_ok = True )
638624
639- self .reset_parameters ()
640-
641- def reset_parameters (self ):
642- self .ema_model .load_state_dict (self .model .state_dict ())
643-
644- def step_ema (self ):
645- if self .step < self .step_start_ema :
646- self .reset_parameters ()
647- return
648- self .ema .update_model_average (self .ema_model , self .model )
649-
650625 def save (self , milestone ):
651626 data = {
652627 'step' : self .step ,
653628 'model' : self .model .state_dict (),
654- 'ema' : self .ema_model .state_dict (),
629+ 'ema' : self .ema .state_dict (),
655630 'scaler' : self .scaler .state_dict ()
656631 }
657632 torch .save (data , str (self .results_folder / f'model-{ milestone } .pt' ))
@@ -661,7 +636,7 @@ def load(self, milestone):
661636
662637 self .step = data ['step' ]
663638 self .model .load_state_dict (data ['model' ])
664- self .ema_model .load_state_dict (data ['ema' ])
639+ self .ema .load_state_dict (data ['ema' ])
665640 self .scaler .load_state_dict (data ['scaler' ])
666641
667642 def train (self ):
@@ -681,15 +656,15 @@ def train(self):
681656 self .scaler .update ()
682657 self .opt .zero_grad ()
683658
684- if self .step % self .update_ema_every == 0 :
685- self .step_ema ()
659+ self .ema .update ()
686660
687661 if self .step != 0 and self .step % self .save_and_sample_every == 0 :
688- self .ema_model .eval ()
662+ self .ema .ema_model .eval ()
663+ with torch .no_grad ():
664+ milestone = self .step // self .save_and_sample_every
665+ batches = num_to_groups (36 , self .batch_size )
666+ all_images_list = list (map (lambda n : self .ema .ema_model .sample (batch_size = n ), batches ))
689667
690- milestone = self .step // self .save_and_sample_every
691- batches = num_to_groups (36 , self .batch_size )
692- all_images_list = list (map (lambda n : self .ema_model .sample (batch_size = n ), batches ))
693668 all_images = torch .cat (all_images_list , dim = 0 )
694669 utils .save_image (all_images , str (self .results_folder / f'sample-{ milestone } .png' ), nrow = 6 )
695670 self .save (milestone )
0 commit comments