@@ -285,6 +285,7 @@ def __init__(
285285 dim_latent = default (dim_latent , dim )
286286
287287 self .latents_attend_to_patches = Attention (dim_latent , dim_context = dim , norm = True , norm_context = True , ** attn_kwargs )
288+ self .latents_cross_attn_ff = FeedForward (dim_latent )
288289
289290 self .latent_self_attns = nn .ModuleList ([])
290291 for _ in range (latent_self_attn_depth ):
@@ -309,6 +310,8 @@ def forward(self, patches, latents, t):
309310
310311 latents = self .latents_attend_to_patches (latents , patches , time = t ) + latents
311312
313+ latents = self .latents_cross_attn_ff (latents , time = t ) + latents
314+
312315 # latent self attention
313316
314317 for attn , ff in self .latent_self_attns :
@@ -583,7 +586,7 @@ def ddpm_sample(self, shape, time_difference = None):
583586 x_start = model_output
584587
585588 elif self .objective == 'eps' :
586- x_start = (img - sigma * model_output ) / alpha
589+ x_start = (img - sigma * model_output ) / alpha . clamp ( min = 1e-8 )
587590
588591 # clip x0
589592
@@ -648,7 +651,7 @@ def ddim_sample(self, shape, time_difference = None):
648651 x_start = model_output
649652
650653 elif self .objective == 'eps' :
651- x_start = (img - sigma * model_output ) / alpha
654+ x_start = (img - sigma * model_output ) / alpha . clamp ( min = 1e-8 )
652655
653656 # clip x0
654657
0 commit comments