Skip to content

Commit ee594d3

Browse files
committed
add ability to add attention bias, for dynamic positional bias and extrapolating to greater sequence lengths at inference
1 parent d4b015f commit ee594d3

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@ def attention(
1919
q, k, v,
2020
mask = None,
2121
causal = False,
22+
attn_bias = None,
2223
**kwargs
2324
):
2425
scale = q.shape[-1] ** -0.5
2526
q = q * scale
2627

2728
sim = einsum('b h i d, b h j d -> b h i j', q, k)
2829

30+
if exists(attn_bias):
31+
sim = sim + attn_bias
32+
2933
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
3034
mask_value = -torch.finfo(sim.dtype).max
3135

@@ -45,8 +49,12 @@ def attention(
4549

4650
# memory efficient attention
4751

48-
def summarize_qkv_chunk(q, k, v, mask, causal_mask):
52+
def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
4953
weight = einsum('b h i d, b h j d -> b h i j', q, k)
54+
55+
if exists(attn_bias_chunk):
56+
weight = weight + attn_bias_chunk
57+
5058
weight_max = weight.amax(dim = -1, keepdim = True).detach()
5159
weight = weight - weight_max
5260

@@ -70,6 +78,7 @@ def memory_efficient_attention(
7078
q, k, v,
7179
mask = None,
7280
causal = False,
81+
attn_bias = None,
7382
q_bucket_size = 512,
7483
k_bucket_size = 1024,
7584
eps = 1e-8
@@ -90,6 +99,11 @@ def memory_efficient_attention(
9099
causal_mask_chunks = causal_mask.split(q_bucket_size, dim = 0)
91100
causal_mask_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), causal_mask_chunks))
92101

102+
if exists(attn_bias):
103+
i, j = attn_bias.shape[-2:]
104+
attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
105+
attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
106+
93107
# loop through all chunks and accumulate
94108

95109
out = []
@@ -106,12 +120,15 @@ def memory_efficient_attention(
106120
# if chunk is to be all masked out causally, skip
107121
continue
108122

123+
attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
124+
109125
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = checkpointed_summarize_qkv_chunk(
110126
q_chunk,
111127
k_chunk,
112128
v_chunk,
113129
mask_chunk,
114-
causal_mask_chunk
130+
causal_mask_chunk,
131+
attn_bias_chunk
115132
)
116133

117134
exp_weights.append(exp_weight_chunk)
@@ -173,9 +190,10 @@ def forward(
173190
x,
174191
context = None,
175192
mask = None,
193+
attn_bias = None,
176194
memory_efficient = None,
177195
q_bucket_size = None,
178-
k_bucket_size = None
196+
k_bucket_size = None,
179197
):
180198
memory_efficient = default(memory_efficient, self.memory_efficient)
181199
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
@@ -191,7 +209,7 @@ def forward(
191209

192210
attn_fn = attention if not memory_efficient else memory_efficient_attention
193211

194-
out = attn_fn(q, k, v, mask = mask, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
212+
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
195213

196214
out = rearrange(out, 'b h n d -> b n (h d)')
197215
return self.to_out(out)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.6',
6+
version = '0.0.7',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)