Skip to content

Commit 4b2c00c

Browse files
committed
when cross attending in look vit, make sure context tokens are normalized
1 parent ec6c48b commit 4b2c00c

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.7.2',
9+
version = '1.7.3',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/look_vit.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
heads = 8,
6767
dim_head = 64,
6868
dropout = 0.,
69+
cross_attend = False,
6970
reuse_attention = False
7071
):
7172
super().__init__()
@@ -74,10 +75,13 @@ def __init__(
7475
self.scale = dim_head ** -0.5
7576
self.heads = heads
7677
self.reuse_attention = reuse_attention
78+
self.cross_attend = cross_attend
7779

7880
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
7981

8082
self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
83+
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()
84+
8185
self.attend = nn.Softmax(dim = -1)
8286
self.dropout = nn.Dropout(dropout)
8387

@@ -99,7 +103,13 @@ def forward(
99103
attn = None
100104
):
101105
x = self.norm(x)
102-
context = default(context, x)
106+
107+
assert not (exists(context) ^ self.cross_attend)
108+
109+
if self.cross_attend:
110+
context = self.norm_context(context)
111+
else:
112+
context = x
103113

104114
v = self.to_v(context)
105115
v = self.split_heads(v)
@@ -179,8 +189,8 @@ def __init__(
179189
layers.append(ModuleList([
180190
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
181191
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
182-
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout),
183-
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, reuse_attention = True),
192+
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
193+
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
184194
LayerNorm(dim),
185195
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
186196
]))

0 commit comments

Comments
 (0)