Skip to content

Commit 9db7dc8

Browse files
committed
chunk feedforward for entirely memory efficient transformer
1 parent df75876 commit 9db7dc8

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

memory_efficient_attention_pytorch/transformer.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,24 @@ def forward(self, x, **kwargs):
1919
x = self.norm(x)
2020
return self.fn(x, **kwargs)
2121

22-
def FeedForward(dim, mult = 4):
23-
return nn.Sequential(
24-
nn.Linear(dim, dim * mult),
25-
nn.GELU(),
26-
nn.Linear(dim * mult, dim)
27-
)
22+
class FeedForward(nn.Module):
23+
def __init__(self, dim, mult = 4, chunks = 1):
24+
super().__init__()
25+
self.chunks = chunks
26+
27+
self.net = nn.Sequential(
28+
nn.Linear(dim, dim * mult),
29+
nn.GELU(),
30+
nn.Linear(dim * mult, dim)
31+
)
32+
33+
def forward(self, x):
34+
if self.chunks <= 1:
35+
return self.net(x)
36+
37+
chunks = x.chunk(self.chunks, dim = 1)
38+
out = [self.net(chunk) for chunk in chunks]
39+
return torch.cat(out, dim = 1)
2840

2941
class Transformer(nn.Module):
3042
def __init__(
@@ -38,6 +50,7 @@ def __init__(
3850
dim_head = 64,
3951
heads = 8,
4052
ff_mult = 4,
53+
ff_chunks = 1,
4154
**kwargs
4255
):
4356
super().__init__()
@@ -50,7 +63,7 @@ def __init__(
5063
for _ in range(depth):
5164
self.layers.append(nn.ModuleList([
5265
PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
53-
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult)),
66+
PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
5467
]))
5568

5669
self.net = ReversibleSequence(self.layers)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.14',
6+
version = '0.0.15',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
LEARNING_RATE = 2e-4
1919
VALIDATE_EVERY = 100
2020
GENERATE_EVERY = 500
21-
GENERATE_LENGTH = 2048
21+
GENERATE_LENGTH = 4096
2222
SEQ_LEN = 4096
2323

2424
# helpers
@@ -40,12 +40,13 @@ def decode_tokens(tokens):
4040
num_tokens = 256,
4141
dim = 512,
4242
max_seq_len = SEQ_LEN,
43-
depth = 8,
43+
depth = 6,
4444
heads = 8,
4545
causal = True,
4646
memory_efficient = True,
47-
q_bucket_size = 512,
48-
k_bucket_size = 512
47+
q_bucket_size = 256,
48+
k_bucket_size = 256,
49+
ff_chunks = 5
4950
)
5051

5152
model = AutoregressiveWrapper(model)

0 commit comments

Comments
 (0)