Skip to content

[CI] Make AttentionBackend interface compatible to fix broken CI #1893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import torch

from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import \
AscendAttentionBackendImpl092 # isort: skip
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata,
CommonAttentionState)
from vllm_ascend.utils import vllm_version_is


class TestAscendAttentionBackend(TestBase):
Expand All @@ -17,8 +20,12 @@ def test_get_name(self):
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")

def test_get_impl_cls(self):
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
AscendAttentionBackendImpl)
if vllm_version_is("0.9.2"):
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
AscendAttentionBackendImpl092)
else:
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
AscendAttentionBackendImpl)

def test_get_metadata_cls(self):
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),
Expand Down
37 changes: 35 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
nd_to_nz_2d, nd_to_nz_spec, vllm_version_is)


class AscendAttentionBackend(AttentionBackend):
Expand All @@ -43,6 +43,8 @@ def get_name() -> str:

@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
if vllm_version_is("0.9.2"):
return AscendAttentionBackendImpl092
return AscendAttentionBackendImpl

@staticmethod
Expand Down Expand Up @@ -222,7 +224,6 @@ def __init__(
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
Expand Down Expand Up @@ -437,6 +438,38 @@ def forward(
return output.view(num_tokens, self.hidden_size)


class AscendAttentionBackendImpl092(AscendAttentionBackendImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window,
kv_cache_dtype=kv_cache_dtype,
logits_soft_cap=logits_soft_cap,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
use_irope=use_irope,
)


def unified_ascend_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
Expand Down
38 changes: 36 additions & 2 deletions vllm_ascend/attention/attention_v1_torchair.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d)
nd_to_nz_2d, vllm_version_is)


class AscendAttentionTorchairBackend(AttentionBackend):
Expand All @@ -41,6 +41,8 @@ def get_name() -> str:

@staticmethod
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
if vllm_version_is("0.9.2"):
return AscendAttentionTorchairBackendImpl092
return AscendAttentionTorchairBackendImpl

@staticmethod
Expand Down Expand Up @@ -333,7 +335,6 @@ def __init__(
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
Expand Down Expand Up @@ -501,3 +502,36 @@ def forward(
"to use ascend scheduler.")

return output.view(num_tokens, self.hidden_size)


class AscendAttentionTorchairBackendImpl092(AscendAttentionTorchairBackendImpl
):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window,
kv_cache_dtype=kv_cache_dtype,
logits_soft_cap=logits_soft_cap,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
use_irope=use_irope,
)
42 changes: 38 additions & 4 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar)

import numpy as np
import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
AttentionMetadata, AttentionType,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
Expand All @@ -20,7 +21,8 @@
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (npu_prefetch, npu_stream_switch,
npu_wait_tensor, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import InputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -66,6 +68,8 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,

@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
if vllm_version_is("0.9.2"):
return AscendMLAImpl092
return AscendMLAImpl


Expand Down Expand Up @@ -533,7 +537,6 @@ def __init__(
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str] = None,
Expand Down Expand Up @@ -1226,3 +1229,34 @@ def forward(
output[:num_decode_tokens] = output_decode

return output_padded


class AscendMLAImpl092(AscendMLAImpl):

def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
**kwargs) -> None:
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
sliding_window=sliding_window,
kv_cache_dtype=kv_cache_dtype,
logits_soft_cap=logits_soft_cap,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
use_irope=use_irope,
**kwargs)
Loading