Skip to content

Commit 0b0832a

Browse files
committed
cherry-pick:engineV1 support pipeline parallel
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent cc210f4 commit 0b0832a

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from vllm.attention.layer import Attention
3838
from vllm.config import CompilationLevel, VllmConfig
3939
from vllm.distributed import get_tensor_model_parallel_world_size
40-
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
40+
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
41+
get_tp_group)
4142
from vllm.forward_context import set_forward_context
4243
from vllm.inputs import INPUT_REGISTRY
4344
from vllm.logger import logger
@@ -146,6 +147,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
146147
self.model_config = vllm_config.model_config
147148
self.cache_config = vllm_config.cache_config
148149
self.lora_config = vllm_config.lora_config
150+
self.parallel_config = vllm_config.parallel_config
149151
self.scheduler_config = vllm_config.scheduler_config
150152
self.speculative_config = vllm_config.speculative_config
151153
self.block_size = vllm_config.cache_config.block_size
@@ -921,8 +923,8 @@ def _process_reqs(
921923
cu_num_tokens = np.cumsum(num_scheduled_tokens)
922924
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
923925
num_scheduled_tokens)
924-
sample_indices = cu_num_tokens - 1
925-
sample_indices = torch.from_numpy(sample_indices).to(self.device,
926+
logits_indices = cu_num_tokens - 1
927+
logits_indices = torch.from_numpy(logits_indices).to(self.device,
926928
non_blocking=True)
927929
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
928930

@@ -1153,14 +1155,14 @@ def _process_reqs(
11531155

11541156
spec_decode_metadata = self._calc_spec_decode_metadata(
11551157
num_draft_tokens, cu_num_tokens)
1156-
sample_indices = spec_decode_metadata.logits_indices
1158+
logits_indices = spec_decode_metadata.logits_indices
11571159

11581160
aux_hidden_states = None
11591161
if self.use_aux_hidden_state_outputs:
11601162
hidden_states, aux_hidden_states = hidden_states
11611163

11621164
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
1163-
total_num_scheduled_tokens, sample_indices, aux_hidden_states,
1165+
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
11641166
num_scheduled_tokens)
11651167

11661168
def _get_cumsum_and_arange(
@@ -1397,16 +1399,42 @@ def execute_model(
13971399
# Return empty ModelRunnerOuptut if there's no work to do.
13981400
return EMPTY_MODEL_RUNNER_OUTPUT
13991401
(attn_metadata, hidden_states, spec_decode_metadata, positions,
1400-
num_scheduled_tokens, sample_indices, aux_hidden_states,
1402+
num_scheduled_tokens, logits_indices, aux_hidden_states,
14011403
num_scheduled_tokens_np) = (self._process_reqs(
14021404
scheduler_output, intermediate_tensors))
14031405

14041406
with ProfileExecuteDuration().capture_async("post process"):
14051407
if self.input_batch.pooling_params:
14061408
return self._pool(hidden_states, num_scheduled_tokens,
14071409
num_scheduled_tokens_np)
1408-
logits = self.model.compute_logits(hidden_states[sample_indices],
1409-
None)
1410+
# Broadcast PP output for external_launcher (torchrun)
1411+
# to make sure we are synced across pp ranks
1412+
# TODO: Support overlapping mirco-batches
1413+
# https://github.com/vllm-project/vllm/issues/18019
1414+
broadcast_pp_output = \
1415+
self.parallel_config.distributed_executor_backend \
1416+
== "external_launcher" and len(get_pp_group().ranks) > 0
1417+
if not get_pp_group().is_last_rank:
1418+
# For mid-pipeline stages, return the hidden states.
1419+
if not broadcast_pp_output:
1420+
return hidden_states
1421+
assert isinstance(hidden_states, IntermediateTensors)
1422+
get_pp_group().send_tensor_dict(
1423+
hidden_states.tensors, all_gather_group=get_tp_group())
1424+
logits = None
1425+
else:
1426+
sample_hidden_states = hidden_states[logits_indices]
1427+
logits = self.model.compute_logits(sample_hidden_states, None)
1428+
if broadcast_pp_output:
1429+
model_output_broadcast_data = {
1430+
"logits": logits.contiguous(),
1431+
} if logits is not None else {}
1432+
model_output_broadcast_data = get_pp_group(
1433+
).broadcast_tensor_dict(model_output_broadcast_data,
1434+
src=len(get_pp_group().ranks) - 1)
1435+
assert model_output_broadcast_data is not None
1436+
logits = model_output_broadcast_data["logits"]
1437+
14101438
# Apply structured output bitmasks if present
14111439
if scheduler_output.grammar_bitmask is not None:
14121440
logits = self.apply_grammar_bitmask(scheduler_output, logits)
@@ -1423,6 +1451,7 @@ def execute_model(
14231451
# creates a new tensor with separate storage from the original
14241452
# logits tensor. This means any in-place operations on bonus_logits
14251453
# won't affect the original logits tensor.
1454+
assert logits is not None
14261455
bonus_logits = logits[
14271456
spec_decode_metadata.bonus_logits_indices]
14281457
sampler_output = self.sampler(

vllm_ascend/worker/worker_v1.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from vllm.distributed import (ensure_model_parallel_initialized,
2929
init_distributed_environment)
3030
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
31+
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3132
from vllm.logger import logger
3233
from vllm.lora.request import LoRARequest
34+
from vllm.sequence import IntermediateTensors
3335
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
3436
from vllm.v1.core.sched.output import SchedulerOutput
3537
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@@ -206,7 +208,22 @@ def execute_model(
206208
self,
207209
scheduler_output: "SchedulerOutput",
208210
) -> Optional[ModelRunnerOutput]:
209-
output = self.model_runner.execute_model(scheduler_output)
211+
intermediate_tensors = None
212+
if not get_pp_group().is_first_rank:
213+
intermediate_tensors = IntermediateTensors(
214+
get_pp_group().recv_tensor_dict(
215+
all_gather_group=get_tp_group()))
216+
217+
output = self.model_runner.execute_model(scheduler_output,
218+
intermediate_tensors)
219+
parallel_config = self.vllm_config.parallel_config
220+
if parallel_config.distributed_executor_backend != "external_launcher" \
221+
and not get_pp_group().is_last_rank:
222+
assert isinstance(output, IntermediateTensors)
223+
get_pp_group().send_tensor_dict(output.tensors,
224+
all_gather_group=get_tp_group())
225+
return None
226+
assert isinstance(output, ModelRunnerOutput)
210227
return output if self.is_driver_worker else None
211228

212229
def load_model(self) -> None:

0 commit comments

Comments
 (0)