27
27
from copy import deepcopy
28
28
from dataclasses import dataclass
29
29
from multiprocessing import Manager
30
- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union , cast
30
+ from typing import TYPE_CHECKING , Any , Dict , List , NamedTuple , Optional , Union , cast
31
31
32
32
import numpy as np
33
33
import numpy .typing as npt
72
72
from vllm .v1 .attention .backends .utils import \
73
73
reorder_batch_to_split_decodes_and_prefills
74
74
from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
75
- from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
76
- KVCacheConfig , KVCacheSpec , MambaSpec )
75
+ # yapf conflicts with isort for this block
76
+ # yapf: disable
77
+ from vllm .v1 .kv_cache_interface import (AttentionSpec ,
78
+ FullAttentionSpec , KVCacheConfig ,
79
+ KVCacheGroupSpec , KVCacheSpec ,
80
+ MambaSpec , UniformTypeKVCacheSpecs )
81
+ # yapf: enable
77
82
from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
78
83
DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
79
84
from vllm .v1 .pool .metadata import PoolingMetadata
@@ -2576,10 +2581,10 @@ def initialize_kv_cache_tensors_deepseek(
2576
2581
kv_cache_sizes [kv_cache_tensor .shared_by [0 ]] = kv_cache_tensor .size
2577
2582
2578
2583
kv_caches : Dict [str , torch .Tensor ] = {}
2579
- for kv_cache_spec , kv_cache_group in self ._kv_cache_spec_attn_group_iterator (
2580
- ):
2581
- attn_backend = kv_cache_group .backend
2582
- for layer_name in kv_cache_group .layer_names :
2584
+ for group in self ._kv_cache_spec_attn_group_iterator ():
2585
+ kv_cache_spec = group . kv_cache_spec
2586
+ attn_backend = group .backend
2587
+ for layer_name in group .layer_names :
2583
2588
if layer_name in self .runner_only_attn_layers :
2584
2589
continue
2585
2590
tensor_size = kv_cache_sizes [layer_name ]
@@ -2721,10 +2726,11 @@ def initialize_kv_cache_tensors(
2721
2726
)), "Some layers are not correctly initialized"
2722
2727
2723
2728
kv_caches : Dict [str , torch .Tensor ] = {}
2724
- for kv_cache_spec , kv_cache_group in self ._kv_cache_spec_attn_group_iterator (
2729
+ for group in self ._kv_cache_spec_attn_group_iterator (
2725
2730
):
2726
- attn_backend = kv_cache_group .backend
2727
- for layer_name in kv_cache_group .layer_names :
2731
+ kv_cache_spec = group .kv_cache_spec
2732
+ attn_backend = group .backend
2733
+ for layer_name in group .layer_names :
2728
2734
if layer_name in self .runner_only_attn_layers :
2729
2735
continue
2730
2736
@@ -2821,7 +2827,7 @@ def initialize_kv_cache_tensors(
2821
2827
2822
2828
return kv_caches
2823
2829
2824
- def _kv_cache_spec_attn_group_iterator (
2830
+ def _kv_cache_spec_attn_group_iterator_v0102 (
2825
2831
self ) -> Iterator [tuple [KVCacheSpec , AttentionGroup ]]:
2826
2832
if not self .kv_cache_config .kv_cache_groups :
2827
2833
return
@@ -2908,48 +2914,39 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
2908
2914
"""
2909
2915
assert len (self .attn_groups ) == 0 , \
2910
2916
"Attention backends are already initialized"
2911
-
2912
- def get_attn_backends_for_layers (
2913
- layer_names : list [str ]
2914
- ) -> dict [type [AttentionBackend ], list [str ]]:
2915
- layers = get_layers_from_vllm_config (self .vllm_config ,
2916
- AttentionLayerBase ,
2917
- layer_names )
2917
+ class AttentionGroupKey (NamedTuple ):
2918
+ attn_backend : type [AttentionBackend ]
2919
+ kv_cache_spec : KVCacheSpec
2920
+
2921
+ def get_attn_backends_for_group (
2922
+ kv_cache_group_spec : KVCacheGroupSpec ,
2923
+ ) -> dict [AttentionGroupKey , list [str ]]:
2924
+ layers = get_layers_from_vllm_config (
2925
+ self .vllm_config , AttentionLayerBase ,
2926
+ kv_cache_group_spec .layer_names )
2918
2927
attn_backends = {}
2919
2928
attn_backend_layers = defaultdict (list )
2920
2929
# Dedupe based on full class name; this is a bit safer than
2921
2930
# using the class itself as the key because when we create dynamic
2922
2931
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
2923
2932
# they are cached correctly, there will be different objects per
2924
2933
# layer.
2925
- for layer_name in layer_names :
2934
+ for layer_name in kv_cache_group_spec . layer_names :
2926
2935
attn_backend = layers [layer_name ].get_attn_backend ()
2927
- key = attn_backend .full_cls_name ()
2928
- attn_backends [key ] = attn_backend
2936
+ full_cls_name = attn_backend .full_cls_name ()
2937
+ layer_kv_cache_spec = kv_cache_group_spec .kv_cache_spec
2938
+ if isinstance (layer_kv_cache_spec , UniformTypeKVCacheSpecs ):
2939
+ layer_kv_cache_spec = layer_kv_cache_spec .kv_cache_specs [
2940
+ layer_name ]
2941
+ key = (full_cls_name , layer_kv_cache_spec )
2942
+ attn_backends [key ] = AttentionGroupKey (attn_backend ,
2943
+ layer_kv_cache_spec )
2929
2944
attn_backend_layers [key ].append (layer_name )
2930
2945
return {
2931
2946
attn_backends [k ]: v
2932
2947
for k , v in attn_backend_layers .items ()
2933
2948
}
2934
2949
2935
- def create_attn_groups_v0102 (
2936
- attn_backends_map : dict [AttentionBackend , list [str ]],
2937
- kv_cache_spec : KVCacheSpec ,
2938
- ) -> list [AttentionGroup ]:
2939
- attn_groups : list [AttentionGroup ] = []
2940
- for attn_backend , layer_names in attn_backends_map .items ():
2941
- attn_metadata_builder_i = attn_backend .get_builder_cls ()(
2942
- kv_cache_spec ,
2943
- layer_names ,
2944
- self .vllm_config ,
2945
- self .device ,
2946
- )
2947
- attn_group = AttentionGroup (attn_backend ,
2948
- attn_metadata_builder_i ,
2949
- layer_names )
2950
- attn_groups .append (attn_group )
2951
- return attn_groups
2952
-
2953
2950
def create_attn_groups (
2954
2951
attn_backends_map : dict [AttentionBackend , list [str ]],
2955
2952
kv_cache_spec : KVCacheSpec ,
@@ -2965,27 +2962,26 @@ def create_attn_groups(
2965
2962
))
2966
2963
attn_group = AttentionGroup (attn_backend ,
2967
2964
attn_metadata_builders ,
2968
- layer_names )
2965
+ layer_names , kv_cache_spec )
2969
2966
attn_groups .append (attn_group )
2970
2967
return attn_groups
2971
2968
2972
2969
for kv_cache_group_spec in kv_cache_config .kv_cache_groups :
2973
- kv_cache_spec = kv_cache_group_spec .kv_cache_spec
2974
- attn_backends = get_attn_backends_for_layers (
2975
- kv_cache_group_spec .layer_names )
2976
- if vllm_version_is ("0.10.2" ):
2977
- self .attn_groups .append (
2978
- create_attn_groups_v0102 (attn_backends , kv_cache_spec ))
2979
- else :
2980
- self .attn_groups .append (
2981
- create_attn_groups (attn_backends , kv_cache_spec ))
2970
+ attn_backends = get_attn_backends_for_group (kv_cache_group_spec )
2971
+ self .attn_groups .append (create_attn_groups (attn_backends ))
2982
2972
2983
2973
# Calculate reorder batch threshold (if needed)
2984
2974
self .calculate_reorder_batch_threshold ()
2985
2975
2986
2976
def _attn_group_iterator (self ) -> Iterator [AttentionGroup ]:
2987
2977
return itertools .chain .from_iterable (self .attn_groups )
2988
2978
2979
+ def _kv_cache_spec_attn_group_iterator (self ) -> Iterator [AttentionGroup ]:
2980
+ if not self .kv_cache_config .kv_cache_groups :
2981
+ return
2982
+ for attn_groups in self .attn_groups :
2983
+ yield from attn_groups
2984
+
2989
2985
def calculate_reorder_batch_threshold (self ) -> None :
2990
2986
"""
2991
2987
Check that if any backends reorder batches; that the reordering
0 commit comments