Skip to content

Commit 8d7de9c

Browse files
committed
further simplification
1 parent 75469e8 commit 8d7de9c

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

memory_efficient_attention_pytorch/flash_attention.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# constants
1010

11-
EPSILON = 1e-6
11+
EPSILON = 1e-10
1212

1313
# helper functions
1414

@@ -81,24 +81,22 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
8181
attn_weights.masked_fill_(causal_mask, max_neg_value)
8282

8383
block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
84-
attn_weights -= block_row_maxes
85-
exp_weights = torch.exp(attn_weights)
84+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
85+
86+
exp_weights = torch.exp(attn_weights - new_row_maxes)
8687

8788
if exists(col_mask):
8889
exp_weights.masked_fill_(~col_mask, 0.)
8990

9091
block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
9192

92-
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
93-
9493
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
9594

9695
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
97-
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
9896

99-
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
97+
new_row_sums = exp_row_max_diff * row_sums + block_row_sums
10098

101-
oc.mul_(exp_row_max_diff).add_(exp_block_row_max_diff * exp_values)
99+
oc.mul_(exp_row_max_diff).add_(exp_values)
102100

103101
row_maxes.copy_(new_row_maxes)
104102
row_sums.copy_(new_row_sums)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.5',
6+
version = '0.1.6',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)