Skip to content

Commit 592e5c5

Browse files
committed
add patch ut
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
1 parent 9cd4ac7 commit 592e5c5

File tree

3 files changed

+268
-6
lines changed

3 files changed

+268
-6
lines changed

tests/ut/patch/worker/patch_common/test_patch_distributed.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,106 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# This file is a part of the vllm-ascend project.
14-
#
14+
15+
from unittest.mock import MagicMock, patch
16+
17+
import torch
18+
from vllm.distributed.parallel_state import GroupCoordinator
1519

1620
from tests.ut.base import TestBase
21+
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
22+
GroupCoordinatorPatch
1723

1824

19-
class TestPatchDistributed(TestBase):
25+
class TestGroupCoordinatorPatch(TestBase):
2026

21-
def test_GroupCoordinator_patched(self):
22-
from vllm.distributed.parallel_state import GroupCoordinator
27+
def setUp(self):
28+
self.mock_group_ranks = [[0, 1]]
29+
self.mock_local_rank = 0
30+
self.mock_backend = "hccl"
31+
self.mock_use_device_comm = True
32+
33+
patcher_get_rank = patch("torch.distributed.get_rank", return_value=0)
34+
patcher_new_group = patch("torch.distributed.new_group",
35+
return_value=MagicMock())
36+
patcher_is_cuda_alike = patch(
37+
"vllm.platforms.current_platform.is_cuda_alike", return_value=True)
38+
patcher_device_comm_cls = patch(
39+
"vllm.distributed.parallel_state.resolve_obj_by_qualname",
40+
return_value=MagicMock())
2341

24-
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
25-
GroupCoordinatorPatch
42+
self.mock_get_rank = patcher_get_rank.start()
43+
self.mock_new_group = patcher_new_group.start()
44+
self.mock_is_cuda_alike = patcher_is_cuda_alike.start()
45+
self.mock_resolve_obj = patcher_device_comm_cls.start()
2646

