Skip to content

Commit 3032553

Browse files
authored
[BugFix] Fix multi async save in MultiConnector (#90)
The MultiKVConnector impl keeps track of cases where multiple connectors are async saving the same request, but this state needs to be shared from the scheduler side to the worker side. Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent f7faa01 commit 3032553

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import copy
3+
from dataclasses import dataclass
34
from typing import TYPE_CHECKING, Any, Optional
45

56
import torch
@@ -21,9 +22,10 @@
2122
logger = init_logger(__name__)
2223

2324

24-
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
25-
KVConnectorMetadata):
26-
pass
25+
@dataclass
26+
class MultiKVConnectorMetadata(KVConnectorMetadata):
27+
metadata: tuple[KVConnectorMetadata, ...]
28+
extra_async_saves: Optional[dict[str, int]] = None
2729

2830

2931
class MultiConnector(KVConnectorBase_V1):
@@ -46,6 +48,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
4648
# Keeps track of *additional* remaining async saves (beyond 1) to be
4749
# finished per request. Not needed for async loads since we only allow
4850
# a single connector to load.
51+
# Propagated from scheduler to worker side via the connector metadata.
4952
self._extra_async_saves: dict[str, int] = {}
5053

5154
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@@ -58,7 +61,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
5861
def bind_connector_metadata(
5962
self, connector_metadata: KVConnectorMetadata) -> None:
6063
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
61-
for c, cm in zip(self._connectors, connector_metadata):
64+
if connector_metadata.extra_async_saves:
65+
self._extra_async_saves.update(
66+
connector_metadata.extra_async_saves)
67+
for c, cm in zip(self._connectors, connector_metadata.metadata):
6268
c.bind_connector_metadata(cm)
6369

6470
def clear_connector_metadata(self) -> None:
@@ -144,8 +150,13 @@ def update_state_after_alloc(self, request: "Request",
144150
def build_connector_meta(
145151
self,
146152
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
147-
return MultiKVConnectorMetadata(
148-
c.build_connector_meta(scheduler_output) for c in self._connectors)
153+
metadata = MultiKVConnectorMetadata(metadata=tuple(
154+
c.build_connector_meta(scheduler_output)
155+
for c in self._connectors))
156+
if self._extra_async_saves:
157+
metadata.extra_async_saves = self._extra_async_saves
158+
self._extra_async_saves = {}
159+
return metadata
149160

150161
def request_finished(
151162
self,

0 commit comments

Comments
 (0)