Skip to content

Commit 2c0fc02

Browse files
committed
bring in the further simplification to flash attention that @tridao discovered, saving only logsumexp instead of rowsums + maximum for backwards
1 parent 35559a0 commit 2c0fc02

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

memory_efficient_attention_pytorch/flash_attention.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
9595
row_maxes.copy_(new_row_maxes)
9696
row_sums.copy_(new_row_sums)
9797

98+
lse = all_row_sums.log() + all_row_maxes
99+
98100
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
99-
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
101+
ctx.save_for_backward(q, k, v, o, lse)
100102

101103
return o
102104

@@ -106,7 +108,7 @@ def backward(ctx, do):
106108
""" Algorithm 4 in the paper """
107109

108110
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
109-
q, k, v, o, l, m = ctx.saved_tensors
111+
q, k, v, o, lse = ctx.saved_tensors
110112

111113
device = q.device
112114

@@ -122,12 +124,11 @@ def backward(ctx, do):
122124
o.split(q_bucket_size, dim = -2),
123125
do.split(q_bucket_size, dim = -2),
124126
mask,
125-
l.split(q_bucket_size, dim = -2),
126-
m.split(q_bucket_size, dim = -2),
127+
lse.split(q_bucket_size, dim = -2),
127128
dq.split(q_bucket_size, dim = -2)
128129
)
129130

130-
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
131+
for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
131132
q_start_index = ind * q_bucket_size - qk_len_diff
132133

133134
col_splits = zip(
@@ -146,12 +147,10 @@ def backward(ctx, do):
146147
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
147148
attn_weights.masked_fill_(causal_mask, max_neg_value)
148149

149-
exp_attn_weights = torch.exp(attn_weights - mc)
150+
p = torch.exp(attn_weights - lsec)
150151

151152
if exists(row_mask):
152-
exp_attn_weights.masked_fill_(~row_mask, 0.)
153-
154-
p = exp_attn_weights / lc
153+
p.masked_fill_(~row_mask, 0.)
155154

156155
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
157156
dp = einsum('... i d, ... j d -> ... i j', doc, vc)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.26',
6+
version = '0.0.27',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)