Skip to content

Support pipeline parallel in V1 Engine #1700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/e2e/multicard/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest

from tests.conftest import VllmRunner

MODELS = [
"Qwen/Qwen3-0.6B",
]

TENSOR_PARALLELS = [2]
PIPELINE_PARALLELS = [2]

prompts = [
"Hello, my name is",
"The future of AI is",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
def test_models(model: str, tp_size: int, pp_size: int) -> None:
with VllmRunner(model,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate_greedy(prompts, 64)
51 changes: 40 additions & 11 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.block_size = vllm_config.cache_config.block_size
Expand Down Expand Up @@ -921,8 +923,8 @@ def _process_reqs(
cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
sample_indices = cu_num_tokens - 1
sample_indices = torch.from_numpy(sample_indices).to(self.device,
logits_indices = cu_num_tokens - 1
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets

Expand Down Expand Up @@ -1153,14 +1155,14 @@ def _process_reqs(

spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
sample_indices = spec_decode_metadata.logits_indices
logits_indices = spec_decode_metadata.logits_indices

aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states

return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, sample_indices, aux_hidden_states,
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens)

def _get_cumsum_and_arange(
Expand Down Expand Up @@ -1397,16 +1399,42 @@ def execute_model(
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens, sample_indices, aux_hidden_states,
num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens_np) = (self._process_reqs(
scheduler_output, intermediate_tensors))

with ProfileExecuteDuration().capture_async("post process"):
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
logits = self.model.compute_logits(hidden_states[sample_indices],
None)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group(
).broadcast_tensor_dict(model_output_broadcast_data,
src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
logits = self.apply_grammar_bitmask(scheduler_output, logits)
Expand All @@ -1423,6 +1451,7 @@ def execute_model(
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[
spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
Expand Down
19 changes: 18 additions & 1 deletion vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.sequence import IntermediateTensors
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
Expand Down Expand Up @@ -206,7 +208,22 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))

output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None

def load_model(self) -> None:
Expand Down
Loading