Skip to content

Commit cb4d571

Browse files
committed
bring in ddim sampling
1 parent a0c3443 commit cb4d571

File tree

6 files changed

+133
-51
lines changed

6 files changed

+133
-51
lines changed

README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ model = Unet(
6060
diffusion = GaussianDiffusion(
6161
model,
6262
image_size = 128,
63-
timesteps = 1000, # number of steps
64-
loss_type = 'l1' # L1 or L2
63+
timesteps = 1000, # number of steps
64+
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
65+
loss_type = 'l1' # L1 or L2
6566
).cuda()
6667

6768
trainer = Trainer(
@@ -159,3 +160,13 @@ $ accelerate launch train.py
159160
volume = {abs/2206.00364}
160161
}
161162
```
163+
164+
```bibtex
165+
@article{Song2021DenoisingDI,
166+
title = {Denoising Diffusion Implicit Models},
167+
author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
168+
journal = {ArXiv},
169+
year = {2021},
170+
volume = {abs/2010.02502}
171+
}
172+
```

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def forward(self, x):
112112
class ContinuousTimeGaussianDiffusion(nn.Module):
113113
def __init__(
114114
self,
115-
denoise_fn,
115+
model,
116116
*,
117117
image_size,
118118
channels = 3,
@@ -126,9 +126,9 @@ def __init__(
126126
p2_loss_weight_k = 1
127127
):
128128
super().__init__()
129-
assert denoise_fn.learned_sinusoidal_cond
129+
assert model.learned_sinusoidal_cond
130130

131-
self.denoise_fn = denoise_fn
131+
self.model = model
132132

133133
# image dimensions
134134

@@ -170,7 +170,7 @@ def __init__(
170170

171171
@property
172172
def device(self):
173-
return next(self.denoise_fn.parameters()).device
173+
return next(self.model.parameters()).device
174174

175175
@property
176176
def loss_fn(self):
@@ -195,7 +195,7 @@ def p_mean_variance(self, x, time, time_next):
195195
alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))
196196

197197
batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
198-
pred_noise = self.denoise_fn(x, batch_log_snr)
198+
pred_noise = self.model(x, batch_log_snr)
199199

200200
if self.clip_sample_denoised:
201201
x_start = (x - sigma * pred_noise) / alpha
@@ -266,7 +266,7 @@ def p_losses(self, x_start, times, noise = None):
266266
noise = default(noise, lambda: torch.randn_like(x_start))
267267

268268
x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
269-
model_out = self.denoise_fn(x, log_snr)
269+
model_out = self.model(x, log_snr)
270270

271271
losses = self.loss_fn(model_out, noise, reduction = 'none')
272272
losses = reduce(losses, 'b ... -> b', 'mean')

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch import nn, einsum
55
import torch.nn.functional as F
66
from inspect import isfunction
7+
from collections import namedtuple
78
from functools import partial
89

910
from torch.utils.data import Dataset, DataLoader
@@ -22,6 +23,10 @@
2223

2324
from accelerate import Accelerator
2425

26+
# constants
27+
28+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
29+
2530
# helpers functions
2631

2732
def exists(x):
@@ -383,25 +388,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
383388
class GaussianDiffusion(nn.Module):
384389
def __init__(
385390
self,
386-
denoise_fn,
391+
model,
387392
*,
388393
image_size,
389394
channels = 3,
390395
timesteps = 1000,
396+
sampling_timesteps = None,
391397
loss_type = 'l1',
392398
objective = 'pred_noise',
393399
beta_schedule = 'cosine',
394400
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
395-
p2_loss_weight_k = 1
401+
p2_loss_weight_k = 1,
402+
ddim_sampling_eta = 1.
396403
):
397404
super().__init__()
398-
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
405+
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
399406

400407
self.channels = channels
401408
self.image_size = image_size
402-
self.denoise_fn = denoise_fn
409+
self.model = model
403410
self.objective = objective
404411

412+
assert objective in {'pred_noise', 'pred_x0'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start)'
413+
405414
if beta_schedule == 'linear':
406415
betas = linear_beta_schedule(timesteps)
407416
elif beta_schedule == 'cosine':
@@ -417,6 +426,14 @@ def __init__(
417426
self.num_timesteps = int(timesteps)
418427
self.loss_type = loss_type
419428

429+
# sampling related parameters
430+
431+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
432+
433+
assert self.sampling_timesteps <= timesteps
434+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
435+
self.ddim_sampling_eta = ddim_sampling_eta
436+
420437
# helper function to register buffer from float64 to float32
421438

422439
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
@@ -457,6 +474,12 @@ def predict_start_from_noise(self, x_t, t, noise):
457474
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
458475
)
459476

477+
def predict_noise_from_start(self, x_t, t, x0):
478+
return (
479+
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
480+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
481+
)
482+
460483
def q_posterior(self, x_start, x_t, t):
461484
posterior_mean = (
462485
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
@@ -466,15 +489,22 @@ def q_posterior(self, x_start, x_t, t):
466489
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
467490
return posterior_mean, posterior_variance, posterior_log_variance_clipped
468491

469-
def p_mean_variance(self, x, t, clip_denoised: bool):
470-
model_output = self.denoise_fn(x, t)
492+
def model_predictions(self, x, t):
493+
model_output = self.model(x, t)
471494

472495
if self.objective == 'pred_noise':
473-
x_start = self.predict_start_from_noise(x, t = t, noise = model_output)
496+
pred_noise = model_output
497+
x_start = self.predict_start_from_noise(x, t, model_output)
498+
474499
elif self.objective == 'pred_x0':
500+
pred_noise = self.predict_noise_from_start(x, t, model_output)
475501
x_start = model_output
476-
else:
477-
raise ValueError(f'unknown objective {self.objective}')
502+
503+
return ModelPrediction(pred_noise, x_start)
504+
505+
def p_mean_variance(self, x, t, clip_denoised: bool):
506+
preds = self.model_predictions(x, t)
507+
x_start = preds.pred_x_start
478508

479509
if clip_denoised:
480510
x_start.clamp_(-1., 1.)
@@ -483,32 +513,59 @@ def p_mean_variance(self, x, t, clip_denoised: bool):
483513
return model_mean, posterior_variance, posterior_log_variance
484514

485515
@torch.no_grad()
486-
def p_sample(self, x, t, clip_denoised=True):
516+
def p_sample(self, x, t: int, clip_denoised = True):
487517
b, *_, device = *x.shape, x.device
488-
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
489-
noise = torch.randn_like(x)
490-
# no noise when t == 0
491-
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
492-
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
518+
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
519+
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = batched_times, clip_denoised = clip_denoised)
520+
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
521+
return model_mean + (0.5 * model_log_variance).exp() * noise
493522

494523
@torch.no_grad()
495524
def p_sample_loop(self, shape):
496-
device = self.betas.device
525+
batch, device = shape[0], self.betas.device
497526

498-
b = shape[0]
499527
img = torch.randn(shape, device=device)
500528

501-
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
502-
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
529+
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step'):
530+
img = self.p_sample(img, t)
531+
532+
img = unnormalize_to_zero_to_one(img)
533+
return img
534+
535+
@torch.no_grad()
536+
def ddim_sample(self, shape, clip_denoised = False):
537+
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
538+
539+
times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]
540+
541+
times = list(reversed(times.int().tolist()))
542+
time_pairs = list(zip(times[:-1], times[1:]))
543+
544+
img = torch.randn(shape, device = device)
545+
546+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
547+
alpha = self.alphas_cumprod_prev[time]
548+
alpha_next = self.alphas_cumprod_prev[time_next]
549+
550+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
551+
552+
pred_noise, x_start, *_ = self.model_predictions(img, time_cond)
553+
554+
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
555+
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
556+
557+
img = x_start * alpha_next.sqrt() + \
558+
c1 * torch.randn_like(img) + \
559+
c2 * pred_noise
503560

504561
img = unnormalize_to_zero_to_one(img)
505562
return img
506563

507564
@torch.no_grad()
508565
def sample(self, batch_size = 16):
509-
image_size = self.image_size
510-
channels = self.channels
511-
return self.p_sample_loop((batch_size, channels, image_size, image_size))
566+
image_size, channels = self.image_size, self.channels
567+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
568+
return sample_fn((batch_size, channels, image_size, image_size))
512569

513570
@torch.no_grad()
514571
def interpolate(self, x1, x2, t = None, lam = 0.5):
@@ -547,8 +604,8 @@ def p_losses(self, x_start, t, noise = None):
547604
b, c, h, w = x_start.shape
548605
noise = default(noise, lambda: torch.randn_like(x_start))
549606

550-
x = self.q_sample(x_start=x_start, t=t, noise=noise)
551-
model_out = self.denoise_fn(x, t)
607+
x = self.q_sample(x_start = x_start, t = t, noise = noise)
608+
model_out = self.model(x, t)
552609

553610
if self.objective == 'pred_noise':
554611
target = noise
@@ -677,15 +734,13 @@ def __init__(
677734
self.model, self.dl, self.opt = self.accelerator.prepare(self.model, self.dl, self.opt)
678735

679736
def save(self, milestone):
680-
if not self.accelerator.is_main_process:
737+
if not self.accelerator.is_local_main_process:
681738
return
682739

683-
opt = self.accelerator.unwrap_model(self.opt)
684-
685740
data = {
686741
'step': self.step,
687742
'model': self.accelerator.get_state_dict(self.model),
688-
'opt': opt.state_dict(),
743+
'opt': self.opt.state_dict(),
689744
'ema': self.ema.state_dict(),
690745
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
691746
}
@@ -696,12 +751,10 @@ def load(self, milestone):
696751
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))
697752

698753
model = self.accelerator.unwrap_model(self.model)
699-
opt = self.accelerator.unwrap_model(self.opt)
700-
701754
model.load_state_dict(data['model'])
702-
opt.load_state_dict(data['opt'])
703755

704756
self.step = data['step']
757+
self.opt.load_state_dict(data['opt'])
705758
self.ema.load_state_dict(data['ema'])
706759

707760
if exists(self.accelerator.scaler) and exists(data['scaler']):

denoising_diffusion_pytorch/learned_gaussian_diffusion.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from collections import namedtuple
23
from math import pi, sqrt, log as ln
34
from inspect import isfunction
45
from torch import nn, einsum
@@ -10,6 +11,8 @@
1011

1112
NAT = 1. / ln(2)
1213

14+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])
15+
1316
# helper functions
1417

1518
def exists(x):
@@ -67,17 +70,31 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
6770
class LearnedGaussianDiffusion(GaussianDiffusion):
6871
def __init__(
6972
self,
70-
denoise_fn,
73+
model,
7174
vb_loss_weight = 0.001, # lambda was 0.001 in the paper
7275
*args,
7376
**kwargs
7477
):
75-
super().__init__(denoise_fn, *args, **kwargs)
76-
assert denoise_fn.out_dim == (denoise_fn.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
78+
super().__init__(model, *args, **kwargs)
79+
assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
7780
self.vb_loss_weight = vb_loss_weight
7881

82+
def model_predictions(self, x, t):
83+
model_output = self.model(x, t)
84+
model_output, pred_variance = model_output.chunk(2, dim = 1)
85+
86+
if self.objective == 'pred_noise':
87+
pred_noise = model_output
88+
x_start = self.predict_start_from_noise(x, t, model_output)
89+
90+
elif self.objective == 'pred_x0':
91+
pred_noise = self.predict_noise_from_start(x, t, model_output)
92+
x_start = model_output
93+
94+
return ModelPrediction(pred_noise, x_start, pred_variance)
95+
7996
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
80-
model_output = default(model_output, lambda: self.denoise_fn(x, t))
97+
model_output = default(model_output, lambda: self.model(x, t))
8198
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)
8299

83100
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
@@ -102,7 +119,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):
102119

103120
# model output
104121

105-
model_output = self.denoise_fn(x_t, t)
122+
model_output = self.model(x_t, t)
106123

107124
# calculating kl loss for learned variance (interpolation)
108125

denoising_diffusion_pytorch/weighted_objective_gaussian_diffusion.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,23 @@ def default(val, d):
2222
class WeightedObjectiveGaussianDiffusion(GaussianDiffusion):
2323
def __init__(
2424
self,
25-
denoise_fn,
25+
model,
2626
*args,
2727
pred_noise_loss_weight = 0.1,
2828
pred_x_start_loss_weight = 0.1,
2929
**kwargs
3030
):
31-
super().__init__(denoise_fn, *args, **kwargs)
32-
channels = denoise_fn.channels
33-
assert denoise_fn.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
31+
super().__init__(model, *args, **kwargs)
32+
channels = model.channels
33+
assert model.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
34+
assert not self.is_ddim_sampling, 'ddim sampling cannot be used'
3435

3536
self.split_dims = (channels, channels, 2)
3637
self.pred_noise_loss_weight = pred_noise_loss_weight
3738
self.pred_x_start_loss_weight = pred_x_start_loss_weight
3839

3940
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
40-
model_output = self.denoise_fn(x, t)
41+
model_output = self.model(x, t)
4142

4243
pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
4344
normalized_weights = weights.softmax(dim = 1)
@@ -58,7 +59,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):
5859
noise = default(noise, lambda: torch.randn_like(x_start))
5960
x_t = self.q_sample(x_start = x_start, t = t, noise = noise)
6061

61-
model_output = self.denoise_fn(x_t, t)
62+
model_output = self.model(x_t, t)
6263
pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
6364

6465
# get loss for predicted noise and x_start

0 commit comments

Comments
 (0)