Skip to content

Commit 2785345

Browse files
committed
add gaussian diffusion where model predicts both noise and x_start, with a learned weighting between the two (experimental)
1 parent 84ebb9a commit 2785345

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2+
23
from denoising_diffusion_pytorch.learned_gaussian_diffusion import LearnedGaussianDiffusion
4+
from denoising_diffusion_pytorch.weighted_objective_gaussian_diffusion import WeightedObjectiveGaussianDiffusion
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
from inspect import isfunction
3+
from torch import nn, einsum
4+
from einops import rearrange
5+
6+
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one
7+
8+
# helper functions
9+
10+
def exists(x):
11+
return x is not None
12+
13+
def default(val, d):
14+
if exists(val):
15+
return val
16+
return d() if isfunction(d) else d
17+
18+
# some improvisation on my end
19+
# where i have the model learn to both predict noise and x0
20+
# and learn the weighted sum for each depending on time step
21+
22+
class WeightedObjectiveGaussianDiffusion(GaussianDiffusion):
23+
def __init__(
24+
self,
25+
denoise_fn,
26+
*args,
27+
pred_noise_loss_weight = 0.1,
28+
pred_x_start_loss_weight = 0.1,
29+
**kwargs
30+
):
31+
super().__init__(denoise_fn, *args, **kwargs)
32+
channels = denoise_fn.channels
33+
assert denoise_fn.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
34+
35+
self.split_dims = (channels, channels, 2)
36+
self.pred_noise_loss_weight = pred_noise_loss_weight
37+
self.pred_x_start_loss_weight = pred_x_start_loss_weight
38+
39+
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
40+
model_output = self.denoise_fn(x, t)
41+
42+
pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
43+
normalized_weights = weights.softmax(dim = 1)
44+
45+
x_start_from_noise = self.predict_start_from_noise(x, t = t, noise = pred_noise)
46+
47+
x_starts = torch.stack((x_start_from_noise, pred_x_start), dim = 1)
48+
weighted_x_start = einsum('b j h w, b j c h w -> b c h w', normalized_weights, x_starts)
49+
50+
if clip_denoised:
51+
weighted_x_start.clamp_(-1., 1.)
52+
53+
model_mean, model_variance, model_log_variance = self.q_posterior(weighted_x_start, x, t)
54+
55+
return model_mean, model_variance, model_log_variance
56+
57+
def p_losses(self, x_start, t, noise = None, clip_denoised = False):
58+
noise = default(noise, lambda: torch.randn_like(x_start))
59+
x_t = self.q_sample(x_start = x_start, t = t, noise = noise)
60+
61+
model_output = self.denoise_fn(x_t, t)
62+
pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
63+
64+
# get loss for predicted noise and x_start
65+
# with the loss weight given at initialization
66+
67+
noise_loss = self.loss_fn(noise, pred_noise) * self.pred_noise_loss_weight
68+
x_start_loss = self.loss_fn(x_start, pred_x_start) * self.pred_x_start_loss_weight
69+
70+
# calculate x_start from predicted noise
71+
# then do a weighted sum of the x_start prediction, weights also predicted by the model (softmax normalized)
72+
73+
x_start_from_pred_noise = self.predict_start_from_noise(x_t, t, pred_noise)
74+
x_start_from_pred_noise = x_start_from_pred_noise.clamp(-2., 2.)
75+
weighted_x_start = einsum('b j h w, b j c h w -> b c h w', weights.softmax(dim = 1), torch.stack((x_start_from_pred_noise, pred_x_start), dim = 1))
76+
77+
# main loss to x_start with the weighted one
78+
79+
weighted_x_start_loss = self.loss_fn(x_start, weighted_x_start)
80+
return weighted_x_start_loss + x_start_loss + noise_loss

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.15.0',
6+
version = '0.15.1',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)