1
+ import functools
1
2
import time
2
3
from collections import deque
3
4
from contextlib import contextmanager
4
- from dataclasses import dataclass
5
+ from dataclasses import dataclass , field
5
6
from typing import (TYPE_CHECKING , Any , ClassVar , Deque , Dict , Iterable , List ,
6
7
Mapping , Optional )
7
8
from typing import Sequence as GenericSequence
@@ -88,6 +89,17 @@ class SchedulerOutputState:
88
89
last_output : Optional [SamplerOutput ] = None
89
90
90
91
92
+ @dataclass
93
+ class SchedulerContext :
94
+ output_queue : Deque [Tuple [List [SamplerOutput ], List [SequenceGroupMetadata ],
95
+ SchedulerOutputs ]] = field (
96
+ default_factory = lambda : deque ())
97
+
98
+ request_outputs : List [Union [RequestOutput ,
99
+ EmbeddingRequestOutput ]] = field (
100
+ default_factory = lambda : [])
101
+
102
+
91
103
class LLMEngine :
92
104
"""An LLM engine that receives requests and generates texts.
93
105
@@ -350,9 +362,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
350
362
Scheduler (
351
363
scheduler_config , cache_config , lora_config ,
352
364
parallel_config .pipeline_parallel_size ,
353
- self ._process_model_outputs
365
+ functools .partial (self ._process_model_outputs ,
366
+ virtual_engine = v_id ,
367
+ is_async = True )
354
368
if model_config .use_async_output_proc else None )
355
- for _ in range (parallel_config .pipeline_parallel_size )
369
+ for v_id in range (parallel_config .pipeline_parallel_size )
356
370
]
357
371
358
372
# Metric Logging.
@@ -406,12 +420,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
406
420
for _ in range (self .parallel_config .pipeline_parallel_size )
407
421
]
408
422
409
- # Async output processing pointers
410
- self .output_queue : Deque [Tuple [List [SamplerOutput ],
411
- List [SequenceGroupMetadata ],
412
- SchedulerOutputs ]] = deque ()
413
- self .request_outputs : List [Union [RequestOutput ,
414
- EmbeddingRequestOutput ]] = []
423
+ self .scheduler_contexts = [
424
+ SchedulerContext ()
425
+ for _ in range (self .parallel_config .pipeline_parallel_size )
426
+ ]
427
+
428
+ self .async_callback = [
429
+ functools .partial (self ._process_model_outputs ,
430
+ virtual_engine = v_id ,
431
+ is_async = True )
432
+ for v_id in range (self .parallel_config .pipeline_parallel_size )
433
+ ]
415
434
416
435
def _initialize_kv_caches (self ) -> None :
417
436
"""Initialize the KV cache in the worker(s).
@@ -1221,32 +1240,28 @@ def _process_sequence_group_outputs(
1221
1240
1222
1241
return
1223
1242
1224
- def _process_model_outputs (self ,
1225
- is_async : bool ,
1226
- clear_outputs : bool = True ) -> None :
1243
+ def _process_model_outputs (self , virtual_engine : int ,
1244
+ is_async : bool ) -> None :
1227
1245
"""Apply the model output to the sequences in the scheduled seq groups.
1228
1246
1247
+ virtual_engine: The engine id to operate on
1229
1248
is_async: Indicates whether this postprocessor runs in
1230
1249
parallel with the GPU forward pass and is processing
1231
1250
tokens from the previous step. If this is true, then
1232
1251
no tokens need to be appended since it is already done
1233
1252
externally (before the next schedule() call)
1234
- clear_outputs: Sometimes existing outputs need to be combined
1235
- with outputs of this call. This happens for postprocessor
1236
- draining at the final stage (like when sequences are finished)
1237
1253
1238
1254
Returns RequestOutputs that can be returned to the client.
1239
1255
"""
1240
1256
now = time .time ()
1241
1257
1242
- if clear_outputs :
1243
- self .request_outputs .clear ()
1258
+ ctx : SchedulerContext = self .scheduler_contexts [virtual_engine ]
1244
1259
1245
- if len (self .output_queue ) == 0 :
1260
+ if len (ctx .output_queue ) == 0 :
1246
1261
return None
1247
1262
1248
1263
(outputs , seq_group_metadata_list ,
1249
- scheduler_outputs ) = self .output_queue .popleft ()
1264
+ scheduler_outputs ) = ctx .output_queue .popleft ()
1250
1265
1251
1266
# Sanity check
1252
1267
assert len (seq_group_metadata_list ) == len (
@@ -1321,11 +1336,11 @@ def _process_model_outputs(self,
1321
1336
if (seq_group .is_finished ()
1322
1337
if self .step_return_finished_only else True ):
1323
1338
request_output = RequestOutputFactory .create (seq_group )
1324
- self .request_outputs .append (request_output )
1339
+ ctx .request_outputs .append (request_output )
1325
1340
1326
1341
for seq_group in scheduler_outputs .ignored_seq_groups :
1327
1342
request_output = RequestOutputFactory .create (seq_group )
1328
- self .request_outputs .append (request_output )
1343
+ ctx .request_outputs .append (request_output )
1329
1344
1330
1345
if is_async :
1331
1346
# Log stats.
@@ -1421,29 +1436,43 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1421
1436
"Pipeline parallelism is only supported through AsyncLLMEngine "
1422
1437
"as performance will be severely degraded otherwise." )
1423
1438
1439
+ # For llm_engine, there is no pipeline parallel support, so the engine
1440
+ # used is always 0
1441
+ virtual_engine = 0
1442
+
1424
1443
# These are cached outputs from previous iterations. None if on first
1425
1444
# iteration
1426
- cached_outputs = self .cached_scheduler_outputs [0 ]
1445
+ cached_outputs = self .cached_scheduler_outputs [virtual_engine ]
1427
1446
seq_group_metadata_list = cached_outputs .seq_group_metadata_list
1428
1447
scheduler_outputs = cached_outputs .scheduler_outputs
1429
1448
allow_async_output_proc = cached_outputs .allow_async_output_proc
1430
1449
1450
+ ctx = self .scheduler_contexts [virtual_engine ]
1451
+
1431
1452
# Skip the scheduler if there are any remaining steps in the seq groups.
1432
1453
# This ensures that the scheduler is only called again when the current
1433
1454
# batch has completed.
1434
1455
if not self ._has_remaining_steps (seq_group_metadata_list ):
1456
+
1457
+ # Clear outputs on scheduler iteration start
1458
+ ctx .request_outputs .clear ()
1459
+
1460
+ # Schedule iteration
1435
1461
(seq_group_metadata_list , scheduler_outputs ,
1436
- allow_async_output_proc ) = self .scheduler [0 ].schedule ()
1462
+ allow_async_output_proc
1463
+ ) = self .scheduler [virtual_engine ].schedule ()
1437
1464
1438
- if not allow_async_output_proc and len (self .output_queue ) > 0 :
1439
- self ._process_model_outputs (is_async = True )
1465
+ # Maybe switch from async mode to sync mode
1466
+ if not allow_async_output_proc and len (ctx .output_queue ) > 0 :
1467
+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1468
+ is_async = True )
1440
1469
1441
1470
if (self .scheduler_config .is_multi_step
1442
1471
and scheduler_outputs .num_lookahead_slots > 0 ):
1443
1472
# cache the scheduler outputs for the next iteration if we have
1444
1473
# lookahead slots
1445
1474
self ._cache_scheduler_outputs_for_multi_step (
1446
- 0 , seq_group_metadata_list , scheduler_outputs ,
1475
+ virtual_engine , seq_group_metadata_list , scheduler_outputs ,
1447
1476
allow_async_output_proc )
1448
1477
1449
1478
assert seq_group_metadata_list is not None
@@ -1454,14 +1483,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1454
1483
1455
1484
if not scheduler_outputs .is_empty ():
1456
1485
finished_requests_ids = self .scheduler [
1457
- 0 ].get_and_reset_finished_requests_ids ()
1486
+ virtual_engine ].get_and_reset_finished_requests_ids ()
1458
1487
1459
1488
# Check if we have a cached last_output from the previous iteration.
1460
1489
# For supporting PP this is probably the best way to pass the
1461
1490
# sampled_token_ids, as a separate broadcast over all the PP stages
1462
1491
# will cause one virtual engine's microbatch to block the pipeline.
1463
1492
last_sampled_token_ids = \
1464
- self ._get_last_sampled_token_ids (0 )
1493
+ self ._get_last_sampled_token_ids (virtual_engine )
1465
1494
1466
1495
execute_model_req = ExecuteModelRequest (
1467
1496
seq_group_metadata_list = seq_group_metadata_list ,
@@ -1476,20 +1505,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1476
1505
last_sampled_token_ids = last_sampled_token_ids )
1477
1506
1478
1507
if allow_async_output_proc :
1479
- execute_model_req .output_proc_callback_fn = \
1480
- self . _process_model_outputs
1508
+ execute_model_req .async_callback = self . async_callback [
1509
+ virtual_engine ]
1481
1510
1482
1511
output = self .model_executor .execute_model (
1483
1512
execute_model_req = execute_model_req )
1484
1513
1485
- # we need to do this here so that last step's sampled_token_ids can
1514
+ # We need to do this here so that last step's sampled_token_ids can
1486
1515
# be passed to the next iteration for PP.
1487
1516
if self .scheduler_config .is_multi_step :
1488
- self ._update_cached_scheduler_output (0 , output )
1517
+ self ._update_cached_scheduler_output (virtual_engine , output )
1489
1518
else :
1490
- if len (self .output_queue ) > 0 :
1519
+ # Nothing scheduled => If there is pending async postprocessor,
1520
+ # then finish it here.
1521
+ if len (ctx .output_queue ) > 0 :
1491
1522
assert not self .scheduler_config .is_multi_step
1492
- self ._process_model_outputs (is_async = True )
1523
+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1524
+ is_async = True )
1525
+ # No outputs in this case
1493
1526
output = []
1494
1527
1495
1528
# Finish the current step for all the sequence groups.
@@ -1504,7 +1537,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1504
1537
1505
1538
# Add results to the output_queue
1506
1539
# (for async or non-async postprocessing)
1507
- self .output_queue .append (
1540
+ ctx .output_queue .append (
1508
1541
(output , seq_group_metadata_list , scheduler_outputs ))
1509
1542
1510
1543
if output and allow_async_output_proc :
@@ -1515,23 +1548,27 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1515
1548
output [0 ], seq_group_metadata_list ,
1516
1549
scheduler_outputs .scheduled_seq_groups )
1517
1550
1551
+ # Check if need to run the usual non-async path
1518
1552
if not allow_async_output_proc :
1519
- self ._process_model_outputs (is_async = False )
1553
+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1554
+ is_async = False )
1520
1555
1521
1556
# Log stats.
1522
1557
self .do_log_stats (scheduler_outputs , output )
1523
1558
1524
1559
# Tracing
1525
1560
self .do_tracing (scheduler_outputs )
1526
1561
else :
1527
- self .request_outputs = []
1562
+ # Multi-step case
1563
+ ctx .request_outputs = []
1528
1564
1529
1565
if not self .has_unfinished_requests ():
1530
- # Drain async postprocessor
1531
- if len (self .output_queue ) > 0 :
1566
+ # Drain async postprocessor (if exists)
1567
+ if len (ctx .output_queue ) > 0 :
1532
1568
assert not self .scheduler_config .is_multi_step
1533
- self ._process_model_outputs (is_async = True , clear_outputs = False )
1534
- assert len (self .output_queue ) == 0
1569
+ self ._process_model_outputs (virtual_engine = virtual_engine ,
1570
+ is_async = True )
1571
+ assert len (ctx .output_queue ) == 0
1535
1572
1536
1573
# Stop the execute model loop in parallel workers until there are
1537
1574
# more requests to process. This avoids waiting indefinitely in
@@ -1540,7 +1577,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1540
1577
# queued control plane messages, such as add/remove lora adapters.
1541
1578
self .model_executor .stop_remote_worker_execution_loop ()
1542
1579
1543
- return self .request_outputs
1580
+ return ctx .request_outputs
1544
1581
1545
1582
def _has_remaining_steps (
1546
1583
self , seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]]
0 commit comments