Skip to content

Commit d79f510

Browse files
committed
oops, actually add cosine sim flash attention
1 parent 4023c96 commit d79f510

File tree

2 files changed

+220
-1
lines changed

2 files changed

+220
-1
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import math
2+
import torch
3+
from functools import partial
4+
from torch import nn, einsum
5+
import torch.nn.functional as F
6+
from torch.autograd.function import Function
7+
8+
from einops import rearrange
9+
10+
# constants
11+
12+
EPSILON = 1e-6
13+
14+
# helper functions
15+
16+
def exists(val):
17+
return val is not None
18+
19+
def default(val, d):
20+
return val if exists(val) else d
21+
22+
def l2norm(t):
23+
return F.normalize(t, dim = -1)
24+
25+
# flash attention forwards and backwards
26+
27+
class FlashAttentionFunction(Function):
28+
@staticmethod
29+
@torch.no_grad()
30+
def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
31+
device = q.device
32+
max_neg_value = -torch.finfo(q.dtype).max
33+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
34+
35+
k_len = k.shape[-2] # in cosine sim attention, row sums are bounded by key / values sequence length
36+
37+
o = torch.zeros_like(q)
38+
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
39+
40+
scale = (q.shape[-1] ** -0.5)
41+
q = q * scale
42+
43+
if not exists(mask):
44+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
45+
else:
46+
mask = mask.split(q_bucket_size, dim = -2)
47+
48+
row_splits = zip(
49+
q.split(q_bucket_size, dim = -2),
50+
o.split(q_bucket_size, dim = -2),
51+
mask,
52+
all_row_sums.split(q_bucket_size, dim = -2),
53+
)
54+
55+
for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits):
56+
q_start_index = ind * q_bucket_size - qk_len_diff
57+
58+
col_splits = zip(
59+
k.split(k_bucket_size, dim = -2),
60+
v.split(k_bucket_size, dim = -2),
61+
)
62+
63+
for k_ind, (kc, vc) in enumerate(col_splits):
64+
k_start_index = k_ind * k_bucket_size
65+
66+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc)
67+
68+
if exists(row_mask):
69+
attn_weights.masked_fill_(~row_mask, max_neg_value)
70+
71+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
72+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
73+
attn_weights.masked_fill_(causal_mask, max_neg_value)
74+
75+
attn_weights -= scale
76+
exp_weights = torch.exp(attn_weights)
77+
78+
if exists(row_mask):
79+
exp_weights.masked_fill_(~row_mask, 0.)
80+
81+
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
82+
83+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
84+
85+
oc.add_(exp_values / k_len)
86+
row_sums.add_(block_row_sums)
87+
88+
ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size)
89+
ctx.save_for_backward(q, k, v, o, all_row_sums)
90+
91+
o.mul_(k_len / all_row_sums)
92+
93+
return o
94+
95+
@staticmethod
96+
@torch.no_grad()
97+
def backward(ctx, do):
98+
scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args
99+
q, k, v, o, l = ctx.saved_tensors
100+
101+
device = q.device
102+
103+
max_neg_value = -torch.finfo(q.dtype).max
104+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
105+
106+
dq = torch.zeros_like(q)
107+
dk = torch.zeros_like(k)
108+
dv = torch.zeros_like(v)
109+
110+
row_splits = zip(
111+
q.split(q_bucket_size, dim = -2),
112+
o.split(q_bucket_size, dim = -2),
113+
do.split(q_bucket_size, dim = -2),
114+
mask,
115+
l.split(q_bucket_size, dim = -2),
116+
dq.split(q_bucket_size, dim = -2)
117+
)
118+
119+
for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits):
120+
q_start_index = ind * q_bucket_size - qk_len_diff
121+
122+
col_splits = zip(
123+
k.split(k_bucket_size, dim = -2),
124+
v.split(k_bucket_size, dim = -2),
125+
dk.split(k_bucket_size, dim = -2),
126+
dv.split(k_bucket_size, dim = -2),
127+
)
128+
129+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
130+
k_start_index = k_ind * k_bucket_size
131+
132+
qc_scaled = qc * scale
133+
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
134+
135+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
136+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
137+
attn_weights.masked_fill_(causal_mask, max_neg_value)
138+
139+
exp_attn_weights = torch.exp(attn_weights)
140+
141+
if exists(row_mask):
142+
exp_attn_weights.masked_fill_(~row_mask, 0.)
143+
144+
p = exp_attn_weights / lc
145+
146+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
147+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
148+
149+
D = (doc * oc).sum(dim = -1, keepdims = True)
150+
ds = p * scale * (dp - D)
151+
152+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
153+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
154+
155+
dqc.add_(dq_chunk)
156+
dkc.add_(dk_chunk)
157+
dvc.add_(dv_chunk)
158+
159+
return dq, dk, dv, None, None, None, None, None
160+
161+
# main class
162+
163+
# flash attention for cosine sim attention
164+
# a bit less complicated, as no more need to worry about softmax numerical stability, and row sums are bounded
165+
166+
class FlashAttention(nn.Module):
167+
def __init__(
168+
self,
169+
*,
170+
dim,
171+
scale = 16,
172+
heads = 8,
173+
dim_head = 64,
174+
causal = False,
175+
q_bucket_size = 512,
176+
k_bucket_size = 1024
177+
):
178+
super().__init__()
179+
self.heads = heads
180+
181+
self.scale = scale
182+
self.causal = causal
183+
184+
inner_dim = heads * dim_head
185+
186+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
187+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
188+
self.to_out = nn.Linear(inner_dim, dim)
189+
190+
# memory efficient attention related parameters
191+
# can be overriden on forward
192+
self.q_bucket_size = q_bucket_size
193+
self.k_bucket_size = k_bucket_size
194+
195+
def forward(
196+
self,
197+
x,
198+
context = None,
199+
mask = None,
200+
q_bucket_size = None,
201+
k_bucket_size = None,
202+
):
203+
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
204+
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
205+
206+
h = self.heads
207+
context = default(context, x)
208+
209+
q = self.to_q(x)
210+
k, v = self.to_kv(context).chunk(2, dim = -1)
211+
212+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
213+
214+
q, k = map(l2norm, (q, k))
215+
216+
out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size)
217+
218+
out = rearrange(out, 'b h n d -> b n (h d)')
219+
return self.to_out(out)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.20',
6+
version = '0.0.21',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)