Skip to content

Commit c37fbd2

Browse files
authored
Merge pull request #6 from usryokousha/main
Added dropout support to memory efficient variant
2 parents 8b013f6 + 70521cd commit c37fbd2

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import partial
33
from torch import nn, einsum
44
from torch.utils.checkpoint import checkpoint
5+
import torch.nn.functional as F
56

67
from einops import rearrange
78

@@ -49,7 +50,7 @@ def attention(
4950

5051
# memory efficient attention
5152

52-
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
53+
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout=0., training=False):
5354
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
5455

5556
weight = einsum('b h i d, b h j d -> b h i j', q, k)
@@ -71,6 +72,8 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
7172
weight = weight - weight_max
7273

7374
exp_weight = weight.exp()
75+
if training:
76+
exp_weight = F.dropout(exp_weight, p=dropout, training=training)
7477
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
7578

7679
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
@@ -84,7 +87,9 @@ def memory_efficient_attention(
8487
attn_bias = None,
8588
q_bucket_size = 512,
8689
k_bucket_size = 1024,
87-
eps = 1e-8
90+
eps = 1e-8,
91+
dropout = 0.,
92+
training = False
8893
):
8994
scale = q.shape[-1] ** -0.5
9095
q = q * scale
@@ -131,7 +136,9 @@ def memory_efficient_attention(
131136
mask_chunk,
132137
attn_bias_chunk,
133138
causal,
134-
(q_start_index, k_start_index)
139+
(q_start_index, k_start_index),
140+
dropout = dropout,
141+
training = training
135142
)
136143

137144
exp_weights.append(exp_weight_chunk)
@@ -175,7 +182,7 @@ def __init__(
175182
super().__init__()
176183
self.heads = heads
177184
self.causal = causal
178-
185+
self.dropout = dropout
179186
inner_dim = heads * dim_head
180187

181188
self.to_q = nn.Linear(dim, inner_dim, bias = False)
@@ -212,7 +219,8 @@ def forward(
212219

213220
attn_fn = attention if not memory_efficient else memory_efficient_attention
214221

215-
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
222+
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size,
223+
k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training)
216224

217225
out = rearrange(out, 'b h n d -> b n (h d)')
218226
return self.to_out(out)

0 commit comments

Comments
 (0)