1
1
# SPDX-License-Identifier: Apache-2.0
2
2
import copy
3
+ from dataclasses import dataclass
3
4
from typing import TYPE_CHECKING , Any , Optional
4
5
5
6
import torch
21
22
logger = init_logger (__name__ )
22
23
23
24
24
- class MultiKVConnectorMetadata (tuple [KVConnectorMetadata , ...],
25
- KVConnectorMetadata ):
26
- pass
25
+ @dataclass
26
+ class MultiKVConnectorMetadata (KVConnectorMetadata ):
27
+ metadata : tuple [KVConnectorMetadata , ...]
28
+ extra_async_saves : Optional [dict [str , int ]] = None
27
29
28
30
29
31
class MultiConnector (KVConnectorBase_V1 ):
@@ -46,6 +48,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
46
48
# Keeps track of *additional* remaining async saves (beyond 1) to be
47
49
# finished per request. Not needed for async loads since we only allow
48
50
# a single connector to load.
51
+ # Propagated from scheduler to worker side via the connector metadata.
49
52
self ._extra_async_saves : dict [str , int ] = {}
50
53
51
54
def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
@@ -58,7 +61,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
58
61
def bind_connector_metadata (
59
62
self , connector_metadata : KVConnectorMetadata ) -> None :
60
63
assert isinstance (connector_metadata , MultiKVConnectorMetadata )
61
- for c , cm in zip (self ._connectors , connector_metadata ):
64
+ if connector_metadata .extra_async_saves :
65
+ self ._extra_async_saves .update (
66
+ connector_metadata .extra_async_saves )
67
+ for c , cm in zip (self ._connectors , connector_metadata .metadata ):
62
68
c .bind_connector_metadata (cm )
63
69
64
70
def clear_connector_metadata (self ) -> None :
@@ -144,8 +150,13 @@ def update_state_after_alloc(self, request: "Request",
144
150
def build_connector_meta (
145
151
self ,
146
152
scheduler_output : SchedulerOutput ) -> MultiKVConnectorMetadata :
147
- return MultiKVConnectorMetadata (
148
- c .build_connector_meta (scheduler_output ) for c in self ._connectors )
153
+ metadata = MultiKVConnectorMetadata (metadata = tuple (
154
+ c .build_connector_meta (scheduler_output )
155
+ for c in self ._connectors ))
156
+ if self ._extra_async_saves :
157
+ metadata .extra_async_saves = self ._extra_async_saves
158
+ self ._extra_async_saves = {}
159
+ return metadata
149
160
150
161
def request_finished (
151
162
self ,
0 commit comments