@@ -50,7 +50,7 @@ def attention(
50
50
51
51
# memory efficient attention
52
52
53
- def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices , dropout = 0. , training = False ):
53
+ def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices , dropout ):
54
54
q_start_index , k_start_index , q_chunk_size , k_chunk_size , device = * qk_start_indices , q .shape [- 2 ], k .shape [- 2 ], q .device
55
55
56
56
weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
@@ -72,8 +72,9 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
72
72
weight = weight - weight_max
73
73
74
74
exp_weight = weight .exp ()
75
- if training :
76
- exp_weight = F .dropout (exp_weight , p = dropout , training = training )
75
+
76
+ exp_weight = F .dropout (exp_weight , p = dropout )
77
+
77
78
weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
78
79
79
80
return exp_weight .sum (dim = - 1 ), weighted_value , rearrange (weight_max , '... 1 -> ...' )
@@ -137,8 +138,7 @@ def memory_efficient_attention(
137
138
attn_bias_chunk ,
138
139
causal ,
139
140
(q_start_index , k_start_index ),
140
- dropout = dropout ,
141
- training = training
141
+ dropout if training else 0.
142
142
)
143
143
144
144
exp_weights .append (exp_weight_chunk )
0 commit comments