Skip to content

Commit 9a9c370

Browse files
authored
Merge pull request #19 from lucidrains/prefix-full-attention
add ability to specify full attention for a prefix length of the sequ…
2 parents e0f3724 + 976636f commit 9a9c370

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ def __init__(
254254
reversible = False,
255255
attn_dropout = 0.,
256256
ff_dropout = 0,
257-
sparse_attn = False
257+
sparse_attn = False,
258+
noncausal_attn_len = 0,
258259
):
259260
super().__init__()
260261
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
@@ -294,6 +295,7 @@ def __init__(
294295
reversible = reversible,
295296
attn_dropout = attn_dropout,
296297
ff_dropout = ff_dropout,
298+
noncausal_attn_len = noncausal_attn_len,
297299
sparse_attn = sparse_attn,
298300
sparse_attn_global_indices = range(text_seq_len)
299301
)

dalle_pytorch/transformer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ def forward(self, x):
5454
return self.net(x)
5555

5656
class Attention(nn.Module):
57-
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
57+
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., noncausal_attn_len = 0):
5858
super().__init__()
5959
inner_dim = dim_head * heads
6060
self.heads = heads
6161
self.seq_len = seq_len
6262
self.scale = dim ** -0.5
63+
6364
self.causal = causal
65+
self.noncausal_attn_len = noncausal_attn_len
6466

6567
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
6668
self.to_out = nn.Sequential(
@@ -84,6 +86,11 @@ def forward(self, x, mask = None):
8486
if self.causal:
8587
i, j = dots.shape[-2:]
8688
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
89+
90+
if self.noncausal_attn_len > 0:
91+
ind = slice(0, self.noncausal_attn_len)
92+
mask[ind, ind] = False
93+
8794
dots.masked_fill_(mask, mask_value)
8895

8996
attn = dots.softmax(dim=-1)
@@ -146,6 +153,10 @@ def forward(self, x, mask = None):
146153
mask_value = -(torch.finfo(q.dtype).max / 2)
147154
attn_mask.masked_fill_(mask, mask_value)
148155

156+
if self.noncausal_attn_len:
157+
ind = slice(0, self.noncausal_attn_len)
158+
attn_mask[ind, ind] = 0.
159+
149160
out = self.attn_fn(q, k, v, attn_mask = attn_mask, key_padding_mask = key_pad_mask)
150161
out = rearrange(out, 'b h n d -> b n (h d)')
151162
out = self.to_out(out)
@@ -165,6 +176,7 @@ def __init__(
165176
ff_mult = 4,
166177
attn_dropout = 0.,
167178
ff_dropout = 0.,
179+
noncausal_attn_len = 0,
168180
sparse_attn = True,
169181
sparse_attn_global_indices = []
170182
):
@@ -176,7 +188,7 @@ def __init__(
176188
attn_class = Attention if not sparse_attn else partial(SparseAttention, sparse_attn_global_indices = sparse_attn_global_indices)
177189

178190
layers.append(nn.ModuleList([
179-
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
191+
PreNorm(dim, attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout, noncausal_attn_len = noncausal_attn_len)),
180192
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
181193
]))
182194

0 commit comments

Comments
 (0)