Skip to content

Commit 64c1753

Browse files
MengqingCaoYikun
authored andcommitted
fix cann version barrier for npu_fused_infer_attention_score
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent b1b1fe7 commit 64c1753

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
self._seq_len_cached = attn_mask.shape[0]
5151
self.attn_mask_cache = attn_mask
5252
self.device = device
53-
if torch.version.cann.startswith("8.3"):
53+
if torch.version.cann > "8.3.RC1.alpha002":
5454
assigned_mask_dim = 2048
5555
self.chunked_prefill_attn_mask = torch.triu(
5656
torch.ones(assigned_mask_dim, assigned_mask_dim),
@@ -82,7 +82,7 @@ def get_splitfuse_attn_mask(
8282
dtype: torch.dtype = None,
8383
device: torch.device = None,
8484
) -> torch.Tensor:
85-
if torch.version.cann.startswith("8.3"):
85+
if torch.version.cann > "8.3.RC1.alpha002":
8686
return self.chunked_prefill_attn_mask
8787
else:
8888
if dtype not in [torch.float16, torch.bfloat16]:

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def _forward_v1_style(
483483
attn_metadata.seq_lens = \
484484
attn_metadata.seq_lens.to(device=query.device)
485485

486-
if torch.version.cann.startswith("8.3"):
486+
if torch.version.cann > "8.3.RC1.alpha002":
487487
# TODO:The npu_fused_infer_attention_score op is planned to
488488
# be utilized in a wider range in upcoming versions.
489489
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
@@ -613,7 +613,7 @@ def forward(
613613
output)
614614
# Normal V1 situation.
615615
else:
616-
if torch.version.cann.startswith("8.3"):
616+
if torch.version.cann > "8.3.RC1.alpha002":
617617
# npu_fused_infer_attention_score does not support cases
618618
# where query.shape[0] != attn_metadata.query_start_loc[-1].
619619
# Thus we need unpad it here.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
314314
use_mla=self.model_config.use_mla,
315315
)
316316

317-
if torch.version.cann.startswith("8.3"):
317+
if torch.version.cann > "8.3.RC1.alpha002":
318318
self.attn_mask_builder = AttentionMaskBuilder(
319319
self.scheduler_config.max_num_batched_tokens, self.dtype,
320320
self.device)
@@ -881,7 +881,7 @@ def _make_attention_mask(self, seq_lens, position,
881881
attn_state) -> torch.Tensor:
882882
# Chunk Prefill situation.
883883
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
884-
if torch.version.cann.startswith("8.3"):
884+
if torch.version.cann > "8.3.RC1.alpha002":
885885
return self.attn_mask_builder.get_splitfuse_attn_mask()
886886
else:
887887
return self.attn_mask_builder.get_splitfuse_attn_mask(

0 commit comments

Comments
 (0)