Skip to content

Commit ee85704

Browse files
committed
add a missing ff block after latents attending to patches, addressing #3
1 parent 1986201 commit ee85704

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def __init__(
285285
dim_latent = default(dim_latent, dim)
286286

287287
self.latents_attend_to_patches = Attention(dim_latent, dim_context = dim, norm = True, norm_context = True, **attn_kwargs)
288+
self.latents_cross_attn_ff = FeedForward(dim_latent)
288289

289290
self.latent_self_attns = nn.ModuleList([])
290291
for _ in range(latent_self_attn_depth):
@@ -309,6 +310,8 @@ def forward(self, patches, latents, t):
309310

310311
latents = self.latents_attend_to_patches(latents, patches, time = t) + latents
311312

313+
latents = self.latents_cross_attn_ff(latents, time = t) + latents
314+
312315
# latent self attention
313316

314317
for attn, ff in self.latent_self_attns:
@@ -583,7 +586,7 @@ def ddpm_sample(self, shape, time_difference = None):
583586
x_start = model_output
584587

585588
elif self.objective == 'eps':
586-
x_start = (img - sigma * model_output) / alpha
589+
x_start = (img - sigma * model_output) / alpha.clamp(min = 1e-8)
587590

588591
# clip x0
589592

@@ -648,7 +651,7 @@ def ddim_sample(self, shape, time_difference = None):
648651
x_start = model_output
649652

650653
elif self.objective == 'eps':
651-
x_start = (img - sigma * model_output) / alpha
654+
x_start = (img - sigma * model_output) / alpha.clamp(min = 1e-8)
652655

653656
# clip x0
654657

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.0',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)