Skip to content

Commit 50e06b8

Browse files
committed
oops, actually add cosine sim flash attention
1 parent 4023c96 commit 50e06b8

File tree

2 files changed

+219
-1
lines changed

2 files changed

+219
-1
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import math
2+
import torch
3+
from functools import partial
4+
from torch import nn, einsum
5+
import torch.nn.functional as F
6+
from torch.autograd.function import Function
7+
8+
from einops import rearrange
9+
10+
# constants
11+
12+
EPSILON = 1e-6
13+
14+
# helper functions
15+
16+
def exists(val):
17+
return val is not None
18+
19+
def default(val, d):
20+
return val if exists(val) else d
21+
22+
def l2norm(t):
23+
return F.normalize(t, dim = -1)
24+
25+
# flash attention forwards and backwards
26+
27+
class FlashAttentionFunction(Function):
28+
@staticmethod
29+
@torch.no_grad()
30+
def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
31+
device = q.device
32+
max_neg_value = -torch.finfo(q.dtype).max
33+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
34+
35+
k_len = k.shape[-2] # in cosine sim attention, row sums are bounded by key / values sequence length
36+
37+
o = torch.zeros_like(q)
38+
all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
39+
40+
q = q * scale
41+
42+
if not exists(mask):
43+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
44+
else:
45+
mask = mask.split(q_bucket_size, dim = -2)
46+
47+
row_splits = zip(
48+
q.split(q_bucket_size, dim = -2),
49+
o.split(q_bucket_size, dim = -2),
50+
mask,
51+
all_row_sums.split(q_bucket_size, dim = -2),
52+
)
53+
54+
for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits):
55+
q_start_index = ind * q_bucket_size - qk_len_diff
56+
57+
col_splits = zip(
58+
k.split(k_bucket_size, dim = -2),
59+
v.split(k_bucket_size, dim = -2),
60+
)
61+
62+
for k_ind, (kc, vc) in enumerate(col_splits):
63+
k_start_index = k_ind * k_bucket_size
64+
65+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc)
66+
67+
if exists(row_mask):
68+
attn_weights.masked_fill_(~row_mask, max_neg_value)
69+
70+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
71+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
72+
attn_weights.masked_fill_(causal_mask, max_neg_value)
73+
74+
attn_weights -= scale
75+
exp_weights = torch.exp(attn_weights)
76+
77+
if exists(row_mask):
78+
exp_weights.masked_fill_(~row_mask, 0.)
79+
80+
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
81+
82+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
83+
84+
oc.add_(exp_values / k_len)
85+
row_sums.add_(block_row_sums)
86+
87+
ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size)
88+
ctx.save_for_backward(q, k, v, o, all_row_sums)
89+
90+
o.mul_(k_len / all_row_sums)
91+
92+
return o
93+
94+
@staticmethod
95+
@torch.no_grad()
96+
def backward(ctx, do):
97+
scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args
98+
q, k, v, o, l = ctx.saved_tensors
99+
100+
device = q.device
101+
102+
max_neg_value = -torch.finfo(q.dtype).max
103+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
104+
105+
dq = torch.zeros_like(q)
106+
dk = torch.zeros_like(k)
107+
dv = torch.zeros_like(v)
108+
109+
row_splits = zip(
110+
q.split(q_bucket_size, dim = -2),
111+
o.split(q_bucket_size, dim = -2),
112+
do.split(q_bucket_size, dim = -2),
113+
mask,
114+
l.split(q_bucket_size, dim = -2),
115+
dq.split(q_bucket_size, dim = -2)
116+
)
117+
118+
for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits):
119+
q_start_index = ind * q_bucket_size - qk_len_diff
120+
121+
col_splits = zip(
122+
k.split(k_bucket_size, dim = -2),
123+
v.split(k_bucket_size, dim = -2),
124+
dk.split(k_bucket_size, dim = -2),
125+
dv.split(k_bucket_size, dim = -2),
126+
)
127+
128+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
129+
k_start_index = k_ind * k_bucket_size
130+
131+
qc_scaled = qc * scale
132+
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
133+
134+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
135+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
136+
attn_weights.masked_fill_(causal_mask, max_neg_value)
137+
138+
exp_attn_weights = torch.exp(attn_weights)
139+
140+
if exists(row_mask):
141+
exp_attn_weights.masked_fill_(~row_mask, 0.)
142+
143+
p = exp_attn_weights / lc
144+
145+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
146+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
147+
148+
D = (doc * oc).sum(dim = -1, keepdims = True)
149+
ds = p * scale * (dp - D)
150+
151+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
152+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
153+
154+
dqc.add_(dq_chunk)
155+
dkc.add_(dk_chunk)
156+
dvc.add_(dv_chunk)
157+
158+
return dq, dk, dv, None, None, None, None, None
159+
160+
# main class
161+
162+
# flash attention for cosine sim attention
163+
# a bit less complicated, as no more need to worry about softmax numerical stability, and row sums are bounded
164+
165+
class FlashAttention(nn.Module):
166+
def __init__(
167+
self,
168+
*,
169+
dim,
170+
scale = 16,
171+
heads = 8,
172+
dim_head = 64,
173+
causal = False,
174+
q_bucket_size = 512,
175+
k_bucket_size = 1024
176+
):
177+
super().__init__()
178+
self.heads = heads
179+
180+
self.scale = scale
181+
self.causal = causal
182+
183+
inner_dim = heads * dim_head
184+
185+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
186+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
187+
self.to_out = nn.Linear(inner_dim, dim)
188+
189+
# memory efficient attention related parameters
190+
# can be overriden on forward
191+
self.q_bucket_size = q_bucket_size
192+
self.k_bucket_size = k_bucket_size
193+
194+
def forward(
195+
self,
196+
x,
197+
context = None,
198+
mask = None,
199+
q_bucket_size = None,
200+
k_bucket_size = None,
201+
):
202+
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
203+
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
204+
205+
h = self.heads
206+
context = default(context, x)
207+
208+
q = self.to_q(x)
209+
k, v = self.to_kv(context).chunk(2, dim = -1)
210+
211+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
212+
213+
q, k = map(l2norm, (q, k))
214+
215+
out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size)
216+
217+
out = rearrange(out, 'b h n d -> b n (h d)')
218+
return self.to_out(out)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.20',
6+
version = '0.0.22',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)