Skip to content

Commit c4f5426

Browse files
committed
cherry-pick:engineV1 support pipeline parallel
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent cc210f4 commit c4f5426

File tree

3 files changed

+101
-12
lines changed

3 files changed

+101
-12
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2023 The vLLM team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#
17+
import pytest
18+
19+
from tests.conftest import VllmRunner
20+
21+
MODELS = [
22+
"Qwen/Qwen3-0.6B",
23+
]
24+
25+
TENSOR_PARALLELS = [2]
26+
PIPELINE_PARALLELS = [2]
27+
28+
prompts = [
29+
"Hello, my name is",
30+
"The future of AI is",
31+
]
32+
33+
34+
@pytest.mark.parametrize("model", MODELS)
35+
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
36+
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
37+
def test_models(model: str, tp_size: int, pp_size: int) -> None:
38+
with VllmRunner(model,
39+
tensor_parallel_size=tp_size,
40+
pipeline_parallel_size=pp_size,
41+
enforce_eager=True,
42+
gpu_memory_utilization=0.7) as vllm_model:
43+
vllm_model.generate_greedy(prompts, 64)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from vllm.attention.layer import Attention
3838
from vllm.config import CompilationLevel, VllmConfig
3939
from vllm.distributed import get_tensor_model_parallel_world_size
40-
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
40+
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
41+
get_tp_group)
4142
from vllm.forward_context import set_forward_context
4243
from vllm.inputs import INPUT_REGISTRY
4344
from vllm.logger import logger
@@ -146,6 +147,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
146147
self.model_config = vllm_config.model_config
147148
self.cache_config = vllm_config.cache_config
148149
self.lora_config = vllm_config.lora_config
150+
self.parallel_config = vllm_config.parallel_config
149151
self.scheduler_config = vllm_config.scheduler_config
150152
self.speculative_config = vllm_config.speculative_config
151153
self.block_size = vllm_config.cache_config.block_size
@@ -921,8 +923,8 @@ def _process_reqs(
921923
cu_num_tokens = np.cumsum(num_scheduled_tokens)
922924
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
923925
num_scheduled_tokens)
924-
sample_indices = cu_num_tokens - 1
925-
sample_indices = torch.from_numpy(sample_indices).to(self.device,
926+
logits_indices = cu_num_tokens - 1
927+
logits_indices = torch.from_numpy(logits_indices).to(self.device,
926928
non_blocking=True)
927929
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
928930

@@ -1153,14 +1155,14 @@ def _process_reqs(
11531155

11541156
spec_decode_metadata = self._calc_spec_decode_metadata(
11551157
num_draft_tokens, cu_num_tokens)
1156-
sample_indices = spec_decode_metadata.logits_indices
1158+
logits_indices = spec_decode_metadata.logits_indices
11571159

11581160
aux_hidden_states = None
11591161
if self.use_aux_hidden_state_outputs:
11601162
hidden_states, aux_hidden_states = hidden_states
11611163

11621164
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
1163-
total_num_scheduled_tokens, sample_indices, aux_hidden_states,
1165+
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
11641166
num_scheduled_tokens)
11651167

11661168
def _get_cumsum_and_arange(
@@ -1397,16 +1399,42 @@ def execute_model(
13971399
# Return empty ModelRunnerOuptut if there's no work to do.
13981400
return EMPTY_MODEL_RUNNER_OUTPUT
13991401
(attn_metadata, hidden_states, spec_decode_metadata, positions,
1400-
num_scheduled_tokens, sample_indices, aux_hidden_states,
1402+
num_scheduled_tokens, logits_indices, aux_hidden_states,
14011403
num_scheduled_tokens_np) = (self._process_reqs(
14021404
scheduler_output, intermediate_tensors))
14031405

14041406
with ProfileExecuteDuration().capture_async("post process"):
1405-
if self.input_batch.pooling_params:
1406-
return self._pool(hidden_states, num_scheduled_tokens,
1407-
num_scheduled_tokens_np)
1408-
logits = self.model.compute_logits(hidden_states[sample_indices],
1409-
None)
1407+
# Broadcast PP output for external_launcher (torchrun)
1408+
# to make sure we are synced across pp ranks
1409+
# TODO: Support overlapping mirco-batches
1410+
# https://github.com/vllm-project/vllm/issues/18019
1411+
broadcast_pp_output = \
1412+
self.parallel_config.distributed_executor_backend \
1413+
== "external_launcher" and len(get_pp_group().ranks) > 0
1414+
if not get_pp_group().is_last_rank:
1415+
# For mid-pipeline stages, return the hidden states.
1416+
if not broadcast_pp_output:
1417+
return hidden_states
1418+
assert isinstance(hidden_states, IntermediateTensors)
1419+
get_pp_group().send_tensor_dict(
1420+
hidden_states.tensors, all_gather_group=get_tp_group())
1421+
logits = None
1422+
else:
1423+
if self.input_batch.pooling_params:
1424+
return self._pool(hidden_states, num_scheduled_tokens,
1425+
num_scheduled_tokens_np)
1426+
sample_hidden_states = hidden_states[logits_indices]
1427+
logits = self.model.compute_logits(sample_hidden_states, None)
1428+
if broadcast_pp_output:
1429+
model_output_broadcast_data = {
1430+
"logits": logits.contiguous(),
1431+
} if logits is not None else {}
1432+
model_output_broadcast_data = get_pp_group(
1433+
).broadcast_tensor_dict(model_output_broadcast_data,
1434+
src=len(get_pp_group().ranks) - 1)
1435+
assert model_output_broadcast_data is not None
1436+
logits = model_output_broadcast_data["logits"]
1437+
14101438
# Apply structured output bitmasks if present
14111439
if scheduler_output.grammar_bitmask is not None:
14121440
logits = self.apply_grammar_bitmask(scheduler_output, logits)
@@ -1423,6 +1451,7 @@ def execute_model(
14231451
# creates a new tensor with separate storage from the original
14241452
# logits tensor. This means any in-place operations on bonus_logits
14251453
# won't affect the original logits tensor.
1454+
assert logits is not None
14261455
bonus_logits = logits[
14271456
spec_decode_metadata.bonus_logits_indices]
14281457
sampler_output = self.sampler(

vllm_ascend/worker/worker_v1.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
from vllm.distributed import (ensure_model_parallel_initialized,
2929
init_distributed_environment)
3030
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
31+
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3132
from vllm.logger import logger
3233
from vllm.lora.request import LoRARequest
34+
from vllm.sequence import IntermediateTensors
3335
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
3436
from vllm.v1.core.sched.output import SchedulerOutput
3537
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@@ -206,7 +208,22 @@ def execute_model(
206208
self,
207209
scheduler_output: "SchedulerOutput",
208210
) -> Optional[ModelRunnerOutput]:
209-
output = self.model_runner.execute_model(scheduler_output)
211+
intermediate_tensors = None
212+
if not get_pp_group().is_first_rank:
213+
intermediate_tensors = IntermediateTensors(
214+
get_pp_group().recv_tensor_dict(
215+
all_gather_group=get_tp_group()))
216+
217+
output = self.model_runner.execute_model(scheduler_output,
218+
intermediate_tensors)
219+
parallel_config = self.vllm_config.parallel_config
220+
if parallel_config.distributed_executor_backend != "external_launcher" \
221+
and not get_pp_group().is_last_rank:
222+
assert isinstance(output, IntermediateTensors)
223+
get_pp_group().send_tensor_dict(output.tensors,
224+
all_gather_group=get_tp_group())
225+
return None
226+
assert isinstance(output, ModelRunnerOutput)
210227
return output if self.is_driver_worker else None
211228

212229
def load_model(self) -> None:

0 commit comments

Comments
 (0)