Skip to content

Commit 06b7775

Browse files
committed
test out flash attention in GPT
1 parent 33fb78a commit 06b7775

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

memory_efficient_attention_pytorch/transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44

55
from einops import rearrange
6-
from memory_efficient_attention_pytorch import Attention
6+
from memory_efficient_attention_pytorch import FlashAttention, Attention
77
from memory_efficient_attention_pytorch.reversible import ReversibleSequence
88

99
def exists(val):
@@ -51,6 +51,7 @@ def __init__(
5151
heads = 8,
5252
ff_mult = 4,
5353
ff_chunks = 1,
54+
use_flash_attn = True,
5455
**kwargs
5556
):
5657
super().__init__()
@@ -59,10 +60,12 @@ def __init__(
5960
self.token_emb = nn.Embedding(num_tokens, dim)
6061
self.pos_emb = nn.Embedding(max_seq_len, dim)
6162

63+
attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True)
64+
6265
self.layers = nn.ModuleList([])
6366
for _ in range(depth):
6467
self.layers.append(nn.ModuleList([
65-
PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
68+
PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
6669
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
6770
]))
6871

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.23',
6+
version = '0.0.25',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
long_description_content_type = 'text/markdown',

tests/test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def test_flash_attn_gradients_equal():
9191
k = torch.randn(1, 8, 1024, 512).requires_grad_()
9292
v = torch.randn(1, 8, 1024, 512).requires_grad_()
9393

94-
o = attention(q, k, v, causal = False)
94+
mask = torch.ones(1, 1024).bool()
95+
96+
o = attention(q, k, v, mask = mask, causal = True)
9597
o.sum().backward()
9698

9799
dq_grad = q.grad.clone()
@@ -102,7 +104,7 @@ def test_flash_attn_gradients_equal():
102104
k.grad.zero_()
103105
v.grad.zero_()
104106

105-
flash_o = FlashAttentionFunction.apply(q, k, v, None, False, 64, 64)
107+
flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
106108
flash_o.sum().backward()
107109

108110
flash_dq_grad = q.grad.clone()

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
LEARNING_RATE = 2e-4
1919
VALIDATE_EVERY = 100
2020
GENERATE_EVERY = 500
21-
GENERATE_LENGTH = 4096
21+
GENERATE_LENGTH = 1024
2222
SEQ_LEN = 4096
2323

2424
# helpers
@@ -43,10 +43,10 @@ def decode_tokens(tokens):
4343
depth = 6,
4444
heads = 8,
4545
causal = True,
46-
memory_efficient = True,
4746
q_bucket_size = 256,
4847
k_bucket_size = 256,
49-
ff_chunks = 5
48+
ff_chunks = 5,
49+
use_flash_attn = True
5050
)
5151

5252
model = AutoregressiveWrapper(model)

0 commit comments

Comments
 (0)