Skip to content

Commit 0407ffc

Browse files
Yikunwangxiyuan
authored andcommitted
Fix ut
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 6008f12 commit 0407ffc

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import torch
44

55
from tests.ut.base import TestBase
6+
from vllm_ascend.attention.attention_v1 import \
7+
AscendAttentionBackendImpl092 # isort: skip
68
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
79
AscendAttentionBackendImpl,
810
AscendAttentionMetadataBuilder,
911
AscendAttentionState,
1012
AscendMetadata,
1113
CommonAttentionState)
14+
from vllm_ascend.utils import vllm_version_is
1215

1316

1417
class TestAscendAttentionBackend(TestBase):
@@ -17,8 +20,12 @@ def test_get_name(self):
1720
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")
1821

1922
def test_get_impl_cls(self):
20-
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
21-
AscendAttentionBackendImpl)
23+
if vllm_version_is("0.9.2"):
24+
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
25+
AscendAttentionBackendImpl092)
26+
else:
27+
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
28+
AscendAttentionBackendImpl)
2229

2330
def test_get_metadata_cls(self):
2431
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),

vllm_ascend/attention/mla_v1.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,21 +1233,20 @@ def forward(
12331233

12341234
class AscendMLAImpl092(AscendMLAImpl):
12351235

1236-
def __init__(
1237-
self,
1238-
num_heads: int,
1239-
head_size: int,
1240-
scale: float,
1241-
num_kv_heads: int,
1242-
alibi_slopes: Optional[List[float]],
1243-
sliding_window: Optional[int],
1244-
kv_cache_dtype: str,
1245-
blocksparse_params: Optional[Dict[str, Any]] = None,
1246-
logits_soft_cap: Optional[float] = None,
1247-
attn_type: str = AttentionType.DECODER,
1248-
kv_sharing_target_layer_name: Optional[str] = None,
1249-
use_irope: bool = False,
1250-
) -> None:
1236+
def __init__(self,
1237+
num_heads: int,
1238+
head_size: int,
1239+
scale: float,
1240+
num_kv_heads: int,
1241+
alibi_slopes: Optional[List[float]],
1242+
sliding_window: Optional[int],
1243+
kv_cache_dtype: str,
1244+
blocksparse_params: Optional[Dict[str, Any]] = None,
1245+
logits_soft_cap: Optional[float] = None,
1246+
attn_type: str = AttentionType.DECODER,
1247+
kv_sharing_target_layer_name: Optional[str] = None,
1248+
use_irope: bool = False,
1249+
**kwargs) -> None:
12511250
super().__init__(
12521251
num_heads=num_heads,
12531252
head_size=head_size,
@@ -1260,4 +1259,4 @@ def __init__(
12601259
attn_type=attn_type,
12611260
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
12621261
use_irope=use_irope,
1263-
)
1262+
**kwargs)

0 commit comments

Comments
 (0)