diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 509619c7ec6..585d8e63db7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch @@ -21,9 +22,10 @@ logger = init_logger(__name__) -class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], - KVConnectorMetadata): - pass +@dataclass +class MultiKVConnectorMetadata(KVConnectorMetadata): + metadata: tuple[KVConnectorMetadata, ...] + extra_async_saves: Optional[dict[str, int]] = None class MultiConnector(KVConnectorBase_V1): @@ -46,6 +48,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # Keeps track of *additional* remaining async saves (beyond 1) to be # finished per request. Not needed for async loads since we only allow # a single connector to load. + # Propagated from scheduler to worker side via the connector metadata. self._extra_async_saves: dict[str, int] = {} def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): @@ -58,7 +61,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def bind_connector_metadata( self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) - for c, cm in zip(self._connectors, connector_metadata): + if connector_metadata.extra_async_saves: + self._extra_async_saves.update( + connector_metadata.extra_async_saves) + for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) def clear_connector_metadata(self) -> None: @@ -144,8 +150,13 @@ def update_state_after_alloc(self, request: "Request", def build_connector_meta( self, scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: - return MultiKVConnectorMetadata( - c.build_connector_meta(scheduler_output) for c in self._connectors) + metadata = MultiKVConnectorMetadata(metadata=tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) + if self._extra_async_saves: + metadata.extra_async_saves = self._extra_async_saves + self._extra_async_saves = {} + return metadata def request_finished( self,