@@ -19,13 +19,17 @@ def attention(
19
19
q , k , v ,
20
20
mask = None ,
21
21
causal = False ,
22
+ attn_bias = None ,
22
23
** kwargs
23
24
):
24
25
scale = q .shape [- 1 ] ** - 0.5
25
26
q = q * scale
26
27
27
28
sim = einsum ('b h i d, b h j d -> b h i j' , q , k )
28
29
30
+ if exists (attn_bias ):
31
+ sim = sim + attn_bias
32
+
29
33
sim = sim - sim .amax (dim = - 1 , keepdim = True ).detach ()
30
34
mask_value = - torch .finfo (sim .dtype ).max
31
35
@@ -45,8 +49,12 @@ def attention(
45
49
46
50
# memory efficient attention
47
51
48
- def summarize_qkv_chunk (q , k , v , mask , causal_mask ):
52
+ def summarize_qkv_chunk (q , k , v , mask , causal_mask , attn_bias_chunk ):
49
53
weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
54
+
55
+ if exists (attn_bias_chunk ):
56
+ weight = weight + attn_bias_chunk
57
+
50
58
weight_max = weight .amax (dim = - 1 , keepdim = True ).detach ()
51
59
weight = weight - weight_max
52
60
@@ -70,6 +78,7 @@ def memory_efficient_attention(
70
78
q , k , v ,
71
79
mask = None ,
72
80
causal = False ,
81
+ attn_bias = None ,
73
82
q_bucket_size = 512 ,
74
83
k_bucket_size = 1024 ,
75
84
eps = 1e-8
@@ -90,6 +99,11 @@ def memory_efficient_attention(
90
99
causal_mask_chunks = causal_mask .split (q_bucket_size , dim = 0 )
91
100
causal_mask_chunks = list (map (lambda t : t .split (k_bucket_size , dim = - 1 ), causal_mask_chunks ))
92
101
102
+ if exists (attn_bias ):
103
+ i , j = attn_bias .shape [- 2 :]
104
+ attn_bias_chunks = attn_bias .split (q_bucket_size , dim = - 2 )
105
+ attn_bias_chunks = list (map (lambda t : t .split (k_bucket_size , dim = - 1 ), attn_bias_chunks ))
106
+
93
107
# loop through all chunks and accumulate
94
108
95
109
out = []
@@ -106,12 +120,15 @@ def memory_efficient_attention(
106
120
# if chunk is to be all masked out causally, skip
107
121
continue
108
122
123
+ attn_bias_chunk = attn_bias_chunks [q_index ][k_index ] if exists (attn_bias ) else None
124
+
109
125
exp_weight_chunk , weighted_value_chunk , weight_max_chunk = checkpointed_summarize_qkv_chunk (
110
126
q_chunk ,
111
127
k_chunk ,
112
128
v_chunk ,
113
129
mask_chunk ,
114
- causal_mask_chunk
130
+ causal_mask_chunk ,
131
+ attn_bias_chunk
115
132
)
116
133
117
134
exp_weights .append (exp_weight_chunk )
@@ -173,9 +190,10 @@ def forward(
173
190
x ,
174
191
context = None ,
175
192
mask = None ,
193
+ attn_bias = None ,
176
194
memory_efficient = None ,
177
195
q_bucket_size = None ,
178
- k_bucket_size = None
196
+ k_bucket_size = None ,
179
197
):
180
198
memory_efficient = default (memory_efficient , self .memory_efficient )
181
199
q_bucket_size = default (q_bucket_size , self .q_bucket_size )
@@ -191,7 +209,7 @@ def forward(
191
209
192
210
attn_fn = attention if not memory_efficient else memory_efficient_attention
193
211
194
- out = attn_fn (q , k , v , mask = mask , causal = self .causal , q_bucket_size = q_bucket_size , k_bucket_size = k_bucket_size )
212
+ out = attn_fn (q , k , v , mask = mask , attn_bias = attn_bias , causal = self .causal , q_bucket_size = q_bucket_size , k_bucket_size = k_bucket_size )
195
213
196
214
out = rearrange (out , 'b h n d -> b n (h d)' )
197
215
return self .to_out (out )
0 commit comments