Skip to content

Commit adf5751

Browse files
committed
fix causal mask issue
1 parent fdcf17b commit adf5751

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def attention(
3131

3232
if exists(mask):
3333
mask = rearrange(mask, 'b j -> b 1 1 j')
34-
sim = sim.masked_fill(mask, mask_value)
34+
sim = sim.masked_fill(~mask, mask_value)
3535

3636
if causal:
3737
i, j = sim.shape[-2:]
@@ -50,14 +50,18 @@ def safe_sum(acc, el):
5050
return el
5151
return acc + el
5252

53-
def summarize_qkv_chunk(q, k, v, mask):
53+
def summarize_qkv_chunk(q, k, v, mask, causal_mask):
5454
weight = einsum('b h i d, b h j d -> b h i j', q, k)
55-
exp_weight = weight.exp()
55+
mask_value = -torch.finfo(weight.dtype).max
5656

5757
if exists(mask):
5858
mask = rearrange(mask, 'b j -> b 1 1 j')
59-
exp_weight = exp_weight.masked_fill(mask, 0.)
59+
weight = weight.masked_fill(~mask, mask_value)
6060

61+
if exists(causal_mask):
62+
weight = weight.masked_fill(causal_mask, mask_value)
63+
64+
exp_weight = weight.exp()
6165
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
6266
return exp_weight.sum(dim = -1), weighted_value
6367

@@ -68,7 +72,8 @@ def memory_efficient_attention(
6872
mask = None,
6973
causal = False,
7074
q_bucket_size = 512,
71-
k_bucket_size = 1024
75+
k_bucket_size = 1024,
76+
eps = 1e-8
7277
):
7378
scale = q.shape[-1] ** -0.5
7479
q = q * scale
@@ -80,26 +85,38 @@ def memory_efficient_attention(
8085
v_chunks = v.split(k_bucket_size, dim = -2)
8186
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
8287

88+
if causal:
89+
i, j = q.shape[-2], k.shape[-2]
90+
causal_mask = torch.ones(i, j).triu(j - i + 1).bool()
91+
8392
# loop through all chunks and accumulate
8493

8594
out = []
86-
for q_chunk in q_chunks:
95+
for q_index, q_chunk in enumerate(q_chunks):
8796
exp_weights = None
8897
weighted_values = None
8998

90-
for k_chunk, v_chunk, mask_chunk in zip(k_chunks, v_chunks, mask_chunks):
99+
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
100+
101+
causal_mask_chunk = None
102+
if causal:
103+
causal_mask_chunk = causal_mask[
104+
(q_index * q_bucket_size):(q_index * q_bucket_size + q_bucket_size),
105+
(k_index * k_bucket_size):(k_index * k_bucket_size + k_bucket_size),
106+
]
91107

92108
exp_weight_chunk, weighted_value_chunk = checkpointed_summarize_qkv_chunk(
93109
q_chunk,
94110
k_chunk,
95111
v_chunk,
96-
mask_chunk
112+
mask_chunk,
113+
causal_mask_chunk
97114
)
98115

99116
exp_weights = safe_sum(exp_weights, exp_weight_chunk)
100117
weighted_values = safe_sum(weighted_values, weighted_value_chunk)
101118

102-
normalized_values = weighted_values / rearrange(exp_weights, '... -> ... 1')
119+
normalized_values = weighted_values / (rearrange(exp_weights, '... -> ... 1') + eps)
103120
out.append(normalized_values)
104121

105122
return torch.cat(out, dim = -2)

tests/test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ def test_output_equal():
1919
)
2020

2121
x = torch.randn(2, 2048, 512)
22+
mask = torch.ones(2, 2048).bool()
2223

23-
out = attn(x)
24-
mem_efficient_out = attn(x, memory_efficient = True)
24+
out = attn(x, mask = mask)
25+
mem_efficient_out = attn(x, mask = mask, memory_efficient = True)
2526

2627
assert isclose(mem_efficient_out, out)
2728

@@ -40,12 +41,13 @@ def loss_fn(inp, **kwargs):
4041
return attn(inp, **kwargs).sum()
4142

4243
x = torch.randn(2, 2048, 512).requires_grad_()
44+
mask = torch.ones(2, 2048).bool()
4345

44-
loss_fn(x).backward()
46+
loss_fn(x, mask = mask).backward()
4547
out_grad = x.grad.clone()
4648

4749
x.grad.zero_()
48-
loss_fn(x, memory_efficient = True).backward()
50+
loss_fn(x, mask = mask, memory_efficient = True).backward()
4951
mem_efficient_out_grad = x.grad.clone()
5052

5153
assert isclose(out_grad, mem_efficient_out_grad)

0 commit comments

Comments
 (0)