Skip to content

Commit d041958

Browse files
committed
fix tests
1 parent d76fba8 commit d041958

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def attention(
5050

5151
# memory efficient attention
5252

53-
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout=0., training=False):
53+
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
5454
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
5555

5656
weight = einsum('b h i d, b h j d -> b h i j', q, k)
@@ -72,8 +72,9 @@ def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices
7272
weight = weight - weight_max
7373

7474
exp_weight = weight.exp()
75-
if training:
76-
exp_weight = F.dropout(exp_weight, p=dropout, training=training)
75+
76+
exp_weight = F.dropout(exp_weight, p = dropout)
77+
7778
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
7879

7980
return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
@@ -137,8 +138,7 @@ def memory_efficient_attention(
137138
attn_bias_chunk,
138139
causal,
139140
(q_start_index, k_start_index),
140-
dropout = dropout,
141-
training = training
141+
dropout if training else 0.
142142
)
143143

144144
exp_weights.append(exp_weight_chunk)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)