Skip to content

Commit 339fc02

Browse files
committed
successfully did some basic math and clipped the predicted x0 intermediate for the continuous time case
1 parent 4284c88 commit 339fc02

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
loss_type = 'l1',
116116
noise_schedule = 'linear',
117117
num_sample_steps = 500,
118+
clip_sample_denoised = True,
118119
learned_schedule_net_hidden_dim = 1024,
119120
learned_noise_schedule_frac_gradient = 1. # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly
120121
):
@@ -149,6 +150,7 @@ def __init__(
149150
# sampling
150151

151152
self.num_sample_steps = num_sample_steps
153+
self.clip_sample_denoised = clip_sample_denoised
152154

153155
@property
154156
def device(self):
@@ -167,20 +169,28 @@ def p_mean_variance(self, x, time, time_next):
167169
# reviewer found an error in the equation in the paper (missing sigma)
168170
# following - https://openreview.net/forum?id=2LdBqxc1Yv&noteId=rIQgH0zKsRt
169171

170-
# todo - derive x_start from the posterior mean and do dynamic thresholding
171-
# assumed that is what is going on in Imagen
172-
173172
log_snr = self.log_snr(time)
174173
log_snr_next = self.log_snr(time_next)
175174
c = -expm1(log_snr - log_snr_next)
176175

177176
squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
178177
squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
179178

179+
alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))
180+
180181
batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
181182
pred_noise = self.denoise_fn(x, batch_log_snr)
182183

183-
model_mean = sqrt(squared_alpha_next / squared_alpha) * (x - c * sqrt(squared_sigma) * pred_noise)
184+
if self.clip_sample_denoised:
185+
x_start = (x - sigma * pred_noise) / alpha
186+
187+
# in Imagen, this was changed to dynamic thresholding
188+
x_start.clamp_(-1., 1.)
189+
190+
model_mean = alpha_next / alpha * x * (1 - c) + alpha_next * c * x_start
191+
else:
192+
model_mean = alpha_next / alpha * (x - c * sigma * pred_noise)
193+
184194
posterior_variance = squared_sigma_next * c
185195

186196
return model_mean, posterior_variance

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.17.4',
6+
version = '0.17.5',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)