Skip to content

Commit 4be8244

Browse files
committed
materialize causal mask only when needed, to reduce peak memory usage even more
1 parent 9db7dc8 commit 4be8244

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def attention(
4949

5050
# memory efficient attention
5151

52-
def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
52+
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
53+
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
54+
5355
weight = einsum('b h i d, b h j d -> b h i j', q, k)
5456

5557
if exists(attn_bias_chunk):
@@ -61,7 +63,10 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
6163
mask = rearrange(mask, 'b j -> b 1 1 j')
6264
weight = weight.masked_fill(~mask, mask_value)
6365

64-
if exists(causal_mask):
66+
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
67+
q_range = torch.arange(q_start_index, q_start_index + q_chunk_size, device = device)
68+
k_range = torch.arange(k_start_index, k_start_index + k_chunk_size, device = device)
69+
causal_mask = rearrange(q_range, 'i -> i 1') < rearrange(k_range, 'j -> 1 j')
6570
weight = weight.masked_fill(causal_mask, mask_value)
6671

6772
weight_max = weight.amax(dim = -1, keepdim = True).detach()
@@ -98,12 +103,6 @@ def memory_efficient_attention(
98103
v_chunks = v.split(k_bucket_size, dim = -2)
99104
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
100105

101-
if causal:
102-
i, j = q.shape[-2], k.shape[-2]
103-
causal_mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
104-
causal_mask_chunks = causal_mask.split(q_bucket_size, dim = 0)
105-
causal_mask_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), causal_mask_chunks))
106-
107106
if exists(attn_bias):
108107
i, j = attn_bias.shape[-2:]
109108
attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
@@ -118,10 +117,10 @@ def memory_efficient_attention(
118117
weight_maxes = []
119118

120119
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
120+
q_start_index = q_index * q_bucket_size
121+
k_start_index = k_index * k_bucket_size
121122

122-
causal_mask_chunk = causal_mask_chunks[q_index][k_index] if causal else None
123-
124-
if exists(causal_mask_chunk) and torch.all(causal_mask_chunk):
123+
if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
125124
# if chunk is to be all masked out causally, skip
126125
continue
127126

@@ -132,8 +131,9 @@ def memory_efficient_attention(
132131
k_chunk,
133132
v_chunk,
134133
mask_chunk,
135-
causal_mask_chunk,
136-
attn_bias_chunk
134+
attn_bias_chunk,
135+
causal,
136+
(q_start_index, k_start_index)
137137
)
138138

139139
exp_weights.append(exp_weight_chunk)

memory_efficient_attention_pytorch/memory_efficient_cosine_sim_attention.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def attention(
5050

5151
# memory efficient attention
5252

53-
def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
53+
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
54+
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
55+
5456
weight = einsum('b h i d, b h j d -> b h i j', q, k)
5557

5658
if exists(attn_bias_chunk):
@@ -62,7 +64,10 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
6264
mask = rearrange(mask, 'b j -> b 1 1 j')
6365
weight = weight.masked_fill(~mask, mask_value)
6466

65-
if exists(causal_mask):
67+
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
68+
q_range = torch.arange(q_start_index, q_start_index + q_chunk_size, device = device)
69+
k_range = torch.arange(k_start_index, k_start_index + k_chunk_size, device = device)
70+
causal_mask = rearrange(q_range, 'i -> i 1') < rearrange(k_range, 'j -> 1 j')
6671
weight = weight.masked_fill(causal_mask, mask_value)
6772

6873
exp_weight = weight.exp()
@@ -91,12 +96,6 @@ def numerically_unstable_memory_efficient_attention(
9196
v_chunks = v.split(k_bucket_size, dim = -2)
9297
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
9398

94-
if causal:
95-
i, j = q.shape[-2], k.shape[-2]
96-
causal_mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
97-
causal_mask_chunks = causal_mask.split(q_bucket_size, dim = 0)
98-
causal_mask_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), causal_mask_chunks))
99-
10099
if exists(attn_bias):
101100
i, j = attn_bias.shape[-2:]
102101
attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
@@ -106,14 +105,14 @@ def numerically_unstable_memory_efficient_attention(
106105

107106
out = []
108107
for q_index, q_chunk in enumerate(q_chunks):
108+
q_start_index = q_index * q_bucket_size
109109
exp_weights = []
110110
weighted_values = []
111111

112112
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
113+
k_start_index = k_index * k_bucket_size
113114

114-
causal_mask_chunk = causal_mask_chunks[q_index][k_index] if causal else None
115-
116-
if exists(causal_mask_chunk) and torch.all(causal_mask_chunk):
115+
if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
117116
# if chunk is to be all masked out causally, skip
118117
continue
119118

@@ -124,8 +123,9 @@ def numerically_unstable_memory_efficient_attention(
124123
k_chunk,
125124
v_chunk,
126125
mask_chunk,
127-
causal_mask_chunk,
128-
attn_bias_chunk
126+
attn_bias_chunk,
127+
causal,
128+
(q_start_index, k_start_index)
129129
)
130130

131131
exp_weights.append(exp_weight_chunk)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.15',
6+
version = '0.0.16',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)