Skip to content

Commit a8b316a

Browse files
wangxiyuanYikun
andauthored
[CI] Make AttentionBackend interface compatible to fix broken CI (#1893)
vLLM commit vllm-project/vllm@752c6ad removed `blocksparse_params` for attention backend. This PR does the same change to make CI happy. - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@9499e26 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 54f2b31 commit a8b316a

File tree

4 files changed

+118
-10
lines changed

4 files changed

+118
-10
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import torch
44

55
from tests.ut.base import TestBase
6+
from vllm_ascend.attention.attention_v1 import \
7+
AscendAttentionBackendImpl092 # isort: skip
68
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
79
AscendAttentionBackendImpl,
810
AscendAttentionMetadataBuilder,
911
AscendAttentionState,
1012
AscendMetadata,
1113
CommonAttentionState)
14+
from vllm_ascend.utils import vllm_version_is
1215

1316

1417
class TestAscendAttentionBackend(TestBase):
@@ -17,8 +20,12 @@ def test_get_name(self):
1720
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")
1821

1922
def test_get_impl_cls(self):
20-
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
21-
AscendAttentionBackendImpl)
23+
if vllm_version_is("0.9.2"):
24+
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
25+
AscendAttentionBackendImpl092)
26+
else:
27+
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
28+
AscendAttentionBackendImpl)
2229

2330
def test_get_metadata_cls(self):
2431
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),

vllm_ascend/attention/attention_v1.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3333
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
34-
nd_to_nz_2d, nd_to_nz_spec)
34+
nd_to_nz_2d, nd_to_nz_spec, vllm_version_is)
3535

3636

3737
class AscendAttentionBackend(AttentionBackend):
@@ -43,6 +43,8 @@ def get_name() -> str:
4343

4444
@staticmethod
4545
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
46+
if vllm_version_is("0.9.2"):
47+
return AscendAttentionBackendImpl092
4648
return AscendAttentionBackendImpl
4749

4850
@staticmethod
@@ -222,7 +224,6 @@ def __init__(
222224
alibi_slopes: Optional[List[float]],
223225
sliding_window: Optional[int],
224226
kv_cache_dtype: str,
225-
blocksparse_params: Optional[Dict[str, Any]] = None,
226227
logits_soft_cap: Optional[float] = None,
227228
attn_type: str = AttentionType.DECODER,
228229
kv_sharing_target_layer_name: Optional[str] = None,
@@ -437,6 +438,38 @@ def forward(
437438
return output.view(num_tokens, self.hidden_size)
438439

439440

441+
class AscendAttentionBackendImpl092(AscendAttentionBackendImpl):
442+
443+
def __init__(
444+
self,
445+
num_heads: int,
446+
head_size: int,
447+
scale: float,
448+
num_kv_heads: int,
449+
alibi_slopes: Optional[List[float]],
450+
sliding_window: Optional[int],
451+
kv_cache_dtype: str,
452+
blocksparse_params: Optional[Dict[str, Any]] = None,
453+
logits_soft_cap: Optional[float] = None,
454+
attn_type: str = AttentionType.DECODER,
455+
kv_sharing_target_layer_name: Optional[str] = None,
456+
use_irope: bool = False,
457+
) -> None:
458+
super().__init__(
459+
num_heads=num_heads,
460+
head_size=head_size,
461+
scale=scale,
462+
num_kv_heads=num_kv_heads,
463+
alibi_slopes=alibi_slopes,
464+
sliding_window=sliding_window,
465+
kv_cache_dtype=kv_cache_dtype,
466+
logits_soft_cap=logits_soft_cap,
467+
attn_type=attn_type,
468+
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
469+
use_irope=use_irope,
470+
)
471+
472+
440473
def unified_ascend_attention_with_output(
441474
query: torch.Tensor,
442475
key: torch.Tensor,

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from vllm_ascend.attention.attention_v1 import AscendAttentionState
3131
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
32-
nd_to_nz_2d)
32+
nd_to_nz_2d, vllm_version_is)
3333

3434

3535
class AscendAttentionTorchairBackend(AttentionBackend):
@@ -41,6 +41,8 @@ def get_name() -> str:
4141

4242
@staticmethod
4343
def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]:
44+
if vllm_version_is("0.9.2"):
45+
return AscendAttentionTorchairBackendImpl092
4446
return AscendAttentionTorchairBackendImpl
4547

