Skip to content

Commit 6ef548d

Browse files
qyqc731Angazenn
authored andcommitted
fix by review
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
1 parent 8bdda1d commit 6ef548d

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,17 @@ def __init__(
4141
dtype: torch.dtype,
4242
device: torch.device = None,
4343
):
44+
# NOTE: The device argument specifies the target NPU
45+
# to be used for the newly added FIA operator.
46+
# Only pass this parameter when using the new FIA operator.
47+
4448
attn_mask = _generate_attn_mask(max_seq_len, dtype)
4549

4650
self._seq_len_cached = attn_mask.shape[0]
4751
self.attn_mask_cache = attn_mask
4852
self.device = device
4953
if self.device:
50-
#NOTE: New compressed mask needs to be sent to certain device,
51-
# so device needs to be passed here.
54+
5255
assigned_mask_dim = 2048
5356
self.chunked_prefill_attn_mask = torch.triu(torch.ones(assigned_mask_dim, assigned_mask_dim), diagonal=1
5457
).to(torch.int8).to(device)

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ def _forward_v1_style(
460460
attn_metadata.seq_lens.to(device=query.device)
461461

462462
if self.compressed_mask:
463+
# TODO:The npu_fused_infer_attention_score op is planned to
464+
# be utilized in a wider range in upcoming versions.
463465
num_block, block_size, head_num, head_dim = self.key_cache.shape
464466
key = self.key_cache.view(num_block, block_size, -1)
465467
value = self.value_cache.view(num_block, block_size, -1)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def _make_attention_mask(self, seq_lens, position,
825825
attn_state) -> torch.Tensor:
826826
# Chunk Prefill situation.
827827
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
828-
if selkf.compressed_mas:
828+
if self.compressed_mask:
829829
return self.attn_mask_builder.get_splitfuse_attn_mask()
830830
else:
831831
return self.attn_mask_builder.get_splitfuse_attn_mask(

0 commit comments

Comments
 (0)