Skip to content

Commit 8c3609a

Browse files
committed
move EMA logic out of the repository for clarity
1 parent 1586d1a commit 8c3609a

File tree

2 files changed

+13
-37
lines changed

2 files changed

+13
-37
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from einops import rearrange, reduce
2020
from einops.layers.torch import Rearrange
2121

22+
from ema_pytorch import EMA
23+
2224
# helpers functions
2325

2426
def exists(x):
@@ -50,21 +52,6 @@ def unnormalize_to_zero_to_one(t):
5052

5153
# small helper modules
5254

53-
class EMA():
54-
def __init__(self, beta):
55-
super().__init__()
56-
self.beta = beta
57-
58-
def update_model_average(self, ma_model, current_model):
59-
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
60-
old_weight, up_weight = ma_params.data, current_params.data
61-
ma_params.data = self.update_average(old_weight, up_weight)
62-
63-
def update_average(self, old, new):
64-
if old is None:
65-
return new
66-
return old * self.beta + (1 - self.beta) * new
67-
6855
class Residual(nn.Module):
6956
def __init__(self, fn):
7057
super().__init__()
@@ -612,8 +599,7 @@ def __init__(
612599
self.image_size = diffusion_model.image_size
613600

614601
self.model = diffusion_model
615-
self.ema = EMA(ema_decay)
616-
self.ema_model = copy.deepcopy(self.model)
602+
self.ema = EMA(diffusion_model, beta = ema_decay)
617603
self.update_ema_every = update_ema_every
618604

619605
self.step_start_ema = step_start_ema
@@ -636,22 +622,11 @@ def __init__(
636622
self.results_folder = Path(results_folder)
637623
self.results_folder.mkdir(exist_ok = True)
638624

639-
self.reset_parameters()
640-
641-
def reset_parameters(self):
642-
self.ema_model.load_state_dict(self.model.state_dict())
643-
644-
def step_ema(self):
645-
if self.step < self.step_start_ema:
646-
self.reset_parameters()
647-
return
648-
self.ema.update_model_average(self.ema_model, self.model)
649-
650625
def save(self, milestone):
651626
data = {
652627
'step': self.step,
653628
'model': self.model.state_dict(),
654-
'ema': self.ema_model.state_dict(),
629+
'ema': self.ema.state_dict(),
655630
'scaler': self.scaler.state_dict()
656631
}
657632
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
@@ -661,7 +636,7 @@ def load(self, milestone):
661636

662637
self.step = data['step']
663638
self.model.load_state_dict(data['model'])
664-
self.ema_model.load_state_dict(data['ema'])
639+
self.ema.load_state_dict(data['ema'])
665640
self.scaler.load_state_dict(data['scaler'])
666641

667642
def train(self):
@@ -681,15 +656,15 @@ def train(self):
681656
self.scaler.update()
682657
self.opt.zero_grad()
683658

684-
if self.step % self.update_ema_every == 0:
685-
self.step_ema()
659+
self.ema.update()
686660

687661
if self.step != 0 and self.step % self.save_and_sample_every == 0:
688-
self.ema_model.eval()
662+
self.ema.ema_model.eval()
663+
with torch.no_grad():
664+
milestone = self.step // self.save_and_sample_every
665+
batches = num_to_groups(36, self.batch_size)
666+
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
689667

690-
milestone = self.step // self.save_and_sample_every
691-
batches = num_to_groups(36, self.batch_size)
692-
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
693668
all_images = torch.cat(all_images_list, dim=0)
694669
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6)
695670
self.save(milestone)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.20.2',
6+
version = '0.21.0',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',
@@ -16,6 +16,7 @@
1616
],
1717
install_requires=[
1818
'einops',
19+
'ema-pytorch',
1920
'pillow',
2021
'torch',
2122
'torchvision',

0 commit comments

Comments
 (0)