4648
@staticmethod
@@ -333,7 +335,6 @@ def __init__(
333335
alibi_slopes: Optional[List[float]],
334336
sliding_window: Optional[int],
335337
kv_cache_dtype: str,
336-
blocksparse_params: Optional[Dict[str, Any]] = None,
337338
logits_soft_cap: Optional[float] = None,
338339
attn_type: str = AttentionType.DECODER,
339340
kv_sharing_target_layer_name: Optional[str] = None,
@@ -501,3 +502,36 @@ def forward(
501502
"to use ascend scheduler.")
502503

503504
return output.view(num_tokens, self.hidden_size)
505+
506+
507+
class AscendAttentionTorchairBackendImpl092(AscendAttentionTorchairBackendImpl
508+
):
509+
510+
def __init__(
511+
self,
512+
num_heads: int,
513+
head_size: int,
514+
scale: float,
515+
num_kv_heads: int,
516+
alibi_slopes: Optional[List[float]],
517+
sliding_window: Optional[int],
518+
kv_cache_dtype: str,
519+
blocksparse_params: Optional[Dict[str, Any]] = None,
520+
logits_soft_cap: Optional[float] = None,
521+
attn_type: str = AttentionType.DECODER,
522+
kv_sharing_target_layer_name: Optional[str] = None,
523+
use_irope: bool = False,
524+
) -> None:
525+
super().__init__(
526+
num_heads=num_heads,
527+
head_size=head_size,
528+
scale=scale,
529+
num_kv_heads=num_kv_heads,
530+
alibi_slopes=alibi_slopes,
531+
sliding_window=sliding_window,
532+
kv_cache_dtype=kv_cache_dtype,
533+
logits_soft_cap=logits_soft_cap,
534+
attn_type=attn_type,
535+
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
536+
use_irope=use_irope,
537+
)

vllm_ascend/attention/mla_v1.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
2+
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
3+
TypeVar)
34

45
import numpy as np
56
import torch
67
import torch_npu
78
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
8-
AttentionMetadata,
9+
AttentionMetadata, AttentionType,
910
MLAAttentionImpl)
1011
from vllm.attention.backends.utils import PAD_SLOT_ID
1112
from vllm.config import get_current_vllm_config
@@ -20,7 +21,8 @@
2021
from vllm_ascend.multistream.context import get_multistream_comm_context
2122
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2223
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
23-
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
24+
from vllm_ascend.utils import (npu_prefetch, npu_stream_switch,
25+
npu_wait_tensor, vllm_version_is)
2426
from vllm_ascend.worker.npu_input_batch import InputBatch
2527

2628
if TYPE_CHECKING:
@@ -66,6 +68,8 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
6668

6769
@staticmethod
6870
def get_impl_cls() -> Type["MLAAttentionImpl"]:
71+
if vllm_version_is("0.9.2"):
72+
return AscendMLAImpl092
6973
return AscendMLAImpl
7074

7175

@@ -533,7 +537,6 @@ def __init__(
533537
alibi_slopes: Optional[list[float]],
534538
sliding_window: Optional[int],
535539
kv_cache_dtype: str,
536-
blocksparse_params: Optional[dict[str, Any]],
537540
logits_soft_cap: Optional[float],
538541
attn_type: str,
539542
kv_sharing_target_layer_name: Optional[str] = None,
@@ -1226,3 +1229,34 @@ def forward(
12261229
output[:num_decode_tokens] = output_decode
12271230

12281231
return output_padded
1232+
1233+
1234+
class AscendMLAImpl092(AscendMLAImpl):
1235+
1236+
def __init__(self,
1237+
num_heads: int,
1238+
head_size: int,
1239+
scale: float,
1240+
num_kv_heads: int,
1241+
alibi_slopes: Optional[List[float]],
1242+
sliding_window: Optional[int],
1243+
kv_cache_dtype: str,
1244+
blocksparse_params: Optional[Dict[str, Any]] = None,
1245+
logits_soft_cap: Optional[float] = None,
1246+
attn_type: str = AttentionType.DECODER,
1247+
kv_sharing_target_layer_name: Optional[str] = None,
1248+
use_irope: bool = False,
1249+
**kwargs) -> None:
1250+
super().__init__(
1251+
num_heads=num_heads,
1252+
head_size=head_size,
1253+
scale=scale,
1254+
num_kv_heads=num_kv_heads,
1255+
alibi_slopes=alibi_slopes,
1256+
sliding_window=sliding_window,
1257+
kv_cache_dtype=kv_cache_dtype,
1258+
logits_soft_cap=logits_soft_cap,
1259+
attn_type=attn_type,
1260+
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
1261+
use_irope=use_irope,
1262+
**kwargs)

0 commit comments

Comments
 (0)