Skip to content

Commit e4d0998

Browse files
committed
directly calculate triu shift value for causal mask for memory savings
1 parent 4be8244 commit e4d0998

File tree

3 files changed

+3
-7
lines changed

3 files changed

+3
-7
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
6464
weight = weight.masked_fill(~mask, mask_value)
6565

6666
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
67-
q_range = torch.arange(q_start_index, q_start_index + q_chunk_size, device = device)
68-
k_range = torch.arange(k_start_index, k_start_index + k_chunk_size, device = device)
69-
causal_mask = rearrange(q_range, 'i -> i 1') < rearrange(k_range, 'j -> 1 j')
67+
causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
7068
weight = weight.masked_fill(causal_mask, mask_value)
7169

7270
weight_max = weight.amax(dim = -1, keepdim = True).detach()

memory_efficient_attention_pytorch/memory_efficient_cosine_sim_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
6565
weight = weight.masked_fill(~mask, mask_value)
6666

6767
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
68-
q_range = torch.arange(q_start_index, q_start_index + q_chunk_size, device = device)
69-
k_range = torch.arange(k_start_index, k_start_index + k_chunk_size, device = device)
70-
causal_mask = rearrange(q_range, 'i -> i 1') < rearrange(k_range, 'j -> 1 j')
68+
causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
7169
weight = weight.masked_fill(causal_mask, mask_value)
7270

7371
exp_weight = weight.exp()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.16',
6+
version = '0.0.17',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)