Skip to content

Commit 948ebeb

Browse files
wuweiqiang24yiz-liu
authored andcommitted
Refactor tensor_parallel and comm_utils (vllm-project#2814)
### What this PR does / why we need it? 1. Move ops/comm_utils to ops/moe/comm_utils 2. Move distributed/tensor_parallel/gather_from_sequence_parallel_region to ops/moe/comm_utils 3. Delete distributed/tensor_parallel ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? e2e & ut - vLLM version: main - vLLM main: vllm-project/vllm@a1213fa --------- Signed-off-by: wuweiqiang24 <1005334931@qq.com> Signed-off-by: wuweiqiang24 <wuweiqiang11@huawei.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent d2e7eab commit 948ebeb

File tree

6 files changed

+153
-392
lines changed

6 files changed

+153
-392
lines changed

tests/ut/distributed/test_distributed_tensor_parallel.py

Lines changed: 0 additions & 139 deletions
This file was deleted.

tests/ut/ops/test_comm_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
18+
import pytest
19+
import torch
20+
from pytest_mock import MockerFixture
21+
22+
from tests.ut.base import PytestBase
23+
from vllm_ascend.ops.moe.comm_utils import (
24+
_gather_along_first_dim, async_all_to_all,
25+
gather_from_sequence_parallel_region)
26+
27+
28+
class TestDistributedCommunication(PytestBase):
29+
30+
@pytest.fixture(autouse=True)
31+
def context(self, mocker: MockerFixture):
32+
mocker.patch("torch.npu.current_device", return_value="cpu")
33+
mocker.patch("torch.distributed.get_world_size", return_value=4)
34+
35+
mocker.patch("torch.distributed.get_rank", return_value=0)
36+
37+
@pytest.mark.parametrize(
38+
"input_tensor, output_split_sizes, input_split_sizes",
39+
[(torch.randn(8, 16), [2, 2, 2, 2], [2, 2, 2, 2]),
40+
(torch.randn(16, 32), None, None)])
41+
def test_async_all_to_all(self, input_tensor, output_split_sizes,
42+
input_split_sizes, mocker: MockerFixture):
43+
"""Test async_all_to_all"""
44+
mock_group = mocker.MagicMock()
45+
mocker.patch("torch.distributed.all_to_all_single",
46+
return_value=mocker.MagicMock())
47+
48+
_, a2a_out, handle = async_all_to_all(input_tensor, output_split_sizes,
49+
input_split_sizes, mock_group)
50+
51+
# Check if the output tensor is created properly
52+
if output_split_sizes is None:
53+
assert a2a_out.shape == input_tensor.shape
54+
else:
55+
total_output_size = sum(output_split_sizes)
56+
expected_shape = [total_output_size] + list(
57+
input_tensor.size())[1:]
58+
assert a2a_out.shape == torch.Size(expected_shape)
59+
60+
# Ensure handle is returned from async operation
61+
assert handle is not None
62+
assert isinstance(handle, mocker.MagicMock)
63+
64+
@pytest.mark.parametrize("world_size, test_tensor, expected",
65+
[(1, torch.randn(8, 16), (8, 16)),
66+
(4, torch.randn(8, 16), (32, 16))])
67+
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
68+
mocker: MockerFixture):
69+
"""Test _gather_along_first_dim"""
70+
mocker.patch("torch.distributed.get_world_size",
71+
return_value=world_size)
72+
73+
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
74+
75+
assert result.shape == expected
76+
77+
@pytest.mark.parametrize("input_tensor, output_split_sizes",
78+
[(torch.randn(8, 16), None),
79+
(torch.randn(8, 16), [2, 2, 2, 2])])
80+
def test_gather_from_sequence_parallel_region(self, input_tensor,
81+
output_split_sizes,
82+
mocker: MockerFixture):
83+
"""Test gather_from_sequence_parallel_region"""
84+
mock_group = mocker.MagicMock()
85+
86+
result = gather_from_sequence_parallel_region(input_tensor, mock_group,
87+
output_split_sizes)
88+
89+
# If output_split_sizes is not provided, result should have expanded first dimension by world size
90+
if output_split_sizes is None:
91+
expected_shape = [input_tensor.shape[0] * 4] + list(
92+
input_tensor.shape[1:])
93+
assert result.shape == torch.Size(expected_shape)
94+
else:
95+
# If output_split_sizes is provided, result shape is dictated by sum of output_split_sizes
96+
expected_shape = [sum(output_split_sizes)] + list(
97+
input_tensor.shape[1:])
98+
assert result.shape == torch.Size(expected_shape)

tests/ut/ops/test_token_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def setUp(self):
348348
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
349349

350350
# Mock async_all_to_all
351-
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
351+
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
352352
self.mock_async_all_to_all = patcher6.start()
353353
self.addCleanup(patcher6.stop)
354354
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),

0 commit comments

Comments
 (0)