Skip to content

Commit 33eee3b

Browse files
committed
unet needs to be conditioned on log(snr) in p_mean_variance for continuous time gaussian diffusion
1 parent 3bf5e76 commit 33eee3b

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,16 @@ def p_mean_variance(self, x, time, time_next):
108108
# todo - derive x_start from the posterior mean and do dynamic thresholding
109109
# assumed that is what is going on in Imagen
110110

111-
batch = x.shape[0]
112-
batch_time = repeat(time, ' -> b', b = batch)
113-
114-
pred_noise = self.denoise_fn(x, batch_time * self.cond_scale)
115-
116111
log_snr = self.log_snr(time)
117112
log_snr_next = self.log_snr(time_next)
118113
c = -expm1(log_snr - log_snr_next)
119114

120115
squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
121116
squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
122117

118+
batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
119+
pred_noise = self.denoise_fn(x, batch_log_snr)
120+
123121
model_mean = sqrt(squared_alpha_next / squared_alpha) * (x - c * sqrt(squared_sigma) * pred_noise)
124122
posterior_variance = squared_sigma_next * c
125123

@@ -151,6 +149,7 @@ def p_sample_loop(self, shape):
151149
times_next = steps[i + 1]
152150
img = self.p_sample(img, times, times_next)
153151

152+
img.clamp_(-1., 1.)
154153
img = unnormalize_to_zero_to_one(img)
155154
return img
156155

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.16.5',
6+
version = '0.16.6',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)