37
37
from vllm .attention .layer import Attention
38
38
from vllm .config import CompilationLevel , VllmConfig
39
39
from vllm .distributed import get_tensor_model_parallel_world_size
40
- from vllm .distributed .parallel_state import get_dp_group , get_pp_group
40
+ from vllm .distributed .parallel_state import (get_dp_group , get_pp_group ,
41
+ get_tp_group )
41
42
from vllm .forward_context import set_forward_context
42
43
from vllm .inputs import INPUT_REGISTRY
43
44
from vllm .logger import logger
@@ -146,6 +147,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
146
147
self .model_config = vllm_config .model_config
147
148
self .cache_config = vllm_config .cache_config
148
149
self .lora_config = vllm_config .lora_config
150
+ self .parallel_config = vllm_config .parallel_config
149
151
self .scheduler_config = vllm_config .scheduler_config
150
152
self .speculative_config = vllm_config .speculative_config
151
153
self .block_size = vllm_config .cache_config .block_size
@@ -921,8 +923,8 @@ def _process_reqs(
921
923
cu_num_tokens = np .cumsum (num_scheduled_tokens )
922
924
cumsums_offsets = np .repeat (cu_num_tokens - num_scheduled_tokens ,
923
925
num_scheduled_tokens )
924
- sample_indices = cu_num_tokens - 1
925
- sample_indices = torch .from_numpy (sample_indices ).to (self .device ,
926
+ logits_indices = cu_num_tokens - 1
927
+ logits_indices = torch .from_numpy (logits_indices ).to (self .device ,
926
928
non_blocking = True )
927
929
arange = self .arange_np [:total_num_scheduled_tokens ] - cumsums_offsets
928
930
@@ -1153,14 +1155,14 @@ def _process_reqs(
1153
1155
1154
1156
spec_decode_metadata = self ._calc_spec_decode_metadata (
1155
1157
num_draft_tokens , cu_num_tokens )
1156
- sample_indices = spec_decode_metadata .logits_indices
1158
+ logits_indices = spec_decode_metadata .logits_indices
1157
1159
1158
1160
aux_hidden_states = None
1159
1161
if self .use_aux_hidden_state_outputs :
1160
1162
hidden_states , aux_hidden_states = hidden_states
1161
1163
1162
1164
return (attn_metadata , hidden_states , spec_decode_metadata , positions ,
1163
- total_num_scheduled_tokens , sample_indices , aux_hidden_states ,
1165
+ total_num_scheduled_tokens , logits_indices , aux_hidden_states ,
1164
1166
num_scheduled_tokens )
1165
1167
1166
1168
def _get_cumsum_and_arange (
@@ -1397,16 +1399,42 @@ def execute_model(
1397
1399
# Return empty ModelRunnerOuptut if there's no work to do.
1398
1400
return EMPTY_MODEL_RUNNER_OUTPUT
1399
1401
(attn_metadata , hidden_states , spec_decode_metadata , positions ,
1400
- num_scheduled_tokens , sample_indices , aux_hidden_states ,
1402
+ num_scheduled_tokens , logits_indices , aux_hidden_states ,
1401
1403
num_scheduled_tokens_np ) = (self ._process_reqs (
1402
1404
scheduler_output , intermediate_tensors ))
1403
1405
1404
1406
with ProfileExecuteDuration ().capture_async ("post process" ):
1405
- if self .input_batch .pooling_params :
1406
- return self ._pool (hidden_states , num_scheduled_tokens ,
1407
- num_scheduled_tokens_np )
1408
- logits = self .model .compute_logits (hidden_states [sample_indices ],
1409
- None )
1407
+ # Broadcast PP output for external_launcher (torchrun)
1408
+ # to make sure we are synced across pp ranks
1409
+ # TODO: Support overlapping mirco-batches
1410
+ # https://github.com/vllm-project/vllm/issues/18019
1411
+ broadcast_pp_output = \
1412
+ self .parallel_config .distributed_executor_backend \
1413
+ == "external_launcher" and len (get_pp_group ().ranks ) > 0
1414
+ if not get_pp_group ().is_last_rank :
1415
+ # For mid-pipeline stages, return the hidden states.
1416
+ if not broadcast_pp_output :
1417
+ return hidden_states
1418
+ assert isinstance (hidden_states , IntermediateTensors )
1419
+ get_pp_group ().send_tensor_dict (
1420
+ hidden_states .tensors , all_gather_group = get_tp_group ())
1421
+ logits = None
1422
+ else :
1423
+ if self .input_batch .pooling_params :
1424
+ return self ._pool (hidden_states , num_scheduled_tokens ,
1425
+ num_scheduled_tokens_np )
1426
+ sample_hidden_states = hidden_states [logits_indices ]
1427
+ logits = self .model .compute_logits (sample_hidden_states , None )
1428
+ if broadcast_pp_output :
1429
+ model_output_broadcast_data = {
1430
+ "logits" : logits .contiguous (),
1431
+ } if logits is not None else {}
1432
+ model_output_broadcast_data = get_pp_group (
1433
+ ).broadcast_tensor_dict (model_output_broadcast_data ,
1434
+ src = len (get_pp_group ().ranks ) - 1 )
1435
+ assert model_output_broadcast_data is not None
1436
+ logits = model_output_broadcast_data ["logits" ]
1437
+
1410
1438
# Apply structured output bitmasks if present
1411
1439
if scheduler_output .grammar_bitmask is not None :
1412
1440
logits = self .apply_grammar_bitmask (scheduler_output , logits )
@@ -1423,6 +1451,7 @@ def execute_model(
1423
1451
# creates a new tensor with separate storage from the original
1424
1452
# logits tensor. This means any in-place operations on bonus_logits
1425
1453
# won't affect the original logits tensor.
1454
+ assert logits is not None
1426
1455
bonus_logits = logits [
1427
1456
spec_decode_metadata .bonus_logits_indices ]
1428
1457
sampler_output = self .sampler (
0 commit comments