Skip to content

Commit b6ebceb

Browse files
committed
add layerscale, from the recent training deeper vision transformers paper, as a ward against non-convergence at greater depths
1 parent 2b416d1 commit b6ebceb

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

dalle_pytorch/transformer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@ def cast_tuple(val, depth = 1):
2424

2525
# classes
2626

27+
# https://arxiv.org/abs/2103.17239
28+
class LayerScale(nn.Module):
29+
def __init__(self, dim, depth, fn):
30+
super().__init__()
31+
if depth <= 18:
32+
init_eps = 0.1
33+
elif depth > 18 and depth <= 24:
34+
init_eps = 1e-5
35+
else:
36+
init_eps = 1e-6
37+
38+
scale = torch.zeros(1, 1, dim).fill_(init_eps)
39+
self.scale = nn.Parameter(scale)
40+
self.fn = fn
41+
def forward(self, x, **kwargs):
42+
return self.fn(x, **kwargs) * self.scale
43+
2744
class PreNorm(nn.Module):
2845
def __init__(self, dim, fn):
2946
super().__init__()
@@ -77,7 +94,7 @@ def __init__(
7794
attn_types = cast_tuple(attn_types)
7895
attn_type_layer = islice(cycle(attn_types), depth)
7996

80-
for _, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
97+
for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
8198
if attn_type == 'full':
8299
attn_class = Attention
83100
elif attn_type == 'sparse':
@@ -92,8 +109,8 @@ def __init__(
92109
raise ValueError(f'attention type "{attn_type}" is not valid')
93110

94111
layers.append(nn.ModuleList([
95-
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
96-
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
112+
LayerScale(dim, ind + 1, PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
113+
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
97114
]))
98115

99116
execute_type = ReversibleSequence if reversible else SequentialSequence

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.7.3',
7+
version = '0.8.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)