Skip to content

Commit eaf9d9f

Browse files
committed
unet needs to be conditioned on log(snr) in p_mean_variance for continuous time gaussian diffusion
1 parent 3bf5e76 commit eaf9d9f

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
*,
6060
image_size,
6161
channels = 3,
62-
cond_scale = 500,
6362
loss_type = 'l1',
6463
noise_schedule = 'linear',
6564
num_sample_steps = 500
@@ -76,7 +75,6 @@ def __init__(
7675

7776
# continuous noise schedule related stuff
7877

79-
self.cond_scale = cond_scale # the log(snr) will be scaled by this value
8078
self.loss_type = loss_type
8179

8280
if noise_schedule == 'linear':
@@ -108,18 +106,16 @@ def p_mean_variance(self, x, time, time_next):
108106
# todo - derive x_start from the posterior mean and do dynamic thresholding
109107
# assumed that is what is going on in Imagen
110108

111-
batch = x.shape[0]
112-
batch_time = repeat(time, ' -> b', b = batch)
113-
114-
pred_noise = self.denoise_fn(x, batch_time * self.cond_scale)
115-
116109
log_snr = self.log_snr(time)
117110
log_snr_next = self.log_snr(time_next)
118111
c = -expm1(log_snr - log_snr_next)
119112

120113
squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
121114
squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
122115

116+
batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
117+
pred_noise = self.denoise_fn(x, batch_log_snr)
118+
123119
model_mean = sqrt(squared_alpha_next / squared_alpha) * (x - c * sqrt(squared_sigma) * pred_noise)
124120
posterior_variance = squared_sigma_next * c
125121

@@ -151,6 +147,7 @@ def p_sample_loop(self, shape):
151147
times_next = steps[i + 1]
152148
img = self.p_sample(img, times, times_next)
153149

150+
img.clamp_(-1., 1.)
154151
img = unnormalize_to_zero_to_one(img)
155152
return img
156153

@@ -180,7 +177,7 @@ def p_losses(self, x_start, times, noise = None):
180177

181178
x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
182179

183-
model_out = self.denoise_fn(x, log_snr * self.cond_scale)
180+
model_out = self.denoise_fn(x, log_snr)
184181
return self.loss_fn(model_out, noise)
185182

186183
def forward(self, img, *args, **kwargs):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.16.5',
6+
version = '0.16.7',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)