Skip to content

Commit 39a54d6

Browse files
committed
[Hybrid KV] Follow up UniformTypeKVCacheSpecs
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 12bcbd0 commit 39a54d6

File tree

3 files changed

+50
-54
lines changed

3 files changed

+50
-54
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
VLLM_USE_MODELSCOPE: True
8383
strategy:
8484
matrix:
85-
vllm_version: [6d8246aaffff3ebec84767e373212a7b8da328e2, v0.10.2]
85+
vllm_version: [v0.10.2]
8686
steps:
8787
- name: Install packages
8888
run: |
@@ -140,7 +140,7 @@ jobs:
140140
max-parallel: 2
141141
matrix:
142142
os: [linux-aarch64-a2-1]
143-
vllm_version: [6d8246aaffff3ebec84767e373212a7b8da328e2, v0.10.2]
143+
vllm_version: [v0.10.2]
144144
name: singlecard e2e test - light
145145
runs-on: ${{ matrix.os }}
146146
container:
@@ -206,7 +206,7 @@ jobs:
206206
max-parallel: 2
207207
matrix:
208208
os: [linux-aarch64-a2-2]
209-
vllm_version: [6d8246aaffff3ebec84767e373212a7b8da328e2, v0.10.2]
209+
vllm_version: [v0.10.2]
210210
name: multicard e2e test - light
211211
runs-on: ${{ matrix.os }}
212212
container:

.github/workflows/vllm_ascend_test_full.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
max-parallel: 2
7373
matrix:
7474
os: [linux-aarch64-a2-1]
75-
vllm_version: [6d8246aaffff3ebec84767e373212a7b8da328e2, v0.10.2]
75+
vllm_version: [v0.10.2]
7676
name: singlecard e2e test - full
7777
runs-on: ${{ matrix.os }}
7878
container:
@@ -156,7 +156,7 @@ jobs:
156156
max-parallel: 2
157157
matrix:
158158
os: [linux-aarch64-a2-2]
159-
vllm_version: [6d8246aaffff3ebec84767e373212a7b8da328e2, v0.10.2]
159+
vllm_version: [v0.10.2]
160160
name: multicard e2e test - full
161161
runs-on: ${{ matrix.os }}
162162
container:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from copy import deepcopy
2828
from dataclasses import dataclass
2929
from multiprocessing import Manager
30-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
30+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast
3131

3232
import numpy as np
3333
import numpy.typing as npt
@@ -72,8 +72,13 @@
7272
from vllm.v1.attention.backends.utils import \
7373
reorder_batch_to_split_decodes_and_prefills
7474
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
75-
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
76-
KVCacheConfig, KVCacheSpec, MambaSpec)
75+
# yapf conflicts with isort for this block
76+
# yapf: disable
77+
from vllm.v1.kv_cache_interface import (AttentionSpec,
78+
FullAttentionSpec, KVCacheConfig,
79+
KVCacheGroupSpec, KVCacheSpec,
80+
MambaSpec, UniformTypeKVCacheSpecs)
81+
# yapf: enable
7782
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
7883
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
7984
from vllm.v1.pool.metadata import PoolingMetadata
@@ -2576,10 +2581,10 @@ def initialize_kv_cache_tensors_deepseek(
25762581
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
25772582

25782583
kv_caches: Dict[str, torch.Tensor] = {}
2579-
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
2580-
):
2581-
attn_backend = kv_cache_group.backend
2582-
for layer_name in kv_cache_group.layer_names:
2584+
for group in self._kv_cache_spec_attn_group_iterator():
2585+
kv_cache_spec = group.kv_cache_spec
2586+
attn_backend = group.backend
2587+
for layer_name in group.layer_names:
25832588
if layer_name in self.runner_only_attn_layers:
25842589
continue
25852590
tensor_size = kv_cache_sizes[layer_name]
@@ -2721,10 +2726,11 @@ def initialize_kv_cache_tensors(
27212726
)), "Some layers are not correctly initialized"
27222727

27232728
kv_caches: Dict[str, torch.Tensor] = {}
2724-
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
2729+
for group in self._kv_cache_spec_attn_group_iterator(
27252730
):
2726-
attn_backend = kv_cache_group.backend
2727-
for layer_name in kv_cache_group.layer_names:
2731+
kv_cache_spec = group.kv_cache_spec
2732+
attn_backend = group.backend
2733+
for layer_name in group.layer_names:
27282734
if layer_name in self.runner_only_attn_layers:
27292735
continue
27302736

