From f675b3e1068c3f602cce3be2ec06d15ca7b8f019 Mon Sep 17 00:00:00 2001 From: z00811365 Date: Tue, 16 Sep 2025 21:19:06 +0800 Subject: [PATCH] fix dp of has_unfinished_dp Signed-off-by: zhaowe1936 --- .../patch_common/patch_distributed.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index 67d4797f9b..dcb47f96a5 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -19,7 +19,10 @@ import torch import vllm.envs as envs_vllm +from torch.distributed import ProcessGroup, ReduceOp from vllm.config import ParallelConfig +from vllm.distributed.utils import \ + stateless_init_torch_distributed_process_group from vllm_ascend.utils import is_310p @@ -41,9 +44,32 @@ def parallel_config_get_dp_port(self) -> int: return port -ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port +def ascend_stateless_init_dp_group(self) -> "ProcessGroup": + dp_group = stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="hccl") + return dp_group + + +def ascend_has_unfinished_dp(dp_group: "ProcessGroup", + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="npu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished +ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port +ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group +ParallelConfig.has_unfinished_dp = staticmethod(ascend_has_unfinished_dp) + class NullHandle: def __init__(self):