Skip to content

Commit a81dbfa

Browse files
committed
unweight tie layers, given feedback from author
1 parent 80b48d8 commit a81dbfa

File tree

3 files changed

+58
-43
lines changed

3 files changed

+58
-43
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ model = RIN(
3737
dim = 256, # model dimensions
3838
image_size = 128, # image size
3939
patch_size = 8, # patch size
40-
depth = 6, # recurrent depth
40+
depth = 6, # depth
4141
num_latents = 128, # number of latents. they used 256 in the paper
4242
latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
4343
).cuda()
@@ -78,6 +78,7 @@ model = RIN(
7878
dim = 256, # model dimensions
7979
image_size = 128, # image size
8080
patch_size = 8, # patch size
81+
depth = 6, # depth
8182
num_latents = 128, # number of latents. they used 256 in the paper
8283
latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
8384
).cuda()

rin_pytorch/rin_pytorch.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,65 @@ def forward(self, x, time = None):
267267

268268
# model
269269

270+
class RINBlock(nn.Module):
271+
def __init__(
272+
self,
273+
dim,
274+
latent_self_attn_depth,
275+
**attn_kwargs
276+
):
277+
super().__init__()
278+
279+
self.latents_attend_to_patches = Attention(dim, norm = True, norm_context = True, **attn_kwargs)
280+
281+
self.latent_self_attns = nn.ModuleList([])
282+
for _ in range(latent_self_attn_depth):
283+
self.latent_self_attns.append(nn.ModuleList([
284+
Attention(dim, norm = True, **attn_kwargs),
285+
FeedForward(dim)
286+
]))
287+
288+
self.patches_peg = PEG(dim)
289+
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
290+
self.patches_self_attn_ff = FeedForward(dim)
291+
292+
self.patches_attend_to_latents = Attention(dim, norm = True, norm_context = True, **attn_kwargs)
293+
self.patches_cross_attn_ff = FeedForward(dim)
294+
295+
def forward(self, patches, latents, t):
296+
patches = self.patches_peg(patches) + patches
297+
298+
# latents extract or cluster information from the patches
299+
300+
latents = self.latents_attend_to_patches(latents, patches, time = t) + latents
301+
302+
# latent self attention
303+
304+
for attn, ff in self.latent_self_attns:
305+
latents = attn(latents, time = t) + latents
306+
latents = ff(latents, time = t) + latents
307+
308+
# additional patches self attention with linear attention
309+
310+
patches = self.patches_self_attn(patches, time = t) + patches
311+
patches = self.patches_self_attn_ff(patches) + patches
312+
313+
# patches attend to the latents
314+
315+
patches = self.latents_attend_to_patches(patches, latents, time = t) + patches
316+
317+
patches = self.patches_cross_attn_ff(patches, time = t) + patches
318+
319+
return patches, latents
320+
270321
class RIN(nn.Module):
271322
def __init__(
272323
self,
273324
dim,
274325
image_size,
275326
patch_size = 16,
276327
channels = 3,
277-
depth = 6, # weight tied depth. weight tied layers basically is recurrent, with latents as hiddens
328+
depth = 6, # number of RIN blocks
278329
latent_self_attn_depth = 2, # how many self attentions for the latent per each round of cross attending from pixel space to latents and back
279330
num_latents = 256, # they still had to use a fair amount of latents for good results (256), in line with the Perceiver line of papers from Deepmind
280331
learned_sinusoidal_dim = 16,
@@ -329,25 +380,9 @@ def __init__(
329380

330381
# the main RIN body parameters - another attention is all you need moment
331382

332-
self.depth = depth
333-
334383
attn_kwargs = {**attn_kwargs, 'time_cond_dim': time_dim}
335384

336-
self.latents_attend_to_patches = Attention(dim, norm = True, norm_context = True, **attn_kwargs)
337-
338-
self.latent_self_attns = nn.ModuleList([])
339-
for _ in range(latent_self_attn_depth):
340-
self.latent_self_attns.append(nn.ModuleList([
341-
Attention(dim, norm = True, **attn_kwargs),
342-
FeedForward(dim)
343-
]))
344-
345-
self.patches_peg = PEG(dim)
346-
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
347-
self.patches_self_attn_ff = FeedForward(dim)
348-
349-
self.patches_attend_to_latents = Attention(dim, norm = True, norm_context = True, **attn_kwargs)
350-
self.patches_cross_attn_ff = FeedForward(dim)
385+
self.blocks = nn.ModuleList([RINBlock(dim, latent_self_attn_depth = latent_self_attn_depth, **attn_kwargs) for _ in range(depth)])
351386

352387
def forward(
353388
self,
@@ -386,29 +421,8 @@ def forward(
386421

387422
# the recurrent interface network body
388423

389-
for _ in range(self.depth):
390-
patches = self.patches_peg(patches) + patches
391-
392-
# latents extract or cluster information from the patches
393-
394-
latents = self.latents_attend_to_patches(latents, patches, time = t) + latents
395-
396-
# latent self attention
397-
398-
for attn, ff in self.latent_self_attns:
399-
latents = attn(latents, time = t) + latents
400-
latents = ff(latents, time = t) + latents
401-
402-
# additional patches self attention with linear attention
403-
404-
patches = self.patches_self_attn(patches, time = t) + patches
405-
patches = self.patches_self_attn_ff(patches) + patches
406-
407-
# patches attend to the latents
408-
409-
patches = self.latents_attend_to_patches(patches, latents, time = t) + patches
410-
411-
patches = self.patches_cross_attn_ff(patches, time = t) + patches
424+
for block in self.blocks:
425+
patches, latents = block(patches, latents, t)
412426

413427
# to pixels
414428

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.5',
6+
version = '0.0.6',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)