@@ -2821,7 +2827,7 @@ def initialize_kv_cache_tensors(
28212827

28222828
return kv_caches
28232829

2824-
def _kv_cache_spec_attn_group_iterator(
2830+
def _kv_cache_spec_attn_group_iterator_v0102(
28252831
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
28262832
if not self.kv_cache_config.kv_cache_groups:
28272833
return
@@ -2908,48 +2914,39 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
29082914
"""
29092915
assert len(self.attn_groups) == 0, \
29102916
"Attention backends are already initialized"
2911-
2912-
def get_attn_backends_for_layers(
2913-
layer_names: list[str]
2914-
) -> dict[type[AttentionBackend], list[str]]:
2915-
layers = get_layers_from_vllm_config(self.vllm_config,
2916-
AttentionLayerBase,
2917-
layer_names)
2917+
class AttentionGroupKey(NamedTuple):
2918+
attn_backend: type[AttentionBackend]
2919+
kv_cache_spec: KVCacheSpec
2920+
2921+
def get_attn_backends_for_group(
2922+
kv_cache_group_spec: KVCacheGroupSpec,
2923+
) -> dict[AttentionGroupKey, list[str]]:
2924+
layers = get_layers_from_vllm_config(
2925+
self.vllm_config, AttentionLayerBase,
2926+
kv_cache_group_spec.layer_names)
29182927
attn_backends = {}
29192928
attn_backend_layers = defaultdict(list)
29202929
# Dedupe based on full class name; this is a bit safer than
29212930
# using the class itself as the key because when we create dynamic
29222931
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
29232932
# they are cached correctly, there will be different objects per
29242933
# layer.
2925-
for layer_name in layer_names:
2934+
for layer_name in kv_cache_group_spec.layer_names:
29262935
attn_backend = layers[layer_name].get_attn_backend()
2927-
key = attn_backend.full_cls_name()
2928-
attn_backends[key] = attn_backend
2936+
full_cls_name = attn_backend.full_cls_name()
2937+
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
2938+
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
2939+
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
2940+
layer_name]
2941+
key = (full_cls_name, layer_kv_cache_spec)
2942+
attn_backends[key] = AttentionGroupKey(attn_backend,
2943+
layer_kv_cache_spec)
29292944
attn_backend_layers[key].append(layer_name)
29302945
return {
29312946
attn_backends[k]: v
29322947
for k, v in attn_backend_layers.items()
29332948
}
29342949

2935-
def create_attn_groups_v0102(
2936-
attn_backends_map: dict[AttentionBackend, list[str]],
2937-
kv_cache_spec: KVCacheSpec,
2938-
) -> list[AttentionGroup]:
2939-
attn_groups: list[AttentionGroup] = []
2940-
for attn_backend, layer_names in attn_backends_map.items():
2941-
attn_metadata_builder_i = attn_backend.get_builder_cls()(
2942-
kv_cache_spec,
2943-
layer_names,
2944-
self.vllm_config,
2945-
self.device,
2946-
)
2947-
attn_group = AttentionGroup(attn_backend,
2948-
attn_metadata_builder_i,
2949-
layer_names)
2950-
attn_groups.append(attn_group)
2951-
return attn_groups
2952-
29532950
def create_attn_groups(
29542951
attn_backends_map: dict[AttentionBackend, list[str]],
29552952
kv_cache_spec: KVCacheSpec,
@@ -2965,27 +2962,26 @@ def create_attn_groups(
29652962
))
29662963
attn_group = AttentionGroup(attn_backend,
29672964
attn_metadata_builders,
2968-
layer_names)
2965+
layer_names, kv_cache_spec)
29692966
attn_groups.append(attn_group)
29702967
return attn_groups
29712968

29722969
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
2973-
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
2974-
attn_backends = get_attn_backends_for_layers(
2975-
kv_cache_group_spec.layer_names)
2976-
if vllm_version_is("0.10.2"):
2977-
self.attn_groups.append(
2978-
create_attn_groups_v0102(attn_backends, kv_cache_spec))
2979-
else:
2980-
self.attn_groups.append(
2981-
create_attn_groups(attn_backends, kv_cache_spec))
2970+
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
2971+
self.attn_groups.append(create_attn_groups(attn_backends))
29822972

29832973
# Calculate reorder batch threshold (if needed)
29842974
self.calculate_reorder_batch_threshold()
29852975

29862976
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
29872977
return itertools.chain.from_iterable(self.attn_groups)
29882978

2979+
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
2980+
if not self.kv_cache_config.kv_cache_groups:
2981+
return
2982+
for attn_groups in self.attn_groups:
2983+
yield from attn_groups
2984+
29892985
def calculate_reorder_batch_threshold(self) -> None:
29902986
"""
29912987
Check that if any backends reorder batches; that the reordering

0 commit comments

Comments
 (0)