@@ -57,7 +57,7 @@ def __init__(
57
57
# TODO: Find a better solution.
58
58
self ._preprocessors .append (self ._config .transformer .rotary .build (self ._tensor_space ))
59
59
60
- if not self ._config .transformer .diffusion :
60
+ if self ._config .transformer .diffusion is None :
61
61
if self ._use_flash_attention :
62
62
self ._preprocessors .append (FlashAttnVarlenPreprocessor (self ._config .transformer , self ._tensor_space ))
63
63
else :
@@ -355,12 +355,21 @@ def preprocess(
355
355
356
356
batch_size , seq_len = batch .token_ids .shape
357
357
seq_len -= 1 # last token is dropped inputs
358
+ # attention_mask = torch.ones(
359
+ # (batch_size, 1, seq_len, seq_len),
360
+ # dtype=torch.bool,
361
+ # device=self._tensor_space.distributed.device,
362
+ # )
363
+ # kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1)
358
364
attention_mask = torch .ones (
359
- (batch_size , 1 , seq_len , seq_len ),
365
+ (seq_len , seq_len ),
360
366
dtype = torch .bool ,
361
367
device = self ._tensor_space .distributed .device ,
362
368
)
363
- kwargs [TransformerKwargs .attention_mask ] = attention_mask .unsqueeze (1 ).unsqueeze (1 )
369
+ kwargs [TransformerKwargs .attention_mask ] = attention_mask [
370
+ None , None , 0 :seq_len , None , :seq_len
371
+ ]
372
+ print (f"attention_mask: { kwargs [TransformerKwargs .attention_mask ]} " )
364
373
# # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(
365
374
# # -10000.0, device=self._tensor_space.distributed.device
366
375
# # )
0 commit comments