Skip to content

Commit c535d31

Browse files
authored
Merge pull request #51 from lucidrains/pw/elucidating-ddpm
elucidating diffusion, first pass
2 parents d26acbc + 5db64fe commit c535d31

File tree

3 files changed

+226
-0
lines changed

3 files changed

+226
-0
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,13 @@ Samples and model checkpoints will be logged to `./results` periodically
133133
volume = {abs/2204.00227}
134134
}
135135
```
136+
137+
```bibtex
138+
@article{Karras2022ElucidatingTD,
139+
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
140+
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
141+
journal = {ArXiv},
142+
year = {2022},
143+
volume = {abs/2206.00364}
144+
}
145+
```

denoising_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from denoising_diffusion_pytorch.learned_gaussian_diffusion import LearnedGaussianDiffusion
44
from denoising_diffusion_pytorch.continuous_time_gaussian_diffusion import ContinuousTimeGaussianDiffusion
55
from denoising_diffusion_pytorch.weighted_objective_gaussian_diffusion import WeightedObjectiveGaussianDiffusion
6+
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from math import sqrt
2+
import torch
3+
from torch import nn, einsum
4+
import torch.nn.functional as F
5+
6+
from tqdm import tqdm
7+
from einops import rearrange, repeat, reduce
8+
9+
# helpers
10+
11+
def exists(val):
12+
return val is not None
13+
14+
def default(val, d):
15+
if exists(val):
16+
return val
17+
return d() if callable(d) else d
18+
19+
# tensor helpers
20+
21+
def log(t, eps = 1e-20):
22+
return torch.log(t.clamp(min = eps))
23+
24+
# normalization functions
25+
26+
def normalize_to_neg_one_to_one(img):
27+
return img * 2 - 1
28+
29+
def unnormalize_to_zero_to_one(t):
30+
return (t + 1) * 0.5
31+
32+
# main class
33+
34+
class ElucidatedDiffusion(nn.Module):
35+
def __init__(
36+
self,
37+
net,
38+
*,
39+
image_size,
40+
channels = 3,
41+
num_sample_steps = 32, # number of sampling steps
42+
sigma_min = 0.002, # min noise level
43+
sigma_max = 80, # max noise level
44+
sigma_data = 0.5, # standard deviation of data distribution
45+
rho = 7, # controls the sampling schedule
46+
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
47+
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
48+
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
49+
S_tmin = 0.05,
50+
S_tmax = 50,
51+
S_noise = 1.003,
52+
):
53+
super().__init__()
54+
assert net.learned_sinusoidal_cond
55+
56+
self.net = net
57+
58+
# image dimensions
59+
60+
self.channels = channels
61+
self.image_size = image_size
62+
63+
# parameters
64+
65+
self.sigma_min = sigma_min
66+
self.sigma_max = sigma_max
67+
self.sigma_data = sigma_data
68+
69+
self.rho = rho
70+
71+
self.P_mean = P_mean
72+
self.P_std = P_std
73+
74+
self.num_sample_steps = num_sample_steps # otherwise known as N in the paper
75+
76+
self.S_churn = S_churn
77+
self.S_tmin = S_tmin
78+
self.S_tmax = S_tmax
79+
self.S_noise = S_noise
80+
81+
@property
82+
def device(self):
83+
return next(self.net.parameters()).device
84+
85+
# derived preconditioning params - Table 1
86+
87+
def c_skip(self, sigma):
88+
return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2)
89+
90+
def c_out(self, sigma):
91+
return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5
92+
93+
def c_in(self, sigma):
94+
return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5
95+
96+
def c_noise(self, sigma):
97+
return log(sigma) * 0.25
98+
99+
# noise distribution
100+
101+
def noise_distribution(self, batch_size):
102+
return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp()
103+
104+
def loss_weight(self, sigma):
105+
return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2
106+
107+
# sample schedule
108+
# equation (5) in the paper
109+
110+
def sample_schedule(self, num_sample_steps = None):
111+
num_sample_steps = default(num_sample_steps, self.num_sample_steps)
112+
113+
N = num_sample_steps
114+
inv_rho = 1 / self.rho
115+
116+
steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
117+
sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
118+
119+
sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
120+
return sigmas
121+
122+
# preconditioned network output
123+
# equation (7) in the paper
124+
125+
def preconditioned_network_forward(self, noised_images, sigma):
126+
batch, device = noised_images.shape[0], noised_images.device
127+
128+
if isinstance(sigma, float):
129+
sigma = torch.full((batch,), sigma, device = device)
130+
131+
padded_sigma = rearrange(sigma, 'b -> b 1 1 1')
132+
133+
net_out = self.net(
134+
self.c_in(padded_sigma) * noised_images,
135+
self.c_noise(sigma)
136+
)
137+
138+
return self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out
139+
140+
# sampling
141+
142+
@torch.no_grad()
143+
def sample(self, batch_size = 16):
144+
shape = (batch_size, self.channels, self.image_size, self.image_size)
145+
146+
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
147+
148+
sigmas = self.sample_schedule()
149+
150+
gammas = torch.where(
151+
(sigmas >= self.S_tmin) & (sigmas <= self.S_tmax),
152+
min(self.S_churn / self.num_sample_steps, sqrt(2) - 1),
153+
0.
154+
)
155+
156+
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
157+
158+
# images is noise at the beginning
159+
160+
init_sigma = sigmas[0]
161+
162+
images = init_sigma * torch.randn(shape, device = self.device)
163+
164+
# gradually denoise
165+
166+
for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
167+
sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
168+
169+
eps = gamma * torch.randn(shape, device = self.device)
170+
171+
sigma_hat = sigma + gamma * sigma
172+
images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
173+
174+
model_output = self.preconditioned_network_forward(images_hat, sigma_hat)
175+
denoised_over_sigma = (images_hat - model_output) / sigma_hat
176+
177+
images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
178+
179+
# second order correction, if not the last timestep
180+
181+
if sigma_next != 0:
182+
model_output_next = self.preconditioned_network_forward(images_next, sigma_next)
183+
denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
184+
images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
185+
186+
images = images_next
187+
188+
images = images.clamp(-1., 1.)
189+
return unnormalize_to_zero_to_one(images)
190+
191+
# training
192+
193+
def forward(self, images):
194+
batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels
195+
196+
assert h == image_size and w == image_size, f'height and width of image must be {image_size}'
197+
assert c == channels, 'mismatch of image channels'
198+
199+
images = normalize_to_neg_one_to_one(images)
200+
201+
sigmas = self.noise_distribution(batch_size)
202+
padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1')
203+
204+
noise = torch.randn_like(images)
205+
206+
noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
207+
208+
denoised = self.preconditioned_network_forward(noised_images, sigmas)
209+
210+
losses = F.mse_loss(denoised, images, reduction = 'none')
211+
losses = reduce(losses, 'b ... -> b', 'mean')
212+
213+
losses = losses * self.loss_weight(sigmas)
214+
215+
return losses.mean()

0 commit comments

Comments
 (0)