@@ -256,6 +256,7 @@ def __init__(
256256 ff_dropout = 0 ,
257257 sparse_attn = False ,
258258 noncausal_attn_len = 0 ,
259+ ignore_index = - 100
259260 ):
260261 super ().__init__ ()
261262 assert isinstance (vae , DiscreteVAE ), 'vae must be an instance of DiscreteVAE'
@@ -279,7 +280,9 @@ def __init__(
279280 seq_len = text_seq_len + image_seq_len
280281 total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
281282 self .total_tokens = total_tokens
282-
283+
284+ self .noncausal_attn_len = noncausal_attn_len
285+
283286 self .vae = vae
284287 if exists (self .vae ):
285288 self .vae = vae
@@ -319,6 +322,8 @@ def __init__(
319322
320323 self .register_buffer ('logits_mask' , logits_mask )
321324
325+ self .ignore_index = ignore_index
326+
322327 @torch .no_grad ()
323328 @eval_decorator
324329 def generate_images (
@@ -404,9 +409,15 @@ def forward(
404409 return logits
405410
406411 assert exists (image ), 'when training, image must be supplied'
407-
412+ noncausal_attn_len = self . noncausal_attn_len
408413 offsetted_image = image + self .num_text_tokens
409414 labels = torch .cat ((text , offsetted_image ), dim = 1 )
415+
416+ if noncausal_attn_len > 0 :
417+ mask = torch .arange (seq_len , device = device )
418+ mask = mask < noncausal_attn_len
419+ labels .masked_fill_ (mask [None , :], - 100 ) # -100 is the ignore index for cross entropy loss
420+
410421 labels = F .pad (labels , (0 , 1 ), value = eos_token_id ) # last token predicts EOS
411422 loss = F .cross_entropy (rearrange (logits , 'b n c -> b c n' ), labels [:, 1 :])
412423 return loss
0 commit comments