Skip to content

Commit f508e03

Browse files
authored
[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)
1 parent 51f86bf commit f508e03

File tree

6 files changed

+122
-67
lines changed

6 files changed

+122
-67
lines changed

vllm/core/scheduler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
cache_config: CacheConfig,
303303
lora_config: Optional[LoRAConfig],
304304
pipeline_parallel_size: int = 1,
305-
output_proc_callback_fn: Optional[Callable] = None,
305+
output_proc_callback: Optional[Callable] = None,
306306
) -> None:
307307
self.scheduler_config = scheduler_config
308308
self.cache_config = cache_config
@@ -376,8 +376,8 @@ def __init__(
376376
# iterations. I.e. since the output processing is lagged one step,
377377
# we cannot reuse the cached objects immediately when the schedule()
378378
# is called again, but only when schedule() is called the second time.
379-
self.output_proc_callback_fn = output_proc_callback_fn
380-
self.use_async_output_proc = self.output_proc_callback_fn is not None
379+
self.output_proc_callback = output_proc_callback
380+
self.use_async_output_proc = self.output_proc_callback is not None
381381
self.num_cache_iters = 2 if self.use_async_output_proc else 1
382382

383383
self.cache_id = 0
@@ -573,8 +573,8 @@ def _schedule_running(
573573
seq_group):
574574
tmp = self.running
575575
self.running = orig_running
576-
assert self.output_proc_callback_fn is not None
577-
self.output_proc_callback_fn(is_async=True)
576+
assert self.output_proc_callback is not None
577+
self.output_proc_callback()
578578
self.running = tmp
579579

580580
while not self._can_append_slots(seq_group):
@@ -1091,7 +1091,6 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
10911091
no_beam_search = seq_group.sampling_params is None or (
10921092
seq_group.sampling_params.best_of == 1
10931093
and not seq_group.sampling_params.use_beam_search)
1094-
10951094
return no_beam_search
10961095

10971096
def schedule(

vllm/engine/async_llm_engine.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,19 +279,26 @@ async def step_async(
279279
scheduler_outputs = cached_outputs.scheduler_outputs
280280
allow_async_output_proc = cached_outputs.allow_async_output_proc
281281

282+
ctx = self.scheduler_contexts[virtual_engine]
283+
282284
# skip the scheduler if there are any remaining steps in the seq groups.
283285
# This ensures that the scheduler is only called again when the current
284286
# batch has completed.
285287
if not self._has_remaining_steps(seq_group_metadata_list):
288+
289+
# Clear outputs on scheduler iteration start
290+
ctx.request_outputs.clear()
291+
286292
(seq_group_metadata_list, scheduler_outputs,
287293
allow_async_output_proc
288294
) = self.scheduler[virtual_engine].schedule()
289295

290296
# If current scheduler iteration has no async postprocessor,
291297
# then we need first to drain the pending async postprocessor
292298
# before moving forward
293-
if not allow_async_output_proc and len(self.output_queue) > 0:
294-
self._process_model_outputs(is_async=True)
299+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
300+
self._process_model_outputs(virtual_engine=virtual_engine,
301+
is_async=True)
295302

296303
if (self.scheduler_config.is_multi_step
297304
and scheduler_outputs.num_lookahead_slots > 0):
@@ -332,8 +339,8 @@ async def step_async(
332339
last_sampled_token_ids=last_sampled_token_ids)
333340

334341
if allow_async_output_proc:
335-
execute_model_req.output_proc_callback_fn = \
336-
self._process_model_outputs
342+
execute_model_req.async_callback = self.async_callback[
343+
virtual_engine]
337344

338345
# Execute the model.
339346
output = await self.model_executor.execute_model_async(
@@ -343,9 +350,10 @@ async def step_async(
343350
if self.scheduler_config.is_multi_step:
344351
self._update_cached_scheduler_output(virtual_engine, output)
345352
else:
346-
if len(self.output_queue) > 0:
353+
if len(ctx.output_queue) > 0:
347354
assert not self.scheduler_config.is_multi_step
348-
self._process_model_outputs(is_async=True)
355+
self._process_model_outputs(virtual_engine=virtual_engine,
356+
is_async=True)
349357
output = []
350358

351359
# Finish the current step for all the sequence groups.
@@ -360,7 +368,7 @@ async def step_async(
360368
virtual_engine] = SchedulerOutputState()
361369

362370
# Cache results in engine
363-
self.output_queue.append(
371+
ctx.output_queue.append(
364372
(output, seq_group_metadata_list, scheduler_outputs))
365373

366374
if output and allow_async_output_proc:
@@ -372,7 +380,8 @@ async def step_async(
372380
scheduler_outputs.scheduled_seq_groups)
373381

374382
if not allow_async_output_proc:
375-
self._process_model_outputs(is_async=False)
383+
self._process_model_outputs(virtual_engine=virtual_engine,
384+
is_async=False)
376385

377386
# Log stats.
378387
self.do_log_stats(scheduler_outputs, output)
@@ -381,9 +390,17 @@ async def step_async(
381390
self.do_tracing(scheduler_outputs)
382391

383392
else:
384-
self.request_outputs = []
393+
ctx.request_outputs = []
394+
395+
if not self.has_unfinished_requests():
396+
# Drain async postprocessor (if exists)
397+
if len(ctx.output_queue) > 0:
398+
assert not self.scheduler_config.is_multi_step
399+
self._process_model_outputs(virtual_engine=virtual_engine,
400+
is_async=True)
401+
assert len(ctx.output_queue) == 0
385402

386-
return self.request_outputs
403+
return ctx.request_outputs
387404

388405
async def stop_remote_worker_execution_loop_async(self) -> None:
389406
"""Stop the remote worker execution loop."""

vllm/engine/llm_engine.py

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import functools
12
import time
23
from collections import deque
34
from contextlib import contextmanager
4-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
56
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
67
Mapping, Optional)
78
from typing import Sequence as GenericSequence
@@ -88,6 +89,17 @@ class SchedulerOutputState:
8889
last_output: Optional[SamplerOutput] = None
8990

9091

92+
@dataclass
93+
class SchedulerContext:
94+
output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata],
95+
SchedulerOutputs]] = field(
96+
default_factory=lambda: deque())
97+
98+
request_outputs: List[Union[RequestOutput,
99+
EmbeddingRequestOutput]] = field(
100+
default_factory=lambda: [])
101+
102+
91103
class LLMEngine:
92104
"""An LLM engine that receives requests and generates texts.
93105
@@ -350,9 +362,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
350362
Scheduler(
351363
scheduler_config, cache_config, lora_config,
352364
parallel_config.pipeline_parallel_size,
353-
self._process_model_outputs
365+
functools.partial(self._process_model_outputs,
366+
virtual_engine=v_id,
367+
is_async=True)
354368
if model_config.use_async_output_proc else None)
355-
for _ in range(parallel_config.pipeline_parallel_size)
369+
for v_id in range(parallel_config.pipeline_parallel_size)
356370
]
357371

358372
# Metric Logging.
@@ -406,12 +420,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
406420
for _ in range(self.parallel_config.pipeline_parallel_size)
407421
]
408422

409-
# Async output processing pointers
410-
self.output_queue: Deque[Tuple[List[SamplerOutput],
411-
List[SequenceGroupMetadata],
412-
SchedulerOutputs]] = deque()
413-
self.request_outputs: List[Union[RequestOutput,
414-
EmbeddingRequestOutput]] = []
423+
self.scheduler_contexts = [
424+
SchedulerContext()
425+
for _ in range(self.parallel_config.pipeline_parallel_size)
426+
]
427+
428+
self.async_callback = [
429+
functools.partial(self._process_model_outputs,
430+
virtual_engine=v_id,
431+
is_async=True)
432+
for v_id in range(self.parallel_config.pipeline_parallel_size)
433+
]
415434

416435
def _initialize_kv_caches(self) -> None:
417436
"""Initialize the KV cache in the worker(s).
@@ -1221,32 +1240,28 @@ def _process_sequence_group_outputs(
12211240

12221241
return
12231242

1224-
def _process_model_outputs(self,
1225-
is_async: bool,
1226-
clear_outputs: bool = True) -> None:
1243+
def _process_model_outputs(self, virtual_engine: int,
1244+
is_async: bool) -> None:
12271245
"""Apply the model output to the sequences in the scheduled seq groups.
12281246
1247+
virtual_engine: The engine id to operate on
12291248
is_async: Indicates whether this postprocessor runs in
12301249
parallel with the GPU forward pass and is processing
12311250
tokens from the previous step. If this is true, then
12321251
no tokens need to be appended since it is already done
12331252
externally (before the next schedule() call)
1234-
clear_outputs: Sometimes existing outputs need to be combined
1235-
with outputs of this call. This happens for postprocessor
1236-
draining at the final stage (like when sequences are finished)
12371253
12381254
Returns RequestOutputs that can be returned to the client.
12391255
"""
12401256
now = time.time()
12411257

1242-
if clear_outputs:
1243-
self.request_outputs.clear()
1258+
ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
12441259

1245-
if len(self.output_queue) == 0:
1260+
if len(ctx.output_queue) == 0:
12461261
return None
12471262

12481263
(outputs, seq_group_metadata_list,
1249-
scheduler_outputs) = self.output_queue.popleft()
1264+
scheduler_outputs) = ctx.output_queue.popleft()
12501265

12511266
# Sanity check
12521267
assert len(seq_group_metadata_list) == len(
@@ -1321,11 +1336,11 @@ def _process_model_outputs(self,
13211336
if (seq_group.is_finished()
13221337
if self.step_return_finished_only else True):
13231338
request_output = RequestOutputFactory.create(seq_group)
1324-
self.request_outputs.append(request_output)
1339+
ctx.request_outputs.append(request_output)
13251340

13261341
for seq_group in scheduler_outputs.ignored_seq_groups:
13271342
request_output = RequestOutputFactory.create(seq_group)
1328-
self.request_outputs.append(request_output)
1343+
ctx.request_outputs.append(request_output)
13291344

13301345
if is_async:
13311346
# Log stats.
@@ -1421,29 +1436,43 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14211436
"Pipeline parallelism is only supported through AsyncLLMEngine "
14221437
"as performance will be severely degraded otherwise.")
14231438

1439+
# For llm_engine, there is no pipeline parallel support, so the engine
1440+
# used is always 0
1441+
virtual_engine = 0
1442+
14241443
# These are cached outputs from previous iterations. None if on first
14251444
# iteration
1426-
cached_outputs = self.cached_scheduler_outputs[0]
1445+
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
14271446
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
14281447
scheduler_outputs = cached_outputs.scheduler_outputs
14291448
allow_async_output_proc = cached_outputs.allow_async_output_proc
14301449

1450+
ctx = self.scheduler_contexts[virtual_engine]
1451+
14311452
# Skip the scheduler if there are any remaining steps in the seq groups.
14321453
# This ensures that the scheduler is only called again when the current
14331454
# batch has completed.
14341455
if not self._has_remaining_steps(seq_group_metadata_list):
1456+
1457+
# Clear outputs on scheduler iteration start
1458+
ctx.request_outputs.clear()
1459+
1460+
# Schedule iteration
14351461
(seq_group_metadata_list, scheduler_outputs,
1436-
allow_async_output_proc) = self.scheduler[0].schedule()
1462+
allow_async_output_proc
1463+
) = self.scheduler[virtual_engine].schedule()
14371464

1438-
if not allow_async_output_proc and len(self.output_queue) > 0:
1439-
self._process_model_outputs(is_async=True)
1465+
# Maybe switch from async mode to sync mode
1466+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
1467+
self._process_model_outputs(virtual_engine=virtual_engine,
1468+
is_async=True)
14401469

14411470
if (self.scheduler_config.is_multi_step
14421471
and scheduler_outputs.num_lookahead_slots > 0):
14431472
# cache the scheduler outputs for the next iteration if we have
14441473
# lookahead slots
14451474
self._cache_scheduler_outputs_for_multi_step(
1446-
0, seq_group_metadata_list, scheduler_outputs,
1475+
virtual_engine, seq_group_metadata_list, scheduler_outputs,
14471476
allow_async_output_proc)
14481477

14491478
assert seq_group_metadata_list is not None
@@ -1454,14 +1483,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14541483

14551484
if not scheduler_outputs.is_empty():
14561485
finished_requests_ids = self.scheduler[
1457-
0].get_and_reset_finished_requests_ids()
1486+
virtual_engine].get_and_reset_finished_requests_ids()
14581487

14591488
# Check if we have a cached last_output from the previous iteration.
14601489
# For supporting PP this is probably the best way to pass the
14611490
# sampled_token_ids, as a separate broadcast over all the PP stages
14621491
# will cause one virtual engine's microbatch to block the pipeline.
14631492
last_sampled_token_ids = \
1464-
self._get_last_sampled_token_ids(0)
1493+
self._get_last_sampled_token_ids(virtual_engine)
14651494

14661495
execute_model_req = ExecuteModelRequest(
14671496
seq_group_metadata_list=seq_group_metadata_list,
@@ -1476,20 +1505,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14761505
last_sampled_token_ids=last_sampled_token_ids)
14771506

14781507
if allow_async_output_proc:
1479-
execute_model_req.output_proc_callback_fn = \
1480-
self._process_model_outputs
1508+
execute_model_req.async_callback = self.async_callback[
1509+
virtual_engine]
14811510

14821511
output = self.model_executor.execute_model(
14831512
execute_model_req=execute_model_req)
14841513

1485-
# we need to do this here so that last step's sampled_token_ids can
1514+
# We need to do this here so that last step's sampled_token_ids can
14861515
# be passed to the next iteration for PP.
14871516
if self.scheduler_config.is_multi_step:
1488-
self._update_cached_scheduler_output(0, output)
1517+
self._update_cached_scheduler_output(virtual_engine, output)
14891518
else:
1490-
if len(self.output_queue) > 0:
1519+
# Nothing scheduled => If there is pending async postprocessor,
1520+
# then finish it here.
1521+
if len(ctx.output_queue) > 0:
14911522
assert not self.scheduler_config.is_multi_step
1492-
self._process_model_outputs(is_async=True)
1523+
self._process_model_outputs(virtual_engine=virtual_engine,
1524+
is_async=True)
1525+
# No outputs in this case
14931526
output = []
14941527

14951528
# Finish the current step for all the sequence groups.
@@ -1504,7 +1537,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15041537

15051538
# Add results to the output_queue
15061539
# (for async or non-async postprocessing)
1507-
self.output_queue.append(
1540+
ctx.output_queue.append(
15081541
(output, seq_group_metadata_list, scheduler_outputs))
15091542

15101543
if output and allow_async_output_proc:
@@ -1515,23 +1548,27 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15151548
output[0], seq_group_metadata_list,
15161549
scheduler_outputs.scheduled_seq_groups)
15171550

1551+
# Check if need to run the usual non-async path
15181552
if not allow_async_output_proc:
1519-
self._process_model_outputs(is_async=False)
1553+
self._process_model_outputs(virtual_engine=virtual_engine,
1554+
is_async=False)
15201555

15211556
# Log stats.
15221557
self.do_log_stats(scheduler_outputs, output)
15231558

15241559
# Tracing
15251560
self.do_tracing(scheduler_outputs)
15261561
else:
1527-
self.request_outputs = []
1562+
# Multi-step case
1563+
ctx.request_outputs = []
15281564

15291565
if not self.has_unfinished_requests():
1530-
# Drain async postprocessor
1531-
if len(self.output_queue) > 0:
1566+
# Drain async postprocessor (if exists)
1567+
if len(ctx.output_queue) > 0:
15321568
assert not self.scheduler_config.is_multi_step
1533-
self._process_model_outputs(is_async=True, clear_outputs=False)
1534-
assert len(self.output_queue) == 0
1569+
self._process_model_outputs(virtual_engine=virtual_engine,
1570+
is_async=True)
1571+
assert len(ctx.output_queue) == 0
15351572

15361573
# Stop the execute model loop in parallel workers until there are
15371574
# more requests to process. This avoids waiting indefinitely in
@@ -1540,7 +1577,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
15401577
# queued control plane messages, such as add/remove lora adapters.
15411578
self.model_executor.stop_remote_worker_execution_loop()
15421579

1543-
return self.request_outputs
1580+
return ctx.request_outputs
15441581

15451582
def _has_remaining_steps(
15461583
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]

0 commit comments

Comments
 (0)