47+
self.addCleanup(patcher_get_rank.stop)
48+
self.addCleanup(patcher_new_group.stop)
49+
self.addCleanup(patcher_is_cuda_alike.stop)
50+
self.addCleanup(patcher_device_comm_cls.stop)
51+
52+
self.group_coordinator = GroupCoordinatorPatch(
53+
group_ranks=self.mock_group_ranks,
54+
local_rank=self.mock_local_rank,
55+
torch_distributed_backend=self.mock_backend,
56+
use_device_communicator=self.mock_use_device_comm)
57+
58+
def test_GroupCoordinator_patched(self):
2759
self.assertIs(GroupCoordinator, GroupCoordinatorPatch)
60+
61+
def test_initialization_sets_attributes(self):
62+
self.assertEqual(self.group_coordinator.world_size, 2)
63+
self.assertEqual(self.group_coordinator.rank_in_group, 0)
64+
self.assertTrue(hasattr(self.group_coordinator, "device_communicator"))
65+
66+
def test_all_to_all_returns_input_when_world_size_1(self):
67+
self.group_coordinator.world_size = 1
68+
input_tensor = torch.randn(2, 3)
69+
output = self.group_coordinator.all_to_all(input_tensor)
70+
self.assertTrue(torch.equal(output, input_tensor))
71+
72+
def test_all_to_all_raises_assertion_on_invalid_scatter_dim(self):
73+
input_tensor = torch.randn(2, 3)
74+
with self.assertRaises(AssertionError) as cm:
75+
self.group_coordinator.all_to_all(input_tensor, scatter_dim=2)
76+
self.assertIn("Invalid scatter dim", str(cm.exception))
77+
78+
def test_all_to_all_raises_assertion_on_invalid_gather_dim(self):
79+
input_tensor = torch.randn(2, 3)
80+
with self.assertRaises(AssertionError) as cm:
81+
self.group_coordinator.all_to_all(input_tensor, gather_dim=2)
82+
self.assertIn("Invalid gather dim", str(cm.exception))
83+
84+
def test_all_to_all_calls_device_communicator_with_correct_args(self):
85+
mock_communicator = MagicMock()
86+
self.group_coordinator.device_communicator = mock_communicator
87+
88+
input_tensor = torch.randn(2, 3)
89+
scatter_dim = 0
90+
gather_dim = 1
91+
scatter_sizes = [1, 1]
92+
gather_sizes = [1, 1]
93+
94+
self.group_coordinator.all_to_all(input_tensor,
95+
scatter_dim=scatter_dim,
96+
gather_dim=gather_dim,
97+
scatter_sizes=scatter_sizes,
98+
gather_sizes=gather_sizes)
99+
100+
mock_communicator.all_to_all.assert_called_once_with(
101+
input_tensor, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
102+
103+
def test_all_to_all_calls_device_communicator_without_sizes(self):
104+
mock_communicator = MagicMock()
105+
self.group_coordinator.device_communicator = mock_communicator
106+
107+
input_tensor = torch.randn(2, 3)
108+
scatter_dim = 0
109+
gather_dim = 1
110+
111+
self.group_coordinator.all_to_all(input_tensor,
112+
scatter_dim=scatter_dim,
113+
gather_dim=gather_dim)
114+
115+
mock_communicator.all_to_all.assert_called_once_with(
116+
input_tensor, scatter_dim, gather_dim, None, None)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
15+
from unittest.mock import MagicMock
16+
17+
import torch
18+
19+
from tests.ut.base import TestBase
20+
from vllm_ascend.patch.worker.patch_common.patch_minicpm import forward
21+
22+
23+
class TestPatchedMiniCPMForward(TestBase):
24+
25+
def setUp(self):
26+
self.mock_self = MagicMock()
27+
28+
self.mock_self.q_size = 128
29+
self.mock_self.kv_size = 128
30+
31+
self.mock_self.qkv_proj = MagicMock()
32+
self.mock_self.rotary_emb = MagicMock()
33+
self.mock_self.attn = MagicMock()
34+
self.mock_self.o_proj = MagicMock()
35+
36+
self.positions = torch.tensor([1, 2, 3])
37+
self.hidden_states = torch.randn(3, 256) # [batch_size, hidden_size]
38+
39+
self.mock_qkv = torch.randn(3, 384)
40+
self.mock_q = self.mock_qkv[:, :128]
41+
self.mock_k = self.mock_qkv[:, 128:256]
42+
self.mock_v = self.mock_qkv[:, 256:]
43+
44+
self.mock_self.qkv_proj.return_value = (self.mock_qkv, None)
45+
self.mock_self.rotary_emb.return_value = (self.mock_q, self.mock_k)
46+
self.mock_self.attn.return_value = torch.randn(3, 256)
47+
self.mock_self.o_proj.return_value = (torch.randn(3, 256), None)
48+
49+
def test_forward_patched(self):
50+
from vllm.model_executor.models.minicpm import MiniCPMAttention
51+
52+
self.assertIs(MiniCPMAttention.forward, forward)
53+
54+
def test_forward_function(self):
55+
result = forward(self.mock_self, self.positions, self.hidden_states)
56+
57+
self.mock_self.qkv_proj.assert_called_once_with(self.hidden_states)
58+
59+
args, _ = self.mock_self.rotary_emb.call_args
60+
self.assertEqual(len(args), 3)
61+
self.assertTrue(torch.equal(args[0], self.positions))
62+
self.assertTrue(torch.equal(args[1], self.mock_q))
63+
self.assertTrue(torch.equal(args[2], self.mock_k))
64+
65+
args, _ = self.mock_self.attn.call_args
66+
self.assertEqual(len(args), 3)
67+
self.assertTrue(torch.equal(args[0], self.mock_q))
68+
self.assertTrue(torch.equal(args[1], self.mock_k))
69+
self.assertTrue(torch.equal(args[2], self.mock_v))
70+
71+
self.mock_self.o_proj.assert_called_once_with(
72+
self.mock_self.attn.return_value)
73+
74+
self.assertEqual(result.shape, (3, 256))
75+
self.assertTrue(
76+
torch.equal(result, self.mock_self.o_proj.return_value[0]))
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
15+
from typing import List, Optional
16+
from unittest.mock import MagicMock, patch
17+
18+
import torch
19+
from torch.library import Library
20+
21+
from tests.ut.base import TestBase
22+
from vllm_ascend.patch.worker.patch_common.patch_utils import \
23+
ascend_direct_register_custom_op
24+
25+
26+
class TestAscendDirectRegisterCustomOp(TestBase):
27+
28+
def setUp(self):
29+
super().setUp()
30+
31+
self.mock_op_func = MagicMock()
32+
self.mock_op_func.__annotations__ = {
33+
'param1': list[int],
34+
'param2': Optional[list[int]],
35+
'param3': str
36+
}
37+
38+
self.mock_fake_impl = MagicMock()
39+
self.mock_lib = MagicMock(spec=Library)
40+
41+
self.op_name = "test_op"
42+
self.mutates_args = ["arg1"]
43+
self.dispatch_key = "CUDA"
44+
self.tags = (torch.Tag.pt2_compliant_tag, )
45+
46+
self.patch_infer_schema = patch(
47+
'vllm_ascend.patch.worker.patch_common.patch_utils.torch.library.infer_schema'
48+
)
49+
self.patch_vllm_lib = patch(
50+
'vllm_ascend.patch.worker.patch_common.patch_utils.vllm_lib')
51+
52+
self.mock_infer_schema = self.patch_infer_schema.start()
53+
self.mock_vllm_lib = self.patch_vllm_lib.start()
54+
55+
self.addCleanup(self.patch_infer_schema.stop)
56+
self.addCleanup(self.patch_vllm_lib.stop)
57+
58+
def test_register_with_default_lib(self):
59+
self.mock_infer_schema.return_value = "(Tensor self) -> Tensor"
60+
61+
ascend_direct_register_custom_op(op_name=self.op_name,
62+
op_func=self.mock_op_func,
63+
mutates_args=self.mutates_args,
64+
fake_impl=self.mock_fake_impl,
65+
dispatch_key=self.dispatch_key,
66+
tags=self.tags)
67+
68+
self.assertEqual(self.mock_op_func.__annotations__['param1'],
69+
List[int])
70+
self.assertEqual(self.mock_op_func.__annotations__['param2'],
71+
Optional[List[int]])
72+
self.assertEqual(self.mock_op_func.__annotations__['param3'], str)
73+
74+
self.mock_infer_schema.assert_called_once_with(
75+
self.mock_op_func, mutates_args=self.mutates_args)
76+
77+
self.mock_vllm_lib.define.assert_called_once_with(
78+
f"{self.op_name}(Tensor self) -> Tensor", tags=self.tags)
79+
self.mock_vllm_lib.impl.assert_called_once_with(
80+
self.op_name, self.mock_op_func, dispatch_key=self.dispatch_key)
81+
self.mock_vllm_lib._register_fake.assert_called_once_with(
82+
self.op_name, self.mock_fake_impl)
83+
84+
def test_register_with_custom_lib(self):
85+
self.mock_infer_schema.return_value = "(Tensor a, Tensor b) -> Tensor"
86+
87+
ascend_direct_register_custom_op(op_name=self.op_name,
88+
op_func=self.mock_op_func,
89+
mutates_args=self.mutates_args,
90+
target_lib=self.mock_lib)
91+
92+
self.mock_lib.define.assert_called_once_with(
93+
f"{self.op_name}(Tensor a, Tensor b) -> Tensor", tags=())
94+
self.mock_lib.impl.assert_called_once_with(self.op_name,
95+
self.mock_op_func,
96+
dispatch_key="CUDA")
97+
self.mock_lib._register_fake.assert_not_called()

0 commit comments

Comments
 (0)