@@ -66,6 +66,7 @@ def __init__(
6666 heads = 8 ,
6767 dim_head = 64 ,
6868 dropout = 0. ,
69+ cross_attend = False ,
6970 reuse_attention = False
7071 ):
7172 super ().__init__ ()
@@ -74,10 +75,13 @@ def __init__(
7475 self .scale = dim_head ** - 0.5
7576 self .heads = heads
7677 self .reuse_attention = reuse_attention
78+ self .cross_attend = cross_attend
7779
7880 self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
7981
8082 self .norm = LayerNorm (dim ) if not reuse_attention else nn .Identity ()
83+ self .norm_context = LayerNorm (dim ) if cross_attend else nn .Identity ()
84+
8185 self .attend = nn .Softmax (dim = - 1 )
8286 self .dropout = nn .Dropout (dropout )
8387
@@ -99,7 +103,13 @@ def forward(
99103 attn = None
100104 ):
101105 x = self .norm (x )
102- context = default (context , x )
106+
107+ assert not (exists (context ) ^ self .cross_attend )
108+
109+ if self .cross_attend :
110+ context = self .norm_context (context )
111+ else :
112+ context = x
103113
104114 v = self .to_v (context )
105115 v = self .split_heads (v )
@@ -179,8 +189,8 @@ def __init__(
179189 layers .append (ModuleList ([
180190 Attention (dim = dim , dim_head = dim_head , heads = heads , dropout = dropout ),
181191 MLP (dim = dim , factor = mlp_factor , dropout = dropout ),
182- Attention (dim = dim , dim_head = cross_attn_dim_head , heads = cross_attn_heads , dropout = dropout ),
183- Attention (dim = dim , dim_head = cross_attn_dim_head , heads = cross_attn_heads , dropout = dropout , reuse_attention = True ),
192+ Attention (dim = dim , dim_head = cross_attn_dim_head , heads = cross_attn_heads , dropout = dropout , cross_attend = True ),
193+ Attention (dim = dim , dim_head = cross_attn_dim_head , heads = cross_attn_heads , dropout = dropout , cross_attend = True , reuse_attention = True ),
184194 LayerNorm (dim ),
185195 MLP (dim = dim , factor = highres_mlp_factor , dropout = dropout )
186196 ]))
0 commit comments