@@ -31,7 +31,7 @@ def attention(
31
31
32
32
if exists (mask ):
33
33
mask = rearrange (mask , 'b j -> b 1 1 j' )
34
- sim = sim .masked_fill (mask , mask_value )
34
+ sim = sim .masked_fill (~ mask , mask_value )
35
35
36
36
if causal :
37
37
i , j = sim .shape [- 2 :]
@@ -50,14 +50,18 @@ def safe_sum(acc, el):
50
50
return el
51
51
return acc + el
52
52
53
- def summarize_qkv_chunk (q , k , v , mask ):
53
+ def summarize_qkv_chunk (q , k , v , mask , causal_mask ):
54
54
weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
55
- exp_weight = weight . exp ()
55
+ mask_value = - torch . finfo ( weight . dtype ). max
56
56
57
57
if exists (mask ):
58
58
mask = rearrange (mask , 'b j -> b 1 1 j' )
59
- exp_weight = exp_weight .masked_fill (mask , 0. )
59
+ weight = weight .masked_fill (~ mask , mask_value )
60
60
61
+ if exists (causal_mask ):
62
+ weight = weight .masked_fill (causal_mask , mask_value )
63
+
64
+ exp_weight = weight .exp ()
61
65
weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
62
66
return exp_weight .sum (dim = - 1 ), weighted_value
63
67
@@ -68,7 +72,8 @@ def memory_efficient_attention(
68
72
mask = None ,
69
73
causal = False ,
70
74
q_bucket_size = 512 ,
71
- k_bucket_size = 1024
75
+ k_bucket_size = 1024 ,
76
+ eps = 1e-8
72
77
):
73
78
scale = q .shape [- 1 ] ** - 0.5
74
79
q = q * scale
@@ -80,26 +85,38 @@ def memory_efficient_attention(
80
85
v_chunks = v .split (k_bucket_size , dim = - 2 )
81
86
mask_chunks = mask .split (k_bucket_size , dim = - 1 ) if exists (mask ) else ((None ,) * len (k_chunks ))
82
87
88
+ if causal :
89
+ i , j = q .shape [- 2 ], k .shape [- 2 ]
90
+ causal_mask = torch .ones (i , j ).triu (j - i + 1 ).bool ()
91
+
83
92
# loop through all chunks and accumulate
84
93
85
94
out = []
86
- for q_chunk in q_chunks :
95
+ for q_index , q_chunk in enumerate ( q_chunks ) :
87
96
exp_weights = None
88
97
weighted_values = None
89
98
90
- for k_chunk , v_chunk , mask_chunk in zip (k_chunks , v_chunks , mask_chunks ):
99
+ for k_index , (k_chunk , v_chunk , mask_chunk ) in enumerate (zip (k_chunks , v_chunks , mask_chunks )):
100
+
101
+ causal_mask_chunk = None
102
+ if causal :
103
+ causal_mask_chunk = causal_mask [
104
+ (q_index * q_bucket_size ):(q_index * q_bucket_size + q_bucket_size ),
105
+ (k_index * k_bucket_size ):(k_index * k_bucket_size + k_bucket_size ),
106
+ ]
91
107
92
108
exp_weight_chunk , weighted_value_chunk = checkpointed_summarize_qkv_chunk (
93
109
q_chunk ,
94
110
k_chunk ,
95
111
v_chunk ,
96
- mask_chunk
112
+ mask_chunk ,
113
+ causal_mask_chunk
97
114
)
98
115
99
116
exp_weights = safe_sum (exp_weights , exp_weight_chunk )
100
117
weighted_values = safe_sum (weighted_values , weighted_value_chunk )
101
118
102
- normalized_values = weighted_values / rearrange (exp_weights , '... -> ... 1' )
119
+ normalized_values = weighted_values / ( rearrange (exp_weights , '... -> ... 1' ) + eps )
103
120
out .append (normalized_values )
104
121
105
122
return torch .cat (out , dim = - 2 )
0 commit comments