Skip to content

Commit 6545e87

Browse files
committed
add patch ut
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
1 parent 9cd4ac7 commit 6545e87

File tree

3 files changed

+269
-4
lines changed

3 files changed

+269
-4
lines changed

tests/ut/patch/worker/patch_common/test_patch_distributed.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,100 @@
1313
# This file is a part of the vllm-ascend project.
1414
#
1515

16+
from unittest.mock import MagicMock, patch
17+
18+
import torch
19+
from vllm.distributed.parallel_state import GroupCoordinator
20+
1621
from tests.ut.base import TestBase
22+
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
23+
GroupCoordinatorPatch
1724

1825

1926
class TestPatchDistributed(TestBase):
2027

21-
def test_GroupCoordinator_patched(self):
22-
from vllm.distributed.parallel_state import GroupCoordinator
28+
def setUp(self):
29+
self.mock_group_ranks = [[0, 1]]
30+
self.mock_local_rank = 0
31+
self.mock_backend = "hccl"
32+
self.mock_use_device_comm = True
33+
34+
patcher_get_rank = patch("torch.distributed.get_rank", return_value=0)
35+
patcher_new_group = patch("torch.distributed.new_group",
36+
return_value=MagicMock())
37+
patcher_is_cuda_alike = patch(
38+
"vllm.platforms.current_platform.is_cuda_alike", return_value=True)
39+
patcher_device_comm_cls = patch(
40+
"vllm.distributed.parallel_state.resolve_obj_by_qualname",
41+
return_value=MagicMock())
42+
43+
self.mock_get_rank = patcher_get_rank.start()
44+
self.mock_new_group = patcher_new_group.start()
45+
self.mock_is_cuda_alike = patcher_is_cuda_alike.start()
46+
self.mock_resolve_obj = patcher_device_comm_cls.start()
2347

24-
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
25-
GroupCoordinatorPatch
48+
self.addCleanup(patcher_get_rank.stop)
49+
self.addCleanup(patcher_new_group.stop)
50+
self.addCleanup(patcher_is_cuda_alike.stop)
51+
self.addCleanup(patcher_device_comm_cls.stop)
2652

