Skip to content

Commit 867ebc1

Browse files
committed
complete numerical stability for memory efficient attention
1 parent adf5751 commit 867ebc1

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask):
6161
if exists(causal_mask):
6262
weight = weight.masked_fill(causal_mask, mask_value)
6363

64+
weight_max = weight.amax(dim = -1, keepdim = True).detach()
65+
weight = weight - weight_max
66+
6467
exp_weight = weight.exp()
6568
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
66-
return exp_weight.sum(dim = -1), weighted_value
69+
70+
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
6771

6872
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
6973

@@ -93,8 +97,9 @@ def memory_efficient_attention(
9397

9498
out = []
9599
for q_index, q_chunk in enumerate(q_chunks):
96-
exp_weights = None
97-
weighted_values = None
100+
exp_weights = []
101+
weighted_values = []
102+
weight_maxes = []
98103

99104
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
100105

@@ -105,18 +110,33 @@ def memory_efficient_attention(
105110
(k_index * k_bucket_size):(k_index * k_bucket_size + k_bucket_size),
106111
]
107112

108-
exp_weight_chunk, weighted_value_chunk = checkpointed_summarize_qkv_chunk(
113+
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = checkpointed_summarize_qkv_chunk(
109114
q_chunk,
110115
k_chunk,
111116
v_chunk,
112117
mask_chunk,
113118
causal_mask_chunk
114119
)
115120

116-
exp_weights = safe_sum(exp_weights, exp_weight_chunk)
117-
weighted_values = safe_sum(weighted_values, weighted_value_chunk)
121+
exp_weights.append(exp_weight_chunk)
122+
weighted_values.append(weighted_value_chunk)
123+
weight_maxes.append(weight_max_chunk)
124+
125+
weight_maxes = torch.stack(weight_maxes, dim = -1)
126+
127+
weighted_values = torch.stack(weighted_values, dim = -1)
128+
exp_weights = torch.stack(exp_weights, dim = -1)
129+
130+
global_max = weight_maxes.amax(dim = -1, keepdim = True)
131+
renorm_factor = (weight_maxes - global_max).exp().detach()
132+
133+
exp_weights = exp_weights * renorm_factor
134+
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
135+
136+
all_values = weighted_values.sum(dim = -1)
137+
all_weights = exp_weights.sum(dim = -1)
118138

119-
normalized_values = weighted_values / (rearrange(exp_weights, '... -> ... 1') + eps)
139+
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
120140
out.append(normalized_values)
121141

122142
return torch.cat(out, dim = -2)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.1',
6+
version = '0.0.2',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)