Skip to content

Commit e8c2d99

Browse files
committed
stability measure 3
1 parent f794ba6 commit e8c2d99

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

dalle_pytorch/attention.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,22 @@ def default(val, d):
2222
def max_neg_value(t):
2323
return -torch.finfo(t.dtype).max
2424

25+
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
26+
t = t / alpha
27+
t = t - torch.amax(t, dim = dim, keepdim = True)
28+
return (t * alpha).softmax(dim = dim)
29+
2530
# classes
2631

2732
class Attention(nn.Module):
28-
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
33+
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
2934
super().__init__()
3035
inner_dim = dim_head * heads
3136
self.heads = heads
3237
self.seq_len = seq_len
3338
self.scale = dim_head ** -0.5
3439

40+
self.stable = stable
3541
self.causal = causal
3642

3743
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -42,6 +48,8 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
4248

4349
def forward(self, x, mask = None):
4450
b, n, _, h, device = *x.shape, self.heads, x.device
51+
softmax = torch.softmax if not self.stable else stable_softmax
52+
4553
qkv = self.to_qkv(x).chunk(3, dim = -1)
4654
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
4755

@@ -60,7 +68,7 @@ def forward(self, x, mask = None):
6068
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
6169
dots.masked_fill_(mask, mask_value)
6270

63-
attn = dots.softmax(dim=-1)
71+
attn = softmax(dots, dim=-1)
6472

6573
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
6674
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -70,7 +78,7 @@ def forward(self, x, mask = None):
7078
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
7179

7280
class SparseConvCausalAttention(nn.Module):
73-
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., **kwargs):
81+
def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
7482
super().__init__()
7583
assert kernel_size % 2 == 1, 'kernel size must be odd'
7684

@@ -82,6 +90,8 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
8290
self.kernel_size = kernel_size
8391
self.dilation = dilation
8492

93+
self.stable = stable
94+
8595
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
8696

8797
self.to_out = nn.Sequential(
@@ -91,6 +101,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
91101

92102
def forward(self, x, mask = None):
93103
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
104+
softmax = torch.softmax if not self.stable else stable_softmax
94105

95106
img_seq_len = img_size ** 2
96107
text_len = seq_len + 1 - img_seq_len
@@ -121,7 +132,7 @@ def forward(self, x, mask = None):
121132
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
122133
dots_text.masked_fill_(text_causal_mask, mask_value)
123134

124-
attn_text = dots_text.softmax(dim = -1)
135+
attn_text = softmax(dots_text, dim = -1)
125136
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
126137

127138
# image attention
@@ -163,7 +174,7 @@ def forward(self, x, mask = None):
163174
dots = torch.cat((dots_image_to_text, dots_image), dim = -1)
164175
dots.masked_fill_(mask, mask_value)
165176

166-
attn = dots.softmax(dim = -1)
177+
attn = softmax(dots, dim = -1)
167178

168179
# aggregate
169180

@@ -185,7 +196,7 @@ def forward(self, x, mask = None):
185196
# sparse axial causal attention
186197

187198
class SparseAxialCausalAttention(nn.Module):
188-
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., **kwargs):
199+
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
189200
super().__init__()
190201
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
191202
self.axis = axis
@@ -196,6 +207,8 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
196207
self.scale = dim_head ** -0.5
197208
self.image_size = image_size
198209

210+
self.stable = stable
211+
199212
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
200213

201214
self.to_out = nn.Sequential(
@@ -205,6 +218,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
205218

206219
def forward(self, x, mask = None):
207220
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
221+
softmax = torch.softmax if not self.stable else stable_softmax
208222

209223
img_seq_len = img_size ** 2
210224
text_len = seq_len + 1 - img_seq_len
@@ -235,7 +249,7 @@ def forward(self, x, mask = None):
235249
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
236250
dots_text.masked_fill_(text_causal_mask, mask_value)
237251

238-
attn_text = dots_text.softmax(dim = -1)
252+
attn_text = softmax(dots_text, dim = -1)
239253
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
240254

241255
# image attention
@@ -267,7 +281,7 @@ def forward(self, x, mask = None):
267281

268282
# attention.
269283

270-
attn = dots.softmax(dim = -1)
284+
attn = softmax(dots, dim = -1)
271285

272286
# aggregate
273287

dalle_pytorch/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,15 @@ def __init__(
108108

109109
for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
110110
if attn_type == 'full':
111-
attn_class = Attention
111+
attn_class = partial(Attention, stable = stable)
112112
elif attn_type == 'sparse':
113113
attn_class = SparseAttention
114114
elif attn_type == 'axial_row':
115-
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size)
115+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
116116
elif attn_type == 'axial_col':
117-
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size)
117+
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
118118
elif attn_type == 'conv_like':
119-
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size)
119+
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
120120
elif attn_type == 'mlp':
121121
attn_class = partial(gMLPBlock, seq_len = seq_len)
122122
else:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.12.2',
7+
version = '0.12.4',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)