Skip to content

Commit eb205d9

Browse files
[P/D][BugFix]Mooncake timeout release bug fix (#2899)
### What this PR does / why we need it? In the P node timeout release mechanism during PD separation, the req_id that requires timeout release is transmitted from the scheduler to the worker. If the KV cache between PDs is transferred too quickly, the P node's req_id may be released twice. The first release is when the D node notifies the P node that the KV cache has been pulled, and the second release is when the scheduler transmits the timeout release to the worker. To address this bug, an intermediate component is introduced to manage the release of req_ids. Pull kv and forward2 may occur one after the other in timing. The previous timeout defaulted to forward2 being before pull_kv. ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@f225ea7 --------- Signed-off-by: baxingpiaochong <771405853@qq.com>
1 parent 6995a7b commit eb205d9

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import types
88
import unittest
99
from collections import defaultdict, deque
10+
from typing import OrderedDict
1011
from unittest.mock import MagicMock, patch
1112

1213
import msgspec
@@ -34,7 +35,7 @@ def test_init_basic_properties(self):
3435
tracker = KVCacheTaskTracker()
3536
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
3637
self.assertIsInstance(tracker.finished_requests, set)
37-
self.assertIsInstance(tracker.delayed_free_requests, deque)
38+
self.assertIsInstance(tracker.delayed_free_requests, OrderedDict)
3839

3940

4041
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
@@ -495,18 +496,42 @@ def setUp(self):
495496
def test_update_done_task_count(self):
496497
self.assertEqual(len(self.tracker.finished_requests), 0)
497498
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
499+
self.assertEqual(len(self.tracker.record_finished_requests), 0)
498500

499501
current_time = time.time()
500502
self.tracker.add_delayed_request("req_1", current_time)
501503
result = self.tracker.delayed_free_requests
504+
result_record = self.tracker.record_finished_requests
502505
self.assertEqual(len(result), 1)
503-
self.assertEqual(result[0], ("req_1", current_time))
506+
self.assertEqual(result["req_1"], current_time)
507+
self.assertEqual(len(result_record), 0)
504508

505509
self.tracker.update_done_task_count("req_1")
506510
result_finished = self.tracker.finished_requests
507511
result_delayed = self.tracker.delayed_free_requests
512+
result_record = self.tracker.record_finished_requests
508513
self.assertEqual(result_finished, {"req_1"})
509514
self.assertEqual(len(result_delayed), 0)
515+
self.assertEqual(len(result_record), 0)
516+
517+
self.tracker.update_done_task_count("req_2")
518+
result_finished = self.tracker.finished_requests
519+
result_delayed = self.tracker.delayed_free_requests
520+
result_record = self.tracker.record_finished_requests
521+
self.assertEqual(result_finished, {"req_1", "req_2"})
522+
self.assertEqual(len(result_delayed), 0)
523+
self.assertEqual(len(result_record), 1)
524+
self.assertEqual(result_record, {"req_2"})
525+
526+
def test_updtate_add_delayed_request(self) -> None:
527+
self.tracker.update_done_task_count("req2")
528+
result_start_record = self.tracker.record_finished_requests
529+
self.assertEqual(len(result_start_record), 1)
530+
self.tracker.add_delayed_request("req2", time.time())
531+
result_delayed = self.tracker.delayed_free_requests
532+
result_end_record = self.tracker.record_finished_requests
533+
self.assertEqual(len(result_delayed), 0)
534+
self.assertEqual(len(result_end_record), 0)
510535

511536
def test_retrieve_expired_requests(self):
512537
current_time = time.time()
@@ -518,7 +543,7 @@ def test_retrieve_expired_requests(self):
518543
})
519544
result_delay = self.tracker.delayed_free_requests
520545
self.assertEqual(len(result_delay), 1)
521-
self.assertEqual(result_delay[0], ("req_2", current_time))
546+
self.assertIn("req_2", result_delay)
522547

523548
def test_duplicate_task_update(self):
524549
self.tracker.update_done_task_count("req1")

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Iterator
1212
from concurrent.futures import ThreadPoolExecutor
1313
from dataclasses import dataclass
14-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
14+
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
1515

1616
import msgspec
1717
import numpy as np
@@ -68,12 +68,16 @@ def __init__(self):
6868
# intentionally delayed. Each entry is a tuple of (request_id,
6969
# timestamp). If a request remains in this queue for too long, it will
7070
# be force-freed.
71-
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
71+
self.record_finished_requests: set[str] = set()
72+
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
7273

7374
def update_done_task_count(self, request_id: str):
7475
with self.done_task_lock:
7576
self.finished_requests.add(request_id)
76-
self._remove_delayed_requests(request_id)
77+
if request_id in self.delayed_free_requests:
78+
self._remove_delayed_requests(request_id)
79+
else:
80+
self.record_finished_requests.add(request_id)
7781

7882
def get_and_clear_finished_requests(self) -> set[str]:
7983
"""
@@ -91,18 +95,22 @@ def get_and_clear_finished_requests(self) -> set[str]:
9195
def add_delayed_request(self, request_id: str, delay_start_time: float):
9296
"""Add a delayed free request."""
9397
with self.done_task_lock:
94-
self.delayed_free_requests.append((request_id, delay_start_time))
98+
if request_id not in self.record_finished_requests:
99+
self.delayed_free_requests[request_id] = delay_start_time
100+
else:
101+
self.record_finished_requests.discard(request_id)
95102

96103
def _retrieve_expired_requests(self):
97104
"""Retrieve all expired delayed requests."""
98105
expired_requests: set[str] = set()
99106
# Free delayed requests if they exceed the timeout
100107
current_time = time.time()
101108
while self.delayed_free_requests:
102-
request_id, delay_start_time = self.delayed_free_requests[0]
109+
request_id = next(iter(self.delayed_free_requests))
110+
delay_start_time = self.delayed_free_requests[request_id]
103111
if (current_time - delay_start_time
104112
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
105-
self.delayed_free_requests.popleft()
113+
self.delayed_free_requests.popitem(last=False)
106114
expired_requests.add(request_id)
107115
logger.info("Force freed request: %s", request_id)
108116
else:
@@ -111,8 +119,7 @@ def _retrieve_expired_requests(self):
111119

112120
def _remove_delayed_requests(self, request_id: str):
113121
"""Remove all delayed free requests matching the given request_id."""
114-
self.delayed_free_requests = deque(
115-
(r, t) for r, t in self.delayed_free_requests if r != request_id)
122+
self.delayed_free_requests.pop(request_id)
116123

117124

118125
class KVCacheSendingThread(threading.Thread):

0 commit comments

Comments
 (0)