@@ -267,14 +267,65 @@ def forward(self, x, time = None):
267267
268268# model
269269
270+ class RINBlock (nn .Module ):
271+ def __init__ (
272+ self ,
273+ dim ,
274+ latent_self_attn_depth ,
275+ ** attn_kwargs
276+ ):
277+ super ().__init__ ()
278+
279+ self .latents_attend_to_patches = Attention (dim , norm = True , norm_context = True , ** attn_kwargs )
280+
281+ self .latent_self_attns = nn .ModuleList ([])
282+ for _ in range (latent_self_attn_depth ):
283+ self .latent_self_attns .append (nn .ModuleList ([
284+ Attention (dim , norm = True , ** attn_kwargs ),
285+ FeedForward (dim )
286+ ]))
287+
288+ self .patches_peg = PEG (dim )
289+ self .patches_self_attn = LinearAttention (dim , norm = True , ** attn_kwargs )
290+ self .patches_self_attn_ff = FeedForward (dim )
291+
292+ self .patches_attend_to_latents = Attention (dim , norm = True , norm_context = True , ** attn_kwargs )
293+ self .patches_cross_attn_ff = FeedForward (dim )
294+
295+ def forward (self , patches , latents , t ):
296+ patches = self .patches_peg (patches ) + patches
297+
298+ # latents extract or cluster information from the patches
299+
300+ latents = self .latents_attend_to_patches (latents , patches , time = t ) + latents
301+
302+ # latent self attention
303+
304+ for attn , ff in self .latent_self_attns :
305+ latents = attn (latents , time = t ) + latents
306+ latents = ff (latents , time = t ) + latents
307+
308+ # additional patches self attention with linear attention
309+
310+ patches = self .patches_self_attn (patches , time = t ) + patches
311+ patches = self .patches_self_attn_ff (patches ) + patches
312+
313+ # patches attend to the latents
314+
315+ patches = self .latents_attend_to_patches (patches , latents , time = t ) + patches
316+
317+ patches = self .patches_cross_attn_ff (patches , time = t ) + patches
318+
319+ return patches , latents
320+
270321class RIN (nn .Module ):
271322 def __init__ (
272323 self ,
273324 dim ,
274325 image_size ,
275326 patch_size = 16 ,
276327 channels = 3 ,
277- depth = 6 , # weight tied depth. weight tied layers basically is recurrent, with latents as hiddens
328+ depth = 6 , # number of RIN blocks
278329 latent_self_attn_depth = 2 , # how many self attentions for the latent per each round of cross attending from pixel space to latents and back
279330 num_latents = 256 , # they still had to use a fair amount of latents for good results (256), in line with the Perceiver line of papers from Deepmind
280331 learned_sinusoidal_dim = 16 ,
@@ -329,25 +380,9 @@ def __init__(
329380
330381 # the main RIN body parameters - another attention is all you need moment
331382
332- self .depth = depth
333-
334383 attn_kwargs = {** attn_kwargs , 'time_cond_dim' : time_dim }
335384
336- self .latents_attend_to_patches = Attention (dim , norm = True , norm_context = True , ** attn_kwargs )
337-
338- self .latent_self_attns = nn .ModuleList ([])
339- for _ in range (latent_self_attn_depth ):
340- self .latent_self_attns .append (nn .ModuleList ([
341- Attention (dim , norm = True , ** attn_kwargs ),
342- FeedForward (dim )
343- ]))
344-
345- self .patches_peg = PEG (dim )
346- self .patches_self_attn = LinearAttention (dim , norm = True , ** attn_kwargs )
347- self .patches_self_attn_ff = FeedForward (dim )
348-
349- self .patches_attend_to_latents = Attention (dim , norm = True , norm_context = True , ** attn_kwargs )
350- self .patches_cross_attn_ff = FeedForward (dim )
385+ self .blocks = nn .ModuleList ([RINBlock (dim , latent_self_attn_depth = latent_self_attn_depth , ** attn_kwargs ) for _ in range (depth )])
351386
352387 def forward (
353388 self ,
@@ -386,29 +421,8 @@ def forward(
386421
387422 # the recurrent interface network body
388423
389- for _ in range (self .depth ):
390- patches = self .patches_peg (patches ) + patches
391-
392- # latents extract or cluster information from the patches
393-
394- latents = self .latents_attend_to_patches (latents , patches , time = t ) + latents
395-
396- # latent self attention
397-
398- for attn , ff in self .latent_self_attns :
399- latents = attn (latents , time = t ) + latents
400- latents = ff (latents , time = t ) + latents
401-
402- # additional patches self attention with linear attention
403-
404- patches = self .patches_self_attn (patches , time = t ) + patches
405- patches = self .patches_self_attn_ff (patches ) + patches
406-
407- # patches attend to the latents
408-
409- patches = self .latents_attend_to_patches (patches , latents , time = t ) + patches
410-
411- patches = self .patches_cross_attn_ff (patches , time = t ) + patches
424+ for block in self .blocks :
425+ patches , latents = block (patches , latents , t )
412426
413427 # to pixels
414428
0 commit comments