Skip to content

Commit 80fba42

Browse files
committed
omit the prefix sections of the sequence undergoing full attention from the cross entropy loss
Merge pull request #20 from lucidrains/prefix-full-attention omit the prefix sections of the sequence undergoing full attention fr…
1 parent 945055a commit 80fba42

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
ff_dropout = 0,
257257
sparse_attn = False,
258258
noncausal_attn_len = 0,
259+
ignore_index = -100
259260
):
260261
super().__init__()
261262
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
@@ -279,7 +280,9 @@ def __init__(
279280
seq_len = text_seq_len + image_seq_len
280281
total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
281282
self.total_tokens = total_tokens
282-
283+
284+
self.noncausal_attn_len = noncausal_attn_len
285+
283286
self.vae = vae
284287
if exists(self.vae):
285288
self.vae = vae
@@ -319,6 +322,8 @@ def __init__(
319322

320323
self.register_buffer('logits_mask', logits_mask)
321324

325+
self.ignore_index = ignore_index
326+
322327
@torch.no_grad()
323328
@eval_decorator
324329
def generate_images(
@@ -404,9 +409,15 @@ def forward(
404409
return logits
405410

406411
assert exists(image), 'when training, image must be supplied'
407-
412+
noncausal_attn_len = self.noncausal_attn_len
408413
offsetted_image = image + self.num_text_tokens
409414
labels = torch.cat((text, offsetted_image), dim = 1)
415+
416+
if noncausal_attn_len > 0:
417+
mask = torch.arange(seq_len, device = device)
418+
mask = mask < noncausal_attn_len
419+
labels.masked_fill_(mask[None, :], -100) # -100 is the ignore index for cross entropy loss
420+
410421
labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS
411422
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])
412423
return loss

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'dalle-pytorch',
55
packages = find_packages(),
6-
version = '0.0.41',
6+
version = '0.0.43',
77
license='MIT',
88
description = 'DALL-E - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)