Skip to content

Commit d4b015f

Browse files
committed
if the chunk is to be all masked out causally, skip summarizing the block entirely, for memory efficient attention
1 parent 64f30f8 commit d4b015f

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def memory_efficient_attention(
102102

103103
causal_mask_chunk = causal_mask_chunks[q_index][k_index] if causal else None
104104

105+
if exists(causal_mask_chunk) and torch.all(causal_mask_chunk):
106+
# if chunk is to be all masked out causally, skip
107+
continue
108+
105109
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = checkpointed_summarize_qkv_chunk(
106110
q_chunk,
107111
k_chunk,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.5',
6+
version = '0.0.6',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)