Skip to content

Commit 9ecb55a

Browse files
committed
max for numerical stability should be taken after masking
1 parent 99fbdf0 commit 9ecb55a

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def attention(
3030
if exists(attn_bias):
3131
sim = sim + attn_bias
3232

33-
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
3433
mask_value = -torch.finfo(sim.dtype).max
3534

3635
if exists(mask):
@@ -42,6 +41,7 @@ def attention(
4241
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
4342
sim = sim.masked_fill(mask, mask_value)
4443

44+
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
4545
attn = sim.softmax(dim = -1)
4646

4747
out = einsum('b h i j, b h j d -> b h i d', attn, v)
@@ -55,9 +55,6 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
5555
if exists(attn_bias_chunk):
5656
weight = weight + attn_bias_chunk
5757

58-
weight_max = weight.amax(dim = -1, keepdim = True).detach()
59-
weight = weight - weight_max
60-
6158
mask_value = -torch.finfo(weight.dtype).max
6259

6360
if exists(mask):
@@ -67,6 +64,9 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
6764
if exists(causal_mask):
6865
weight = weight.masked_fill(causal_mask, mask_value)
6966

67+
weight_max = weight.amax(dim = -1, keepdim = True).detach()
68+
weight = weight - weight_max
69+
7070
exp_weight = weight.exp()
7171
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
7272

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.12',
6+
version = '0.0.14',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',
@@ -15,7 +15,7 @@
1515
'attention-mechanism'
1616
],
1717
install_requires=[
18-
'einops>=0.3,<0.4',
18+
'einops>=0.4.1',
1919
'torch>=1.6'
2020
],
2121
setup_requires=[

0 commit comments

Comments
 (0)