@@ -50,7 +50,9 @@ def attention(
50
50
51
51
# memory efficient attention
52
52
53
- def summarize_qkv_chunk (q , k , v , mask , causal_mask , attn_bias_chunk ):
53
+ def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices ):
54
+ q_start_index , k_start_index , q_chunk_size , k_chunk_size , device = * qk_start_indices , q .shape [- 2 ], k .shape [- 2 ], q .device
55
+
54
56
weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
55
57
56
58
if exists (attn_bias_chunk ):
@@ -62,7 +64,10 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
62
64
mask = rearrange (mask , 'b j -> b 1 1 j' )
63
65
weight = weight .masked_fill (~ mask , mask_value )
64
66
65
- if exists (causal_mask ):
67
+ if causal and q_start_index < (k_start_index + k_chunk_size - 1 ):
68
+ q_range = torch .arange (q_start_index , q_start_index + q_chunk_size , device = device )
69
+ k_range = torch .arange (k_start_index , k_start_index + k_chunk_size , device = device )
70
+ causal_mask = rearrange (q_range , 'i -> i 1' ) < rearrange (k_range , 'j -> 1 j' )
66
71
weight = weight .masked_fill (causal_mask , mask_value )
67
72
68
73
exp_weight = weight .exp ()
@@ -91,12 +96,6 @@ def numerically_unstable_memory_efficient_attention(
91
96
v_chunks = v .split (k_bucket_size , dim = - 2 )
92
97
mask_chunks = mask .split (k_bucket_size , dim = - 1 ) if exists (mask ) else ((None ,) * len (k_chunks ))
93
98
94
- if causal :
95
- i , j = q .shape [- 2 ], k .shape [- 2 ]
96
- causal_mask = torch .ones (i , j , device = q .device , dtype = torch .bool ).triu (j - i + 1 )
97
- causal_mask_chunks = causal_mask .split (q_bucket_size , dim = 0 )
98
- causal_mask_chunks = list (map (lambda t : t .split (k_bucket_size , dim = - 1 ), causal_mask_chunks ))
99
-
100
99
if exists (attn_bias ):
101
100
i , j = attn_bias .shape [- 2 :]
102
101
attn_bias_chunks = attn_bias .split (q_bucket_size , dim = - 2 )
@@ -106,14 +105,14 @@ def numerically_unstable_memory_efficient_attention(
106
105
107
106
out = []
108
107
for q_index , q_chunk in enumerate (q_chunks ):
108
+ q_start_index = q_index * q_bucket_size
109
109
exp_weights = []
110
110
weighted_values = []
111
111
112
112
for k_index , (k_chunk , v_chunk , mask_chunk ) in enumerate (zip (k_chunks , v_chunks , mask_chunks )):
113
+ k_start_index = k_index * k_bucket_size
113
114
114
- causal_mask_chunk = causal_mask_chunks [q_index ][k_index ] if causal else None
115
-
116
- if exists (causal_mask_chunk ) and torch .all (causal_mask_chunk ):
115
+ if causal and k_start_index > (q_start_index + q_chunk .shape [- 2 ] - 1 ):
117
116
# if chunk is to be all masked out causally, skip
118
117
continue
119
118
@@ -124,8 +123,9 @@ def numerically_unstable_memory_efficient_attention(
124
123
k_chunk ,
125
124
v_chunk ,
126
125
mask_chunk ,
127
- causal_mask_chunk ,
128
- attn_bias_chunk
126
+ attn_bias_chunk ,
127
+ causal ,
128
+ (q_start_index , k_start_index )
129
129
)
130
130
131
131
exp_weights .append (exp_weight_chunk )
0 commit comments