Skip to content

Commit 804a202

Browse files
committed
complete flash attention algorithm in plain pytorch (for educational purposes, performant version will be at https://github.com/HazyResearch/flash-attention)
1 parent e4d0998 commit 804a202

File tree

3 files changed

+238
-4
lines changed

3 files changed

+238
-4
lines changed

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ mask = torch.ones(1, 65536).bool().cuda()
5252
out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
5353
```
5454

55-
- [ ] benchmark and see how much torch jit helps
56-
- [ ] look at Triton and Keops and see if either can be a fit
57-
5855
## Citations
5956

6057
```bibtex
@@ -78,3 +75,13 @@ out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
7875
primaryClass = {cs.CV}
7976
}
8077
```
78+
79+
```bibtex
80+
@article{Dao2022FlashAttentionFA,
81+
title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
82+
author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
83+
journal = {ArXiv},
84+
year = {2022},
85+
volume = {abs/2205.14135}
86+
}
87+
```
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import math
2+
import torch
3+
from functools import partial
4+
from torch import nn, einsum
5+
from torch.autograd.function import Function
6+
7+
from einops import rearrange
8+
9+
# constants
10+
11+
EPSILON = 1e-6
12+
13+
# helper functions
14+
15+
def exists(val):
16+
return val is not None
17+
18+
def default(val, d):
19+
return val if exists(val) else d
20+
21+
# flash attention forwards and backwards
22+
23+
class FlashAttentionFunction(Function):
24+
@staticmethod
25+
@torch.no_grad()
26+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
27+
device = q.device
28+
max_neg_value = -torch.finfo(q.dtype).max
29+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
30+
31+
o = torch.zeros_like(q)
32+
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
33+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)
34+
35+
scale = (q.shape[-1] ** -0.5)
36+
q = q * scale
37+
38+
if not exists(mask):
39+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
40+
else:
41+
mask = mask.split(q_bucket_size, dim = -2)
42+
43+
row_splits = zip(
44+
q.split(q_bucket_size, dim = -2),
45+
o.split(q_bucket_size, dim = -2),
46+
mask,
47+
all_row_sums.split(q_bucket_size, dim = -2),
48+
all_row_maxes.split(q_bucket_size, dim = -2),
49+
)
50+
51+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
52+
q_start_index = ind * q_bucket_size - qk_len_diff
53+
54+
col_splits = zip(
55+
k.split(k_bucket_size, dim = -2),
56+
v.split(k_bucket_size, dim = -2),
57+
)
58+
59+
for k_ind, (kc, vc) in enumerate(col_splits):
60+
k_start_index = k_ind * k_bucket_size
61+
62+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc)
63+
64+
if exists(row_mask):
65+
attn_weights.masked_fill_(~row_mask, max_neg_value)
66+
67+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
68+
causal_mask = torch.ones((q_bucket_size, k_bucket_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
69+
attn_weights.masked_fill_(causal_mask, max_neg_value)
70+
71+
block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
72+
73+
attn_weights -= block_row_maxes
74+
exp_weights = torch.exp(attn_weights)
75+
76+
if exists(row_mask):
77+
exp_weights.masked_fill_(~row_mask, 0.)
78+
79+
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
80+
81+
new_row_maxes = torch.maximum(block_row_maxes, row_sums)
82+
83+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
84+
85+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
86+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
87+
88+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
89+
90+
out = (row_sums / new_row_sums) * exp_row_max_diff * oc + \
91+
(exp_block_row_max_diff / new_row_sums) * exp_values
92+
93+
oc.copy_(out)
94+
row_maxes.copy_(new_row_maxes)
95+
row_sums.copy_(new_row_sums)
96+
97+
ctx.args = (causal, mask, q_bucket_size, k_bucket_size)
98+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
99+
100+
return o
101+
102+
@staticmethod
103+
@torch.no_grad()
104+
def backward(ctx, do):
105+
causal, mask, q_bucket_size, k_bucket_size = ctx.args
106+
q, k, v, o, l, m = ctx.saved_tensors
107+
108+
device = q.device
109+
110+
max_neg_value = -torch.finfo(q.dtype).max
111+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
112+
113+
dq = torch.zeros_like(q)
114+
dk = torch.zeros_like(k)
115+
dv = torch.zeros_like(v)
116+
117+
scale = q.shape[-1] ** -0.5
118+
119+
row_splits = zip(
120+
q.split(q_bucket_size, dim = -2),
121+
o.split(q_bucket_size, dim = -2),
122+
do.split(q_bucket_size, dim = -2),
123+
mask,
124+
l.split(q_bucket_size, dim = -2),
125+
m.split(q_bucket_size, dim = -2),
126+
dq.split(q_bucket_size, dim = -2)
127+
)
128+
129+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
130+
q_start_index = ind * q_bucket_size - qk_len_diff
131+
132+
col_splits = zip(
133+
k.split(k_bucket_size, dim = -2),
134+
v.split(k_bucket_size, dim = -2),
135+
dk.split(k_bucket_size, dim = -2),
136+
dv.split(k_bucket_size, dim = -2),
137+
)
138+
139+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
140+
k_start_index = k_ind * k_bucket_size
141+
142+
qc_scaled = qc * scale
143+
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
144+
145+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
146+
causal_mask = torch.ones((q_bucket_size, k_bucket_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
147+
attn_weights.masked_fill_(causal_mask, max_neg_value)
148+
149+
exp_attn_weights = torch.exp(attn_weights - mc)
150+
151+
if exists(row_mask):
152+
exp_attn_weights.masked_fill_(~row_mask, 0.)
153+
154+
p = exp_attn_weights / lc
155+
156+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
157+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
158+
159+
D = (do * o).sum(dim = -1, keepdims = True)
160+
ds = p * scale * (dp - D)
161+
162+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
163+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
164+
165+
dqc.add_(dq_chunk)
166+
dkc.add_(dk_chunk)
167+
dvc.add_(dv_chunk)
168+
169+
return dq, dk, dv, None, None, None, None
170+
171+
# main class
172+
173+
# just flash attention in plain pytorch
174+
# it will be way slower than implementing it in CUDA
175+
# for tinkering and educational purposes
176+
177+
class FlashAttention(nn.Module):
178+
def __init__(
179+
self,
180+
*,
181+
dim,
182+
heads = 8,
183+
dim_head = 64,
184+
causal = False,
185+
q_bucket_size = 512,
186+
k_bucket_size = 1024
187+
):
188+
super().__init__()
189+
self.heads = heads
190+
191+
self.causal = causal
192+
193+
inner_dim = heads * dim_head
194+
195+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
196+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
197+
self.to_out = nn.Linear(inner_dim, dim)
198+
199+
# memory efficient attention related parameters
200+
# can be overriden on forward
201+
self.q_bucket_size = q_bucket_size
202+
self.k_bucket_size = k_bucket_size
203+
204+
def forward(
205+
self,
206+
x,
207+
context = None,
208+
mask = None,
209+
q_bucket_size = None,
210+
k_bucket_size = None,
211+
):
212+
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
213+
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
214+
215+
h = self.heads
216+
context = default(context, x)
217+
218+
q = self.to_q(x)
219+
k, v = self.to_kv(context).chunk(2, dim = -1)
220+
221+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
222+
223+
out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)
224+
225+
out = rearrange(out, 'b h n d -> b n (h d)')
226+
return self.to_out(out)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.17',
6+
version = '0.0.18',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
9+
long_description_content_type = 'text/markdown',
910
author = 'Phil Wang',
1011
author_email = 'lucidrains@gmail.com',
1112
url = 'https://github.com/lucidrains/memory-efficient-attention-pytorch',

0 commit comments

Comments
 (0)