Skip to content

Commit 33fb78a

Browse files
committed
add tests, fix bugs
1 parent c901ae5 commit 33fb78a

File tree

7 files changed

+75
-17
lines changed

7 files changed

+75
-17
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention
22
from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention
3+
from memory_efficient_attention_pytorch.flash_attention import FlashAttention

memory_efficient_attention_pytorch/cosine_sim_flash_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
4242
if not exists(mask):
4343
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
4444
else:
45-
mask = mask.split(q_bucket_size, dim = -2)
45+
mask = rearrange(mask, 'b n -> b 1 1 n')
46+
mask = mask.split(q_bucket_size, dim = -1)
4647

4748
row_splits = zip(
4849
q.split(q_bucket_size, dim = -2),
@@ -184,7 +185,7 @@ def __init__(
184185

185186
self.to_q = nn.Linear(dim, inner_dim, bias = False)
186187
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
187-
self.to_out = nn.Linear(inner_dim, dim)
188+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
188189

189190
# memory efficient attention related parameters
190191
# can be overriden on forward

memory_efficient_attention_pytorch/flash_attention.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
3737
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)
3838

3939
scale = (q.shape[-1] ** -0.5)
40-
q = q * scale
4140

4241
if not exists(mask):
4342
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
4443
else:
45-
mask = mask.split(q_bucket_size, dim = -2)
44+
mask = rearrange(mask, 'b n -> b 1 1 n')
45+
mask = mask.split(q_bucket_size, dim = -1)
4646

4747
row_splits = zip(
4848
q.split(q_bucket_size, dim = -2),
@@ -63,7 +63,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
6363
for k_ind, (kc, vc) in enumerate(col_splits):
6464
k_start_index = k_ind * k_bucket_size
6565

66-
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc)
66+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
6767

6868
if exists(row_mask):
6969
attn_weights.masked_fill_(~row_mask, max_neg_value)
@@ -73,7 +73,6 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
7373
attn_weights.masked_fill_(causal_mask, max_neg_value)
7474

7575
block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
76-
7776
attn_weights -= block_row_maxes
7877
exp_weights = torch.exp(attn_weights)
7978

@@ -82,7 +81,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
8281

8382
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
8483

85-
new_row_maxes = torch.maximum(block_row_maxes, row_sums)
84+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
8685

8786
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
8887

@@ -92,10 +91,11 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
9291
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
9392

9493
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
94+
9595
row_maxes.copy_(new_row_maxes)
9696
row_sums.copy_(new_row_sums)
9797

98-
ctx.args = (causal, mask, q_bucket_size, k_bucket_size)
98+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
9999
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
100100

101101
return o
@@ -105,7 +105,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
105105
def backward(ctx, do):
106106
""" Algorithm 4 in the paper """
107107

108-
causal, mask, q_bucket_size, k_bucket_size = ctx.args
108+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
109109
q, k, v, o, l, m = ctx.saved_tensors
110110

111111
device = q.device
@@ -117,8 +117,6 @@ def backward(ctx, do):
117117
dk = torch.zeros_like(k)
118118
dv = torch.zeros_like(v)
119119

120-
scale = q.shape[-1] ** -0.5
121-
122120
row_splits = zip(
123121
q.split(q_bucket_size, dim = -2),
124122
o.split(q_bucket_size, dim = -2),
@@ -142,8 +140,7 @@ def backward(ctx, do):
142140
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
143141
k_start_index = k_ind * k_bucket_size
144142

145-
qc_scaled = qc * scale
146-
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
143+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
147144

148145
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
149146
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
@@ -197,7 +194,7 @@ def __init__(
197194

198195
self.to_q = nn.Linear(dim, inner_dim, bias = False)
199196
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
200-
self.to_out = nn.Linear(inner_dim, dim)
197+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
201198

202199
# memory efficient attention related parameters
203200
# can be overriden on forward

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def __init__(
180180

181181
self.to_q = nn.Linear(dim, inner_dim, bias = False)
182182
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
183-
self.to_out = nn.Linear(inner_dim, dim)
183+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
184184

185185
# memory efficient attention related parameters
186186
# can be overriden on forward

memory_efficient_attention_pytorch/memory_efficient_cosine_sim_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(
164164

165165
self.to_q = nn.Linear(dim, inner_dim, bias = False)
166166
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
167-
self.to_out = nn.Linear(inner_dim, dim)
167+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
168168

169169
# memory efficient attention related parameters
170170
# can be overriden on forward

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.22',
6+
version = '0.0.23',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

tests/test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22
from memory_efficient_attention_pytorch import Attention
33

4+
from memory_efficient_attention_pytorch.memory_efficient_attention import attention
5+
from memory_efficient_attention_pytorch.flash_attention import FlashAttention, FlashAttentionFunction
6+
47
# constants
58

69
def isclose(a, b, atol = 1e-6):
@@ -53,3 +56,59 @@ def loss_fn(inp, **kwargs):
5356
mem_efficient_out_grad = x.grad.clone()
5457

5558
assert isclose(out_grad, mem_efficient_out_grad, atol = 1e-5)
59+
60+
# test flash attention
61+
62+
def test_flash_attn_output_equal():
63+
attn_kwargs = dict(
64+
dim = 512,
65+
dim_head = 64,
66+
heads = 8,
67+
q_bucket_size = 64,
68+
k_bucket_size = 64,
69+
causal = True
70+
)
71+
72+
attn = Attention(**attn_kwargs)
73+
flash_attn = FlashAttention(**attn_kwargs)
74+
75+
flash_attn.to_q = attn.to_q
76+
flash_attn.to_kv = attn.to_kv
77+
flash_attn.to_out = attn.to_out
78+
79+
x = torch.randn(2, 2048, 512)
80+
mask = torch.ones(2, 2048).bool()
81+
82+
out = attn(x, mask = mask)
83+
mem_efficient_out = flash_attn(x, mask = mask)
84+
85+
assert isclose(mem_efficient_out, out, atol = 1e-6)
86+
87+
# test gradients equal
88+
89+
def test_flash_attn_gradients_equal():
90+
q = torch.randn(1, 8, 1024, 512).requires_grad_()
91+
k = torch.randn(1, 8, 1024, 512).requires_grad_()
92+
v = torch.randn(1, 8, 1024, 512).requires_grad_()
93+
94+
o = attention(q, k, v, causal = False)
95+
o.sum().backward()
96+
97+
dq_grad = q.grad.clone()
98+
dk_grad = k.grad.clone()
99+
dv_grad = v.grad.clone()
100+
101+
q.grad.zero_()
102+
k.grad.zero_()
103+
v.grad.zero_()
104+
105+
flash_o = FlashAttentionFunction.apply(q, k, v, None, False, 64, 64)
106+
flash_o.sum().backward()
107+
108+
flash_dq_grad = q.grad.clone()
109+
flash_dk_grad = k.grad.clone()
110+
flash_dv_grad = v.grad.clone()
111+
112+
assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
113+
assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
114+
assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)

0 commit comments

Comments
 (0)