53+
self.group_coordinator = GroupCoordinatorPatch(
54+
group_ranks=self.mock_group_ranks,
55+
local_rank=self.mock_local_rank,
56+
torch_distributed_backend=self.mock_backend,
57+
use_device_communicator=self.mock_use_device_comm)
58+
59+
def test_GroupCoordinator_patched(self):
2760
self.assertIs(GroupCoordinator, GroupCoordinatorPatch)
61+
62+
def test_all_to_all_returns_input_when_world_size_1(self):
63+
self.group_coordinator.world_size = 1
64+
input_tensor = torch.randn(2, 3)
65+
output = self.group_coordinator.all_to_all(input_tensor)
66+
self.assertTrue(torch.equal(output, input_tensor))
67+
68+
def test_all_to_all_raises_assertion_on_invalid_scatter_dim(self):
69+
input_tensor = torch.randn(2, 3)
70+
with self.assertRaises(AssertionError) as cm:
71+
self.group_coordinator.all_to_all(input_tensor, scatter_dim=2)
72+
self.assertIn("Invalid scatter dim", str(cm.exception))
73+
74+
def test_all_to_all_raises_assertion_on_invalid_gather_dim(self):
75+
input_tensor = torch.randn(2, 3)
76+
with self.assertRaises(AssertionError) as cm:
77+
self.group_coordinator.all_to_all(input_tensor, gather_dim=2)
78+
self.assertIn("Invalid gather dim", str(cm.exception))
79+
80+
def test_all_to_all_calls_device_communicator_with_correct_args(self):
81+
mock_communicator = MagicMock()
82+
self.group_coordinator.device_communicator = mock_communicator
83+
84+
input_tensor = torch.randn(2, 3)
85+
scatter_dim = 0
86+
gather_dim = 1
87+
scatter_sizes = [1, 1]
88+
gather_sizes = [1, 1]
89+
90+
self.group_coordinator.all_to_all(input_tensor,
91+
scatter_dim=scatter_dim,
92+
gather_dim=gather_dim,
93+
scatter_sizes=scatter_sizes,
94+
gather_sizes=gather_sizes)
95+
96+
mock_communicator.all_to_all.assert_called_once_with(
97+
input_tensor, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
98+
99+
def test_all_to_all_calls_device_communicator_without_sizes(self):
100+
mock_communicator = MagicMock()
101+
self.group_coordinator.device_communicator = mock_communicator
102+
103+
input_tensor = torch.randn(2, 3)
104+
scatter_dim = 0
105+
gather_dim = 1
106+
107+
self.group_coordinator.all_to_all(input_tensor,
108+
scatter_dim=scatter_dim,
109+
gather_dim=gather_dim)
110+
111+
mock_communicator.all_to_all.assert_called_once_with(
112+
input_tensor, scatter_dim, gather_dim, None, None)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
16+
from unittest.mock import MagicMock
17+
18+
import torch
19+
20+
from tests.ut.base import TestBase
21+
from vllm_ascend.patch.worker.patch_common.patch_minicpm import forward
22+
23+
24+
class TestPatchMiniCPM(TestBase):
25+
26+
def setUp(self):
27+
self.mock_self = MagicMock()
28+
29+
self.mock_self.q_size = 128
30+
self.mock_self.kv_size = 128
31+
32+
self.mock_self.qkv_proj = MagicMock()
33+
self.mock_self.rotary_emb = MagicMock()
34+
self.mock_self.attn = MagicMock()
35+
self.mock_self.o_proj = MagicMock()
36+
37+
self.positions = torch.tensor([1, 2, 3])
38+
self.hidden_states = torch.randn(3, 256)
39+
40+
self.mock_qkv = torch.randn(3, 384)
41+
self.mock_q = self.mock_qkv[:, :128]
42+
self.mock_k = self.mock_qkv[:, 128:256]
43+
self.mock_v = self.mock_qkv[:, 256:]
44+
45+
self.mock_self.qkv_proj.return_value = (self.mock_qkv, None)
46+
self.mock_self.rotary_emb.return_value = (self.mock_q, self.mock_k)
47+
self.mock_self.attn.return_value = torch.randn(3, 256)
48+
self.mock_self.o_proj.return_value = (torch.randn(3, 256), None)
49+
50+
def test_forward_patched(self):
51+
from vllm.model_executor.models.minicpm import MiniCPMAttention
52+
53+
self.assertIs(MiniCPMAttention.forward, forward)
54+
55+
def test_forward_function(self):
56+
result = forward(self.mock_self, self.positions, self.hidden_states)
57+
58+
self.mock_self.qkv_proj.assert_called_once_with(self.hidden_states)
59+
60+
args, _ = self.mock_self.rotary_emb.call_args
61+
self.assertEqual(len(args), 3)
62+
self.assertTrue(torch.equal(args[0], self.positions))
63+
self.assertTrue(torch.equal(args[1], self.mock_q))
64+
self.assertTrue(torch.equal(args[2], self.mock_k))
65+
66+
args, _ = self.mock_self.attn.call_args
67+
self.assertEqual(len(args), 3)
68+
self.assertTrue(torch.equal(args[0], self.mock_q))
69+
self.assertTrue(torch.equal(args[1], self.mock_k))
70+
self.assertTrue(torch.equal(args[2], self.mock_v))
71+
72+
self.mock_self.o_proj.assert_called_once_with(
73+
self.mock_self.attn.return_value)
74+
75+
self.assertEqual(result.shape, (3, 256))
76+
self.assertTrue(
77+
torch.equal(result, self.mock_self.o_proj.return_value[0]))
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
16+
from typing import List, Optional
17+
from unittest.mock import MagicMock, patch
18+
19+
import torch
20+
from torch.library import Library
21+
22+
from tests.ut.base import TestBase
23+
from vllm_ascend.patch.worker.patch_common.patch_utils import \
24+
ascend_direct_register_custom_op
25+
26+
27+
class TestPatchUtils(TestBase):
28+
29+
def setUp(self):
30+
super().setUp()
31+
32+
self.mock_op_func = MagicMock()
33+
self.mock_op_func.__annotations__ = {
34+
'param1': list[int],
35+
'param2': Optional[list[int]],
36+
'param3': str
37+
}
38+
39+
self.mock_fake_impl = MagicMock()
40+
self.mock_lib = MagicMock(spec=Library)
41+
42+
self.op_name = "test_op"
43+
self.mutates_args = ["arg1"]
44+
self.dispatch_key = "NPU"
45+
self.tags = (torch.Tag.pt2_compliant_tag, )
46+
47+
self.patch_infer_schema = patch(
48+
'vllm_ascend.patch.worker.patch_common.patch_utils.torch.library.infer_schema'
49+
)
50+
self.patch_vllm_lib = patch(
51+
'vllm_ascend.patch.worker.patch_common.patch_utils.vllm_lib')
52+
53+
self.mock_infer_schema = self.patch_infer_schema.start()
54+
self.mock_vllm_lib = self.patch_vllm_lib.start()
55+
56+
self.addCleanup(self.patch_infer_schema.stop)
57+
self.addCleanup(self.patch_vllm_lib.stop)
58+
59+
def test_utils_patched(self):
60+
from vllm import utils
61+
62+
self.assertIs(utils.direct_register_custom_op, ascend_direct_register_custom_op)
63+
64+
def test_register_with_default_lib(self):
65+
self.mock_infer_schema.return_value = "(Tensor self) -> Tensor"
66+
67+
ascend_direct_register_custom_op(op_name=self.op_name,
68+
op_func=self.mock_op_func,
69+
mutates_args=self.mutates_args,
70+
fake_impl=self.mock_fake_impl,
71+
dispatch_key=self.dispatch_key,
72+
tags=self.tags)
73+
74+
self.assertEqual(self.mock_op_func.__annotations__['param1'],
75+
List[int])
76+
self.assertEqual(self.mock_op_func.__annotations__['param2'],
77+
Optional[List[int]])
78+
self.assertEqual(self.mock_op_func.__annotations__['param3'], str)
79+
80+
self.mock_infer_schema.assert_called_once_with(
81+
self.mock_op_func, mutates_args=self.mutates_args)
82+
83+
self.mock_vllm_lib.define.assert_called_once_with(
84+
f"{self.op_name}(Tensor self) -> Tensor", tags=self.tags)
85+
self.mock_vllm_lib.impl.assert_called_once_with(
86+
self.op_name, self.mock_op_func, dispatch_key=self.dispatch_key)
87+
self.mock_vllm_lib._register_fake.assert_called_once_with(
88+
self.op_name, self.mock_fake_impl)
89+
90+
def test_register_with_custom_lib(self):
91+
self.mock_infer_schema.return_value = "(Tensor a, Tensor b) -> Tensor"
92+
93+
ascend_direct_register_custom_op(op_name=self.op_name,
94+
op_func=self.mock_op_func,
95+
mutates_args=self.mutates_args,
96+
target_lib=self.mock_lib)
97+
98+
self.mock_lib.define.assert_called_once_with(
99+
f"{self.op_name}(Tensor a, Tensor b) -> Tensor", tags=())
100+
self.mock_lib.impl.assert_called_once_with(self.op_name,
101+
self.mock_op_func,
102+
dispatch_key="CUDA")
103+
self.mock_lib._register_fake.assert_not_called()

0 commit comments

Comments
 (0)