Skip to content

Commit 9bce9f8

Browse files
committed
fix causal mask leading to high mem consumption for 65536 length
1 parent 5b60e2d commit 9bce9f8

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def attention(
3939

4040
if causal:
4141
i, j = sim.shape[-2:]
42-
mask = torch.ones(i, j, device = q.device).triu(j - i + 1).bool()
42+
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
4343
sim = sim.masked_fill(mask, mask_value)
4444

4545
attn = sim.softmax(dim = -1)
@@ -95,7 +95,7 @@ def memory_efficient_attention(
9595

9696
if causal:
9797
i, j = q.shape[-2], k.shape[-2]
98-
causal_mask = torch.ones(i, j, device = q.device).triu(j - i + 1).bool()
98+
causal_mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
9999
causal_mask_chunks = causal_mask.split(q_bucket_size, dim = 0)
100100
causal_mask_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), causal_mask_chunks))
101101

memory_efficient_attention_pytorch/memory_efficient_cosine_sim_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def attention(
4040

4141
if causal:
4242
i, j = sim.shape[-2:]
43-
mask = torch.ones(i, j, device = q.device).triu(j - i + 1).bool()
43+
mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
4444
sim = sim.masked_fill(mask, mask_value)
4545

4646
attn = sim.softmax(dim = -1)
@@ -90,7 +90,7 @@ def numerically_unstable_memory_efficient_attention(
9090

9191
if causal:
9292
i, j = q.shape[-2], k.shape[-2]
93-
causal_mask = torch.ones(i, j, device = q.device).triu(j - i + 1).bool()
93+
causal_mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
9494
causal_mask_chunks = causal_mask.split(q_bucket_size, dim = 0)
9595
causal_mask_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), causal_mask_chunks))
9696

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.9',
6+
version = '0.0.10',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)