@@ -20,10 +20,14 @@ def default(val, d):
20
20
21
21
# flash attention forwards and backwards
22
22
23
+ # https://arxiv.org/abs/2205.14135
24
+
23
25
class FlashAttentionFunction (Function ):
24
26
@staticmethod
25
27
@torch .no_grad ()
26
28
def forward (ctx , q , k , v , mask , causal , q_bucket_size , k_bucket_size ):
29
+ """ Algorithm 2 in the paper """
30
+
27
31
device = q .device
28
32
max_neg_value = - torch .finfo (q .dtype ).max
29
33
qk_len_diff = max (k .shape [- 2 ] - q .shape [- 2 ], 0 )
@@ -87,10 +91,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
87
91
88
92
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
89
93
90
- out = (row_sums / new_row_sums ) * exp_row_max_diff * oc + \
91
- (exp_block_row_max_diff / new_row_sums ) * exp_values
92
-
93
- oc .copy_ (out )
94
+ oc .mul_ ((row_sums / new_row_sums ) * exp_row_max_diff ).add_ ((exp_block_row_max_diff / new_row_sums ) * exp_values )
94
95
row_maxes .copy_ (new_row_maxes )
95
96
row_sums .copy_ (new_row_sums )
96
97
@@ -102,6 +103,8 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
102
103
@staticmethod
103
104
@torch .no_grad ()
104
105
def backward (ctx , do ):
106
+ """ Algorithm 4 in the paper """
107
+
105
108
causal , mask , q_bucket_size , k_bucket_size = ctx .args
106
109
q , k , v , o , l , m = ctx .saved_tensors
107
110
0 commit comments