Skip to content

Commit fdcf17b

Browse files
committed
add checkpointing for memory savings
1 parent 785a26a commit fdcf17b

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from functools import partial
23
from torch import nn, einsum
34
from torch.utils.checkpoint import checkpoint
45

@@ -49,10 +50,7 @@ def safe_sum(acc, el):
4950
return el
5051
return acc + el
5152

52-
def summarize_qkv_chunk(
53-
q, k, v,
54-
mask = None
55-
):
53+
def summarize_qkv_chunk(q, k, v, mask):
5654
weight = einsum('b h i d, b h j d -> b h i j', q, k)
5755
exp_weight = weight.exp()
5856

@@ -63,6 +61,8 @@ def summarize_qkv_chunk(
6361
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
6462
return exp_weight.sum(dim = -1), weighted_value
6563

64+
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
65+
6666
def memory_efficient_attention(
6767
q, k, v,
6868
mask = None,
@@ -89,11 +89,11 @@ def memory_efficient_attention(
8989

9090
for k_chunk, v_chunk, mask_chunk in zip(k_chunks, v_chunks, mask_chunks):
9191

92-
exp_weight_chunk, weighted_value_chunk = summarize_qkv_chunk(
93-
q = q_chunk,
94-
k = k_chunk,
95-
v = v_chunk,
96-
mask = mask_chunk
92+
exp_weight_chunk, weighted_value_chunk = checkpointed_summarize_qkv_chunk(
93+
q_chunk,
94+
k_chunk,
95+
v_chunk,
96+
mask_chunk
9797
)
9898

9999
exp_weights = safe_sum(exp_weights, exp_weight_chunk)

0 commit comments

Comments
 (0)