2
2
from functools import partial
3
3
from torch import nn , einsum
4
4
from torch .utils .checkpoint import checkpoint
5
+ import torch .nn .functional as F
5
6
6
7
from einops import rearrange
7
8
@@ -49,7 +50,7 @@ def attention(
49
50
50
51
# memory efficient attention
51
52
52
- def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices ):
53
+ def summarize_qkv_chunk (q , k , v , mask , attn_bias_chunk , causal , qk_start_indices , dropout = 0. , training = False ):
53
54
q_start_index , k_start_index , q_chunk_size , k_chunk_size , device = * qk_start_indices , q .shape [- 2 ], k .shape [- 2 ], q .device
54
55
55
56
weight = einsum ('b h i d, b h j d -> b h i j' , q , k )
@@ -71,6 +72,8 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
71
72
weight = weight - weight_max
72
73
73
74
exp_weight = weight .exp ()
75
+ if training :
76
+ exp_weight = F .dropout (exp_weight , p = dropout , training = training )
74
77
weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
75
78
76
79
return exp_weight .sum (dim = - 1 ), weighted_value , rearrange (weight_max , '... 1 -> ...' )
@@ -84,7 +87,9 @@ def memory_efficient_attention(
84
87
attn_bias = None ,
85
88
q_bucket_size = 512 ,
86
89
k_bucket_size = 1024 ,
87
- eps = 1e-8
90
+ eps = 1e-8 ,
91
+ dropout = 0. ,
92
+ training = False
88
93
):
89
94
scale = q .shape [- 1 ] ** - 0.5
90
95
q = q * scale
@@ -131,7 +136,9 @@ def memory_efficient_attention(
131
136
mask_chunk ,
132
137
attn_bias_chunk ,
133
138
causal ,
134
- (q_start_index , k_start_index )
139
+ (q_start_index , k_start_index ),
140
+ dropout = dropout ,
141
+ training = training
135
142
)
136
143
137
144
exp_weights .append (exp_weight_chunk )
@@ -175,7 +182,7 @@ def __init__(
175
182
super ().__init__ ()
176
183
self .heads = heads
177
184
self .causal = causal
178
-
185
+ self . dropout = dropout
179
186
inner_dim = heads * dim_head
180
187
181
188
self .to_q = nn .Linear (dim , inner_dim , bias = False )
@@ -212,7 +219,8 @@ def forward(
212
219
213
220
attn_fn = attention if not memory_efficient else memory_efficient_attention
214
221
215
- out = attn_fn (q , k , v , mask = mask , attn_bias = attn_bias , causal = self .causal , q_bucket_size = q_bucket_size , k_bucket_size = k_bucket_size )
222
+ out = attn_fn (q , k , v , mask = mask , attn_bias = attn_bias , causal = self .causal , q_bucket_size = q_bucket_size ,
223
+ k_bucket_size = k_bucket_size , dropout = self .dropout , training = self .training )
216
224
217
225
out = rearrange (out , 'b h n d -> b n (h d)' )
218
226
return self .to_out (out )
0 commit comments