1
1
import torch
2
+ from functools import partial
2
3
from torch import nn , einsum
3
4
from torch .utils .checkpoint import checkpoint
4
5
@@ -49,10 +50,7 @@ def safe_sum(acc, el):
49
50
return el
50
51
return acc + el
51
52
52
- def summarize_qkv_chunk (
53
- q , k , v ,
54
- mask = None
55
- ):
53
+ def summarize_qkv_chunk (q , k , v , mask ):
56
54
weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
57
55
exp_weight = weight .exp ()
58
56
@@ -63,6 +61,8 @@ def summarize_qkv_chunk(
63
61
weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
64
62
return exp_weight .sum (dim = - 1 ), weighted_value
65
63
64
+ checkpointed_summarize_qkv_chunk = partial (checkpoint , summarize_qkv_chunk )
65
+
66
66
def memory_efficient_attention (
67
67
q , k , v ,
68
68
mask = None ,
@@ -89,11 +89,11 @@ def memory_efficient_attention(
89
89
90
90
for k_chunk , v_chunk , mask_chunk in zip (k_chunks , v_chunks , mask_chunks ):
91
91
92
- exp_weight_chunk , weighted_value_chunk = summarize_qkv_chunk (
93
- q = q_chunk ,
94
- k = k_chunk ,
95
- v = v_chunk ,
96
- mask = mask_chunk
92
+ exp_weight_chunk , weighted_value_chunk = checkpointed_summarize_qkv_chunk (
93
+ q_chunk ,
94
+ k_chunk ,
95
+ v_chunk ,
96
+ mask_chunk
97
97
)
98
98
99
99
exp_weights = safe_sum (exp_weights , exp_weight_chunk )
0 commit comments