diff --git a/tests/e2e/multicard/test_pipeline_parallel.py b/tests/e2e/multicard/test_pipeline_parallel.py new file mode 100644 index 0000000000..a7070b6889 --- /dev/null +++ b/tests/e2e/multicard/test_pipeline_parallel.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import pytest + +from tests.conftest import VllmRunner + +MODELS = [ + "Qwen/Qwen3-0.6B", +] + +TENSOR_PARALLELS = [2] +PIPELINE_PARALLELS = [2] + +prompts = [ + "Hello, my name is", + "The future of AI is", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) +@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS) +def test_models(model: str, tp_size: int, pp_size: int) -> None: + with VllmRunner(model, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + enforce_eager=True, + gpu_memory_utilization=0.7) as vllm_model: + vllm_model.generate_greedy(prompts, 64) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 47e827498d..eabcdbcc19 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -37,7 +37,8 @@ from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_dp_group, get_pp_group +from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, + get_tp_group) from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger @@ -146,6 +147,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.block_size = vllm_config.cache_config.block_size @@ -921,8 +923,8 @@ def _process_reqs( cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) - sample_indices = cu_num_tokens - 1 - sample_indices = torch.from_numpy(sample_indices).to(self.device, + logits_indices = cu_num_tokens - 1 + logits_indices = torch.from_numpy(logits_indices).to(self.device, non_blocking=True) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets @@ -1153,14 +1155,14 @@ def _process_reqs( spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) - sample_indices = spec_decode_metadata.logits_indices + logits_indices = spec_decode_metadata.logits_indices aux_hidden_states = None if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = hidden_states return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, sample_indices, aux_hidden_states, + total_num_scheduled_tokens, logits_indices, aux_hidden_states, num_scheduled_tokens) def _get_cumsum_and_arange( @@ -1397,16 +1399,42 @@ def execute_model( # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, sample_indices, aux_hidden_states, + num_scheduled_tokens, logits_indices, aux_hidden_states, num_scheduled_tokens_np) = (self._process_reqs( scheduler_output, intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): - if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) - logits = self.model.compute_logits(hidden_states[sample_indices], - None) + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + if not broadcast_pp_output: + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.input_batch.pooling_params: + return self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: logits = self.apply_grammar_bitmask(scheduler_output, logits) @@ -1423,6 +1451,7 @@ def execute_model( # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. + assert logits is not None bonus_logits = logits[ spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 02094f5c58..df03d508e4 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -28,8 +28,10 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest +from vllm.sequence import IntermediateTensors from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -206,7 +208,22 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: - output = self.model_runner.execute_model(scheduler_output) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + parallel_config = self.vllm_config.parallel_config + if parallel_config.distributed_executor_backend != "external_launcher" \ + and not get_pp_group().is_last_rank: + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return None + assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None def load_model(self) -> None: