Skip to content

Commit 2caec59

Browse files
committed
update
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 3b9f409 commit 2caec59

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@
1717

1818
def _generate_attn_mask(max_seq_len, dtype):
1919
# Construct lower triangle matrix.
20-
mask_flag = torch.ones(
21-
(max_seq_len, max_seq_len),
22-
dtype=torch.bool).tril_().view(max_seq_len, max_seq_len)
20+
mask_flag = torch.ones((max_seq_len, max_seq_len),
21+
dtype=torch.bool).tril_()
2322
# Create upper triangle matrix used to mark mask positions.
2423
mask_flag = ~mask_flag
2524
# Currently for fp16 dtype, the mask value should be set to -inf.
2625
# TODO: Eliminate this part in the future.
27-
if dtype == torch.float16:
28-
mask_value = torch.finfo(torch.float32).min
29-
else:
30-
mask_value = 1
26+
mask_value = float('-inf') if dtype == torch.float16 else 1
3127
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
3228
.masked_fill_(mask_flag, mask_value)
3329
return attn_mask

0 commit comments

Comments
 (0)