@@ -95,8 +95,10 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
95
95
row_maxes .copy_ (new_row_maxes )
96
96
row_sums .copy_ (new_row_sums )
97
97
98
+ lse = all_row_sums .log () + all_row_maxes
99
+
98
100
ctx .args = (causal , scale , mask , q_bucket_size , k_bucket_size )
99
- ctx .save_for_backward (q , k , v , o , all_row_sums , all_row_maxes )
101
+ ctx .save_for_backward (q , k , v , o , lse )
100
102
101
103
return o
102
104
@@ -106,7 +108,7 @@ def backward(ctx, do):
106
108
""" Algorithm 4 in the paper """
107
109
108
110
causal , scale , mask , q_bucket_size , k_bucket_size = ctx .args
109
- q , k , v , o , l , m = ctx .saved_tensors
111
+ q , k , v , o , lse = ctx .saved_tensors
110
112
111
113
device = q .device
112
114
@@ -122,12 +124,11 @@ def backward(ctx, do):
122
124
o .split (q_bucket_size , dim = - 2 ),
123
125
do .split (q_bucket_size , dim = - 2 ),
124
126
mask ,
125
- l .split (q_bucket_size , dim = - 2 ),
126
- m .split (q_bucket_size , dim = - 2 ),
127
+ lse .split (q_bucket_size , dim = - 2 ),
127
128
dq .split (q_bucket_size , dim = - 2 )
128
129
)
129
130
130
- for ind , (qc , oc , doc , row_mask , lc , mc , dqc ) in enumerate (row_splits ):
131
+ for ind , (qc , oc , doc , row_mask , lsec , dqc ) in enumerate (row_splits ):
131
132
q_start_index = ind * q_bucket_size - qk_len_diff
132
133
133
134
col_splits = zip (
@@ -146,12 +147,10 @@ def backward(ctx, do):
146
147
causal_mask = torch .ones ((qc .shape [- 2 ], kc .shape [- 2 ]), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
147
148
attn_weights .masked_fill_ (causal_mask , max_neg_value )
148
149
149
- exp_attn_weights = torch .exp (attn_weights - mc )
150
+ p = torch .exp (attn_weights - lsec )
150
151
151
152
if exists (row_mask ):
152
- exp_attn_weights .masked_fill_ (~ row_mask , 0. )
153
-
154
- p = exp_attn_weights / lc
153
+ p .masked_fill_ (~ row_mask , 0. )
155
154
156
155
dv_chunk = einsum ('... i j, ... i d -> ... j d' , p , doc )
157
156
dp = einsum ('... i d, ... j d -> ... i j' , doc , vc )
0 commit comments