Skip to content

Commit 75469e8

Browse files
committed
add new trick from flash attention 2 that saves on division
1 parent 8f85587 commit 75469e8

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Implementation of a memory efficient multi-head attention as proposed in the paper, <a href="https://arxiv.org/abs/2112.05682">Self-attention Does Not Need O(n²) Memory</a>. In addition, the module will take care of masking, causal masking, as well as cross attention.
44

5-
This repository also contains a <a href="https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py">naive non-CUDA implementation</a> of the improvements made by <a href="https://tridao.me/">Tri Dao</a> with his <a href="https://github.com/HazyResearch/flash-attention">Flash Attention</a> paper, for educational purposes. It is a game changer for attention and building long-context transformers.
5+
This repository also contains a <a href="https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py">naive non-CUDA implementation</a> of the improvements made by <a href="https://tridao.me/">Tri Dao</a> with his <a href="https://github.com/HazyResearch/flash-attention">Flash Attention 2</a> paper, for educational purposes. It is a game changer for attention and building long-context transformers.
66

77
Update: from now on, you should just be using the <a href="https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product#torch.nn.functional.scaled_dot_product_attention">`F.scaled_dot_product_attention`</a> function in Pytorch 2.0
88

@@ -89,3 +89,11 @@ out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
8989
volume = {abs/2205.14135}
9090
}
9191
```
92+
93+
```bibtex
94+
@article{dao2023flashattention2,
95+
title = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
96+
author = {Dao, Tri},
97+
year = {2023}
98+
}
99+
```

memory_efficient_attention_pytorch/flash_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
9898

9999
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
100100

101-
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
101+
oc.mul_(exp_row_max_diff).add_(exp_block_row_max_diff * exp_values)
102102

103103
row_maxes.copy_(new_row_maxes)
104104
row_sums.copy_(new_row_sums)
105105

106+
oc.div_(row_sums)
107+
106108
lse = all_row_sums.log() + all_row_maxes
107109

108110
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.4',
6+
version = '0.1.5',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)