Skip to content

Commit 6a0c5cd

Browse files
Merge remote-tracking branch 'nm-fork/multi-kv-connectors' into llm-d-launch-branch
2 parents ed1af74 + 59ba66f commit 6a0c5cd

File tree

4 files changed

+416
-2
lines changed

4 files changed

+416
-2
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import filecmp
3+
import shutil
4+
import tempfile
5+
from collections import defaultdict
6+
from pathlib import Path
7+
8+
from vllm import LLM, SamplingParams
9+
from vllm.config import KVTransferConfig, VllmConfig
10+
from vllm.distributed.kv_transfer.kv_connector.factory import (
11+
KVConnectorFactory)
12+
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
13+
SharedStorageConnector)
14+
15+
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
16+
17+
PROMPT_CONTEXT = "Hi " * 100
18+
PROMPTS = [
19+
PROMPT_CONTEXT + "Hello, my name is",
20+
PROMPT_CONTEXT + "The capital of France is",
21+
]
22+
23+
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
24+
25+
26+
class TestSharedStorageConnector(SharedStorageConnector):
27+
28+
def __init__(self, config: VllmConfig, role):
29+
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
30+
self._connector = SharedStorageConnector(config, role)
31+
self.call_record: dict[str, int] = defaultdict(int)
32+
# Use a unique temp file per connector
33+
self._event_file = tempfile.gettempdir(
34+
) + f"/connector_{self.name}_events.log"
35+
# Start with an empty file
36+
with open(self._event_file, "w") as _:
37+
pass
38+
39+
def __getattribute__(self, name):
40+
if name in ("_connector", "call_record", "name", "_event_file",
41+
"__class__", "__dict__", "__getattribute__",
42+
"__init__"): # avoid recursion
43+
return object.__getattribute__(self, name)
44+
if not hasattr(self._connector, name):
45+
return object.__getattribute__(self, name)
46+
attr = getattr(self._connector, name)
47+
48+
# Intercept calls to the connector interface and write an event
49+
# for each one to a file, which can be read back in the main test proc.
50+
if callable(attr):
51+
52+
def wrapper(*args, **kwargs):
53+
self.call_record[name] += 1
54+
# Log the event as a line to the file
55+
try:
56+
with open(self._event_file, "a") as f:
57+
f.write(name + "\n")
58+
except Exception as e:
59+
print(f"[ERROR] Could not log event {name} "
60+
f"for {self.name}: {e}")
61+
return attr(*args, **kwargs)
62+
63+
return wrapper
64+
return attr
65+
66+
67+
KVConnectorFactory.register_connector("TestSharedStorageConnector",
68+
TestSharedStorageConnector.__module__,
69+
TestSharedStorageConnector.__name__)
70+
71+
72+
# Helper function to compare directories recursively
73+
def _compare_directories(dir1: Path, dir2: Path) -> bool:
74+
"""Compares two directories recursively for identical content."""
75+
dcmp = filecmp.dircmp(dir1, dir2)
76+
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
77+
print(f"Differences found between {dir1} and {dir2}:")
78+
print(f" Left only: {dcmp.left_only}")
79+
print(f" Right only: {dcmp.right_only}")
80+
print(f" Different files: {dcmp.diff_files}")
81+
return False
82+
for sub_dir in dcmp.common_dirs:
83+
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
84+
return False
85+
return True
86+
87+
88+
def test_multi_shared_storage_connector_consistency():
89+
"""
90+
Tests that MultiConnector with two SharedStorageConnectors saves
91+
identical KV cache data to separate storage locations.
92+
"""
93+
storage_1_path = Path("storage_1/")
94+
storage_2_path = Path("storage_2/")
95+
shutil.rmtree(storage_1_path, ignore_errors=True)
96+
shutil.rmtree(storage_2_path, ignore_errors=True)
97+
storage_1_path.mkdir()
98+
storage_2_path.mkdir()
99+
100+
# Configure MultiConnector with two SharedStorageConnectors
101+
kv_transfer_config = KVTransferConfig(
102+
kv_connector="MultiConnector",
103+
kv_role="kv_both",
104+
kv_connector_extra_config={
105+
"connectors": [{
106+
"kv_connector": "TestSharedStorageConnector",
107+
"kv_role": "kv_both",
108+
"kv_connector_extra_config": {
109+
"shared_storage_path": str(storage_1_path),
110+
"name": "storage1",
111+
}
112+
}, {
113+
"kv_connector": "TestSharedStorageConnector",
114+
"kv_role": "kv_both",
115+
"kv_connector_extra_config": {
116+
"shared_storage_path": str(storage_2_path),
117+
"name": "storage2",
118+
}
119+
}]
120+
},
121+
)
122+
123+
llm = LLM(
124+
model=MODEL_NAME,
125+
enforce_eager=True,
126+
gpu_memory_utilization=0.5,
127+
kv_transfer_config=kv_transfer_config,
128+
)
129+
# Run generation - this should trigger saving KV cache
130+
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
131+
132+
# --- Verification ---
133+
134+
# Check that both storage directories were populated
135+
local_subdirs = list(storage_1_path.iterdir())
136+
external_subdirs = list(storage_2_path.iterdir())
137+
138+
assert len(
139+
local_subdirs
140+
) > 0, f"Local storage path {storage_1_path} is empty after generation."
141+
assert len(external_subdirs) > 0, (
142+
f"External storage path {storage_2_path} is empty after generation.")
143+
assert len(local_subdirs) == len(external_subdirs), (
144+
f"Mismatch in number of cache entries: "
145+
f"Local={len(local_subdirs)}, External={len(external_subdirs)}")
146+
147+
# The subdirectories should correspond to the prompt hashes
148+
# Since prompts are the same, the hash directories should be the same name
149+
local_subdir_names = sorted([d.name for d in local_subdirs])
150+
external_subdir_names = sorted([d.name for d in external_subdirs])
151+
assert local_subdir_names == external_subdir_names, (
152+
"Cache directory names do not match between local and external storage"
153+
)
154+
155+
# Compare the contents of each corresponding cache directory
156+
for subdir_name in local_subdir_names:
157+
print(f"Comparing contents of cache directory: {subdir_name}")
158+
assert _compare_directories(storage_1_path / subdir_name,
159+
storage_2_path / subdir_name), \
160+
(f"Contents differ for cache directory '{subdir_name}' between "
161+
f"{storage_1_path} and {storage_2_path}")
162+
163+
events = get_connector_events()
164+
# get_num_new_matched_tokens will be called on each connector in turn.
165+
# neither of them have hits so update_state_after_alloc won't be called.
166+
assert events["storage1"][:3] == [
167+
'get_num_new_matched_tokens', 'build_connector_meta',
168+
'bind_connector_metadata'
169+
]
170+
assert events["storage2"][:3] == [
171+
'get_num_new_matched_tokens', 'build_connector_meta',
172+
'bind_connector_metadata'
173+
]
174+
175+
# Reset prefix cache or else we'll just get the tokens back from there.
176+
llm.reset_prefix_cache()
177+
178+
# Run generation again - this should trigger loading from the first
179+
# connector.
180+
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
181+
182+
events = get_connector_events()
183+
# get_num_new_matched_tokens will return new tokens from the first
184+
# connector so update_state_after_alloc will be called once blocks
185+
# are allocated for the first connector.
186+
# get_num_new_matched_tokens *won't* be called on the second connector
187+
# in this case.
188+
assert events["storage1"][:4] == [
189+
'get_num_new_matched_tokens', 'update_state_after_alloc',
190+
'build_connector_meta', 'bind_connector_metadata'
191+
]
192+
assert events["storage2"][:2] == [
193+
'build_connector_meta', 'bind_connector_metadata'
194+
]
195+
196+
# Delete storage1 connector state
197+
shutil.rmtree(storage_1_path)
198+
199+
# Reset prefix cache or else we'll just get the tokens back from there.
200+
llm.reset_prefix_cache()
201+
202+
# Run generation again - this should trigger loading from the first
203+
# connector.
204+
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
205+
206+
events = get_connector_events()
207+
# get_num_new_matched_tokens will be called for the first connector but it
208+
# won't have a hit so update_state_after_alloc won't be called.
209+
# get_num_new_matched_tokens will also be called on the second connector,
210+
# but it should have a hit so update_state_after_alloc will be called.
211+
assert events["storage1"][:3] == [
212+
'get_num_new_matched_tokens', 'build_connector_meta',
213+
'bind_connector_metadata'
214+
]
215+
assert events["storage2"][:4] == [
216+
'get_num_new_matched_tokens', 'update_state_after_alloc',
217+
'build_connector_meta', 'bind_connector_metadata'
218+
]
219+
220+
# Clean up
221+
shutil.rmtree(storage_1_path)
222+
shutil.rmtree(storage_2_path)
223+
224+
225+
def get_connector_events() -> dict[str, list[str]]:
226+
# Read in connector events and reset the files.
227+
import glob
228+
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
229+
connector_events = {}
230+
for fname in event_files:
231+
name = fname.split("connector_")[1].split("_events.log")[0]
232+
try:
233+
with open(fname, "r+") as f:
234+
connector_events[name] = [
235+
line.strip() for line in f if line.strip()
236+
]
237+
f.truncate(0)
238+
except Exception as e:
239+
print(f"[ERROR] Could not read connector events for {name}: {e}")
240+
241+
return connector_events

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,8 @@ def create_connector_v1(
110110
"NixlConnector",
111111
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
112112
"NixlConnector")
113+
114+
KVConnectorFactory.register_connector(
115+
"MultiConnector",
116+
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
117+
"MultiConnector")

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import enum
2424
from abc import ABC, abstractmethod
25-
from dataclasses import dataclass
2625
from typing import TYPE_CHECKING, Any, Optional
2726

2827
import torch
@@ -48,7 +47,6 @@ class KVConnectorRole(enum.Enum):
4847
WORKER = 1
4948

5049

51-
@dataclass
5250
class KVConnectorMetadata:
5351
"""
5452
Abstract Metadata used to communicate between the

0 commit comments

Comments
 (0)