@@ -61,9 +61,13 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask):
61
61
if exists (causal_mask ):
62
62
weight = weight .masked_fill (causal_mask , mask_value )
63
63
64
+ weight_max = weight .amax (dim = - 1 , keepdim = True ).detach ()
65
+ weight = weight - weight_max
66
+
64
67
exp_weight = weight .exp ()
65
68
weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
66
- return exp_weight .sum (dim = - 1 ), weighted_value
69
+
70
+ return exp_weight .sum (dim = - 1 ), weighted_value , rearrange (weight_max , '... 1 -> ...' )
67
71
68
72
checkpointed_summarize_qkv_chunk = partial (checkpoint , summarize_qkv_chunk )
69
73
@@ -93,8 +97,9 @@ def memory_efficient_attention(
93
97
94
98
out = []
95
99
for q_index , q_chunk in enumerate (q_chunks ):
96
- exp_weights = None
97
- weighted_values = None
100
+ exp_weights = []
101
+ weighted_values = []
102
+ weight_maxes = []
98
103
99
104
for k_index , (k_chunk , v_chunk , mask_chunk ) in enumerate (zip (k_chunks , v_chunks , mask_chunks )):
100
105
@@ -105,18 +110,33 @@ def memory_efficient_attention(
105
110
(k_index * k_bucket_size ):(k_index * k_bucket_size + k_bucket_size ),
106
111
]
107
112
108
- exp_weight_chunk , weighted_value_chunk = checkpointed_summarize_qkv_chunk (
113
+ exp_weight_chunk , weighted_value_chunk , weight_max_chunk = checkpointed_summarize_qkv_chunk (
109
114
q_chunk ,
110
115
k_chunk ,
111
116
v_chunk ,
112
117
mask_chunk ,
113
118
causal_mask_chunk
114
119
)
115
120
116
- exp_weights = safe_sum (exp_weights , exp_weight_chunk )
117
- weighted_values = safe_sum (weighted_values , weighted_value_chunk )
121
+ exp_weights .append (exp_weight_chunk )
122
+ weighted_values .append (weighted_value_chunk )
123
+ weight_maxes .append (weight_max_chunk )
124
+
125
+ weight_maxes = torch .stack (weight_maxes , dim = - 1 )
126
+
127
+ weighted_values = torch .stack (weighted_values , dim = - 1 )
128
+ exp_weights = torch .stack (exp_weights , dim = - 1 )
129
+
130
+ global_max = weight_maxes .amax (dim = - 1 , keepdim = True )
131
+ renorm_factor = (weight_maxes - global_max ).exp ().detach ()
132
+
133
+ exp_weights = exp_weights * renorm_factor
134
+ weighted_values = weighted_values * rearrange (renorm_factor , '... c -> ... 1 c' )
135
+
136
+ all_values = weighted_values .sum (dim = - 1 )
137
+ all_weights = exp_weights .sum (dim = - 1 )
118
138
119
- normalized_values = weighted_values / (rearrange (exp_weights , '... -> ... 1' ) + eps )
139
+ normalized_values = all_values / (rearrange (all_weights , '... -> ... 1' ) + eps )
120
140
out .append (normalized_values )
121
141
122
142
return torch .cat (out , dim = - 2 )
0 commit comments