@@ -65,7 +65,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
65
65
attn_weights .masked_fill_ (~ row_mask , max_neg_value )
66
66
67
67
if causal and q_start_index < (k_start_index + k_bucket_size - 1 ):
68
- causal_mask = torch .ones ((q_bucket_size , k_bucket_size ), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
68
+ causal_mask = torch .ones ((qc . shape [ - 2 ], kc . shape [ - 2 ] ), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
69
69
attn_weights .masked_fill_ (causal_mask , max_neg_value )
70
70
71
71
block_row_maxes = attn_weights .amax (dim = - 1 , keepdims = True )
@@ -143,7 +143,7 @@ def backward(ctx, do):
143
143
attn_weights = einsum ('... i d, ... j d -> ... i j' , qc_scaled , kc )
144
144
145
145
if causal and q_start_index < (k_start_index + k_bucket_size - 1 ):
146
- causal_mask = torch .ones ((q_bucket_size , k_bucket_size ), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
146
+ causal_mask = torch .ones ((qc . shape [ - 2 ], kc . shape [ - 2 ] ), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
147
147
attn_weights .masked_fill_ (causal_mask , max_neg_value )
148
148
149
149
exp_attn_weights = torch .exp (attn_weights - mc )
@@ -156,7 +156,7 @@ def backward(ctx, do):
156
156
dv_chunk = einsum ('... i j, ... i d -> ... j d' , p , doc )
157
157
dp = einsum ('... i d, ... j d -> ... i j' , doc , vc )
158
158
159
- D = (do * o ).sum (dim = - 1 , keepdims = True )
159
+ D = (doc * oc ).sum (dim = - 1 , keepdims = True )
160
160
ds = p * scale * (dp - D )
161
161
162
162
dq_chunk = einsum ('... i j, ... j d -> ... i d' , ds , kc )
0 commit comments