Skip to content

Commit a573cd7

Browse files
committed
attn verificiation checks
1 parent 068138f commit a573cd7

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

fast_llm/layers/transformer/attention.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,7 @@ def _attn_fused(
182182
).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk)
183183

184184
attn_weights = attn_weights.to(torch.float32) * self._layer_index
185-
186-
attn_weights = attn_weights.transpose(2, 3)
187185
attn_weights = torch.where(mask, attn_weights, mask_value)
188-
attn_weights = attn_weights.transpose(2, 3)
189-
190186
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
191187

192188
with set_generator(self._tensor_space.distributed.tp_generator):
@@ -417,13 +413,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
417413
diff = input_ - flash_input_
418414
# print(f"Element-wise difference: {diff.shape} {diff}")
419415
max_diff = diff.abs().max()
420-
min_diff = diff.abs().min()
421-
print(f"Min element-wise difference: {min_diff.item()}")
422-
print(f"Max element-wise difference: {max_diff.item()}")
423-
# if max_diff > 1e-3:
424-
# print("Warning: Max difference exceeds 1e-3")
425-
# import sys
426-
# sys.exit(1)
416+
417+
if max_diff > 1e-3:
418+
print("Warning: Max difference exceeds 1e-3")
419+
print(f"Max element-wise difference: {max_diff.item()}")
427420

428421
if self._debug_transformer:
429422
self._debug_log(query, "query", self._QUERY_DIMS, kwargs)

fast_llm/models/gpt/model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
# TODO: Find a better solution.
5858
self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space))
5959

60-
if not self._config.transformer.diffusion:
60+
if self._config.transformer.diffusion is None:
6161
if self._use_flash_attention:
6262
self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space))
6363
else:
@@ -355,12 +355,21 @@ def preprocess(
355355

356356
batch_size, seq_len = batch.token_ids.shape
357357
seq_len -= 1 # last token is dropped inputs
358+
# attention_mask = torch.ones(
359+
# (batch_size, 1, seq_len, seq_len),
360+
# dtype=torch.bool,
361+
# device=self._tensor_space.distributed.device,
362+
# )
363+
# kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1)
358364
attention_mask = torch.ones(
359-
(batch_size, 1, seq_len, seq_len),
365+
(seq_len, seq_len),
360366
dtype=torch.bool,
361367
device=self._tensor_space.distributed.device,
362368
)
363-
kwargs[TransformerKwargs.attention_mask] = attention_mask.unsqueeze(1).unsqueeze(1)
369+
kwargs[TransformerKwargs.attention_mask] = attention_mask[
370+
None, None, 0:seq_len, None, :seq_len
371+
]
372+
print(f"attention_mask: {kwargs[TransformerKwargs.attention_mask]}")
364373
# # kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(
365374
# # -10000.0, device=self._tensor_space.distributed.device
366375
# # )

0 commit comments

Comments
 (0)