Skip to content

Commit 7ce4d63

Browse files
committed
fix bugs
1 parent 1f6f659 commit 7ce4d63

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

memory_efficient_attention_pytorch/flash_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
6565
attn_weights.masked_fill_(~row_mask, max_neg_value)
6666

6767
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
68-
causal_mask = torch.ones((q_bucket_size, k_bucket_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
68+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
6969
attn_weights.masked_fill_(causal_mask, max_neg_value)
7070

7171
block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
@@ -143,7 +143,7 @@ def backward(ctx, do):
143143
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
144144

145145
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
146-
causal_mask = torch.ones((q_bucket_size, k_bucket_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
146+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
147147
attn_weights.masked_fill_(causal_mask, max_neg_value)
148148

149149
exp_attn_weights = torch.exp(attn_weights - mc)
@@ -156,7 +156,7 @@ def backward(ctx, do):
156156
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
157157
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
158158

159-
D = (do * o).sum(dim = -1, keepdims = True)
159+
D = (doc * oc).sum(dim = -1, keepdims = True)
160160
ds = p * scale * (dp - D)
161161

162162
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.18',
6+
version = '0.0.19',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)