Skip to content

Commit caa5af1

Browse files
committed
final cleanup
1 parent 55c658b commit caa5af1

File tree

2 files changed

+2
-14
lines changed

2 files changed

+2
-14
lines changed

denoising_diffusion_pytorch/learned_gaussian_diffusion.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,6 @@ def __init__(
7676
assert denoise_fn.out_dim == (denoise_fn.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
7777
self.vb_loss_weight = vb_loss_weight
7878

79-
def q_posterior_mean_variance(self, x_start, x_t, t):
80-
"""
81-
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
82-
"""
83-
posterior_mean = (
84-
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
85-
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
86-
)
87-
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
88-
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
89-
return posterior_mean, posterior_variance, posterior_log_variance_clipped
90-
9179
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
9280
model_output = default(model_output, lambda: self.denoise_fn(x, t))
9381
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)
@@ -118,7 +106,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):
118106

119107
# calculating kl loss for learned variance (interpolation)
120108

121-
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start = x_start, x_t = x_t, t = t)
109+
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
122110
model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)
123111

124112
# kl loss with detached model predicted mean, for stability reasons as in paper

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.14.2',
6+
version = '0.14.3',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)