Skip to content

Commit c901ae5

Browse files
committed
point to algorithms in paper
1 parent d5e968b commit c901ae5

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

memory_efficient_attention_pytorch/flash_attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ def default(val, d):
2020

2121
# flash attention forwards and backwards
2222

23+
# https://arxiv.org/abs/2205.14135
24+
2325
class FlashAttentionFunction(Function):
2426
@staticmethod
2527
@torch.no_grad()
2628
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
29+
""" Algorithm 2 in the paper """
30+
2731
device = q.device
2832
max_neg_value = -torch.finfo(q.dtype).max
2933
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
@@ -87,10 +91,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
8791

8892
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
8993

90-
out = (row_sums / new_row_sums) * exp_row_max_diff * oc + \
91-
(exp_block_row_max_diff / new_row_sums) * exp_values
92-
93-
oc.copy_(out)
94+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
9495
row_maxes.copy_(new_row_maxes)
9596
row_sums.copy_(new_row_sums)
9697

@@ -102,6 +103,8 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
102103
@staticmethod
103104
@torch.no_grad()
104105
def backward(ctx, do):
106+
""" Algorithm 4 in the paper """
107+
105108
causal, mask, q_bucket_size, k_bucket_size = ctx.args
106109
q, k, v, o, l, m = ctx.saved_tensors
107110

0 commit comments

Comments
 (0)