Skip to content

Commit 3bf5e76

Browse files
committed
use a non-sinusoidal embedded condition for continuous time gaussian diffusion conditioned on log(snr)
1 parent 532178a commit 3bf5e76

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
num_sample_steps = 500
6666
):
6767
super().__init__()
68+
assert not denoise_fn.sinusoidal_cond_mlp
6869

6970
self.denoise_fn = denoise_fn
7071

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tqdm import tqdm
1818
from einops import rearrange
19+
from einops.layers.torch import Rearrange
1920

2021
# helpers functions
2122

@@ -211,6 +212,18 @@ def forward(self, x):
211212

212213
# model
213214

215+
def MLP(dim_in, dim_hidden):
216+
return nn.Sequential(
217+
Rearrange('... -> ... 1'),
218+
nn.Linear(1, dim_hidden),
219+
nn.GELU(),
220+
nn.LayerNorm(dim_hidden),
221+
nn.Linear(dim_hidden, dim_hidden),
222+
nn.GELU(),
223+
nn.LayerNorm(dim_hidden),
224+
nn.Linear(dim_hidden, dim_hidden)
225+
)
226+
214227
class Unet(nn.Module):
215228
def __init__(
216229
self,
@@ -219,9 +232,9 @@ def __init__(
219232
out_dim = None,
220233
dim_mults=(1, 2, 4, 8),
221234
channels = 3,
222-
with_time_emb = True,
223235
resnet_block_groups = 8,
224-
learned_variance = False
236+
learned_variance = False,
237+
sinusoidal_cond_mlp = True
225238
):
226239
super().__init__()
227240

@@ -239,17 +252,19 @@ def __init__(
239252

240253
# time embeddings
241254

242-
if with_time_emb:
243-
time_dim = dim * 4
255+
time_dim = dim * 4
256+
257+
self.sinusoidal_cond_mlp = sinusoidal_cond_mlp
258+
259+
if sinusoidal_cond_mlp:
244260
self.time_mlp = nn.Sequential(
245261
SinusoidalPosEmb(dim),
246262
nn.Linear(dim, time_dim),
247263
nn.GELU(),
248264
nn.Linear(time_dim, time_dim)
249265
)
250266
else:
251-
time_dim = None
252-
self.time_mlp = None
267+
self.time_mlp = MLP(1, time_dim)
253268

254269
# layers
255270

@@ -292,8 +307,7 @@ def __init__(
292307

293308
def forward(self, x, time):
294309
x = self.init_conv(x)
295-
296-
t = self.time_mlp(time) if exists(self.time_mlp) else None
310+
t = self.time_mlp(time)
297311

298312
h = []
299313

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.16.4',
6+
version = '0.16.5',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)