Skip to content

Commit 36fe57f

Browse files
committed
add patch ut
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
1 parent 9cd4ac7 commit 36fe57f

File tree

4 files changed

+422
-6
lines changed

4 files changed

+422
-6
lines changed

tests/ut/patch/worker/patch_common/test_patch_distributed.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,106 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# This file is a part of the vllm-ascend project.
14-
#
14+
15+
from unittest.mock import MagicMock, patch
16+
17+
import torch
18+
from vllm.distributed.parallel_state import GroupCoordinator
1519

1620
from tests.ut.base import TestBase
21+
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
22+
GroupCoordinatorPatch
1723

1824

19-
class TestPatchDistributed(TestBase):
25+
class TestGroupCoordinatorPatch(TestBase):
2026

21-
def test_GroupCoordinator_patched(self):
22-
from vllm.distributed.parallel_state import GroupCoordinator
27+
def setUp(self):
28+
self.mock_group_ranks = [[0, 1]]
29+
self.mock_local_rank = 0
30+
self.mock_backend = "hccl"
31+
self.mock_use_device_comm = True
32+
33+
patcher_get_rank = patch("torch.distributed.get_rank", return_value=0)
34+
patcher_new_group = patch("torch.distributed.new_group",
35+
return_value=MagicMock())
36+
patcher_is_cuda_alike = patch(
37+
"vllm.platforms.current_platform.is_cuda_alike", return_value=True)
38+
patcher_device_comm_cls = patch(
39+
"vllm.distributed.parallel_state.resolve_obj_by_qualname",
40+
return_value=MagicMock())
2341

24-
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
25-
GroupCoordinatorPatch
42+
self.mock_get_rank = patcher_get_rank.start()
43+
self.mock_new_group = patcher_new_group.start()
44+
self.mock_is_cuda_alike = patcher_is_cuda_alike.start()
45+
self.mock_resolve_obj = patcher_device_comm_cls.start()
2646

47+
self.addCleanup(patcher_get_rank.stop)
48+
self.addCleanup(patcher_new_group.stop)
49+
self.addCleanup(patcher_is_cuda_alike.stop)
50+
self.addCleanup(patcher_device_comm_cls.stop)
51+
52+
self.group_coordinator = GroupCoordinatorPatch(
53+
group_ranks=self.mock_group_ranks,
54+
local_rank=self.mock_local_rank,
55+
torch_distributed_backend=self.mock_backend,
56+
use_device_communicator=self.mock_use_device_comm)
57+
58+
def test_GroupCoordinator_patched(self):
2759
self.assertIs(GroupCoordinator, GroupCoordinatorPatch)
60+
61+
def test_initialization_sets_attributes(self):
62+
self.assertEqual(self.group_coordinator.world_size, 2)
63+
self.assertEqual(self.group_coordinator.rank_in_group, 0)
64+
self.assertTrue(hasattr(self.group_coordinator, "device_communicator"))
65+
66+
def test_all_to_all_returns_input_when_world_size_1(self):
67+
self.group_coordinator.world_size = 1
68+
input_tensor = torch.randn(2, 3)
69+
output = self.group_coordinator.all_to_all(input_tensor)
70+
self.assertTrue(torch.equal(output, input_tensor))
71+
72+
def test_all_to_all_raises_assertion_on_invalid_scatter_dim(self):
73+
input_tensor = torch.randn(2, 3)
74+
with self.assertRaises(AssertionError) as cm:
75+
self.group_coordinator.all_to_all(input_tensor, scatter_dim=2)
76+
self.assertIn("Invalid scatter dim", str(cm.exception))
77+
78+
def test_all_to_all_raises_assertion_on_invalid_gather_dim(self):
79+
input_tensor = torch.randn(2, 3)
80+
with self.assertRaises(AssertionError) as cm:
81+
self.group_coordinator.all_to_all(input_tensor, gather_dim=2)
82+
self.assertIn("Invalid gather dim", str(cm.exception))
83+
84+
def test_all_to_all_calls_device_communicator_with_correct_args(self):
85+
mock_communicator = MagicMock()
86+
self.group_coordinator.device_communicator = mock_communicator
87+
88+
input_tensor = torch.randn(2, 3)
89+
scatter_dim = 0
90+
gather_dim = 1
91+
scatter_sizes = [1, 1]
92+
gather_sizes = [1, 1]
93+
94+
self.group_coordinator.all_to_all(input_tensor,
95+
scatter_dim=scatter_dim,
96+
gather_dim=gather_dim,
97+
scatter_sizes=scatter_sizes,
98+
gather_sizes=gather_sizes)
99+
100+
mock_communicator.all_to_all.assert_called_once_with(
101+
input_tensor, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
102+
103+
def test_all_to_all_calls_device_communicator_without_sizes(self):
104+
mock_communicator = MagicMock()
105+
self.group_coordinator.device_communicator = mock_communicator
106+
107+
input_tensor = torch.randn(2, 3)
108+
scatter_dim = 0
109+
gather_dim = 1
110+
111+
self.group_coordinator.all_to_all(input_tensor,
112+
scatter_dim=scatter_dim,
113+
gather_dim=gather_dim)
114+
115+
mock_communicator.all_to_all.assert_called_once_with(
116+
input_tensor, scatter_dim, gather_dim, None, None)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
15+
from unittest.mock import MagicMock
16+
17+
import torch
18+
19+
from tests.ut.base import TestBase
20+
from vllm_ascend.patch.worker.patch_common.patch_minicpm import forward
21+
22+
23+
class TestPatchedMiniCPMForward(TestBase):
24+
25+
def setUp(self):
26+
self.mock_self = MagicMock()
27+
28+
self.mock_self.q_size = 128
29+
self.mock_self.kv_size = 128
30+
31+
self.mock_self.qkv_proj = MagicMock()
32+
self.mock_self.rotary_emb = MagicMock()
33+
self.mock_self.attn = MagicMock()
34+
self.mock_self.o_proj = MagicMock()
35+
36+
self.positions = torch.tensor([1, 2, 3])
37+
self.hidden_states = torch.randn(3, 256) # [batch_size, hidden_size]
38+
39+
self.mock_qkv = torch.randn(3, 384)
40+
self.mock_q = self.mock_qkv[:, :128]
41+
self.mock_k = self.mock_qkv[:, 128:256]
42+
self.mock_v = self.mock_qkv[:, 256:]
43+
44+
self.mock_self.qkv_proj.return_value = (self.mock_qkv, None)
45+
self.mock_self.rotary_emb.return_value = (self.mock_q, self.mock_k)
46+
self.mock_self.attn.return_value = torch.randn(3, 256)
47+
self.mock_self.o_proj.return_value = (torch.randn(3, 256), None)
48+
49+
def test_forward_patched(self):
50+
from vllm.model_executor.models.minicpm import MiniCPMAttention
51+
52+
self.assertIs(MiniCPMAttention.forward, forward)
53+
54+
def test_forward_function(self):
55+
result = forward(self.mock_self, self.positions, self.hidden_states)
56+
57+
self.mock_self.qkv_proj.assert_called_once_with(self.hidden_states)
58+
59+
args, _ = self.mock_self.rotary_emb.call_args
60+
self.assertEqual(len(args), 3)
61+
self.assertTrue(torch.equal(args[0], self.positions))
62+
self.assertTrue(torch.equal(args[1], self.mock_q))
63+
self.assertTrue(torch.equal(args[2], self.mock_k))
64+
65+
args, _ = self.mock_self.attn.call_args
66+
self.assertEqual(len(args), 3)
67+
self.assertTrue(torch.equal(args[0], self.mock_q))
68+
self.assertTrue(torch.equal(args[1], self.mock_k))
69+
self.assertTrue(torch.equal(args[2], self.mock_v))
70+
71+
self.mock_self.o_proj.assert_called_once_with(
72+
self.mock_self.attn.return_value)
73+
74+
self.assertEqual(result.shape, (3, 256))
75+
self.assertTrue(
76+
torch.equal(result, self.mock_self.o_proj.return_value[0]))
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
15+
from unittest.mock import MagicMock, patch
16+
17+
import torch
18+
from vllm.model_executor.layers.sampler import SamplerOutput
19+
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
20+
21+
from tests.ut.base import TestBase
22+
from vllm_ascend.patch.worker.patch_common.patch_multi_step_worker import \
23+
sampler_output
24+
25+
26+
class TestPatchedMultiStepWorkerSamplerOutput(TestBase):
27+
28+
def setUp(self):
29+
self.mock_self = MagicMock()
30+
31+
self.mock_self.device = torch.device("cpu")
32+
33+
self.mock_self._raise_if_unsupported = MagicMock()
34+
self.mock_self._expand_execute_model_request = MagicMock()
35+
self.mock_self.execute_model = MagicMock()
36+
self.mock_self._maybe_update_previous_hidden_states = MagicMock()
37+
self.mock_self._append_new_tokens = MagicMock()
38+
self.mock_self._filter_model_output = MagicMock()
39+
40+
self.execute_model_req = ExecuteModelRequest(
41+
seq_group_metadata_list=[MagicMock(spec=SequenceGroupMetadata)],
42+
num_steps=1,
43+
blocks_to_swap_in={},
44+
blocks_to_swap_out={},
45+
blocks_to_copy={},
46+
num_lookahead_slots=0)
47+
self.sample_len = 3
48+
self.seq_ids_with_bonus_token = {1, 2, 3}
49+
50+
self.expanded_request = MagicMock(spec=ExecuteModelRequest)
51+
self.indices_of_seq_with_bonus_tokens = [0, 1, 2]
52+
self.mock_self._expand_execute_model_request.return_value = (
53+
self.expanded_request, self.indices_of_seq_with_bonus_tokens)
54+
55+
self.filtered_output = [
56+
MagicMock(spec=SamplerOutput),
57+
MagicMock(spec=SamplerOutput),
58+
MagicMock(spec=SamplerOutput)
59+
]
60+
self.mock_self._filter_model_output.return_value = self.filtered_output
61+
62+
def test_sampler_output_patched(self):
63+
from vllm.spec_decode.multi_step_worker import MultiStepWorker
64+
65+
wrapped_func = MultiStepWorker.sampler_output.__wrapped__
66+
self.assertIs(
67+
wrapped_func, sampler_output,
68+
"Wrapped function does not match the expected implementation")
69+
70+
def test_gpu_multi_step_path(self):
71+
mock_model_runner = MagicMock()
72+
mock_model_runner.supports_gpu_multi_step.return_value = True
73+
74+
self.mock_self.model_runner = mock_model_runner
75+
with patch(
76+
'vllm_ascend.patch.worker.patch_common.patch_multi_step_worker.isinstance'
77+
) as mock_isinstance:
78+
mock_isinstance.return_value = True
79+
80+
mock_outputs = [
81+
MagicMock(spec=SamplerOutput),
82+
MagicMock(spec=SamplerOutput),
83+
MagicMock(spec=SamplerOutput)
84+
]
85+
self.mock_self.execute_model.return_value = mock_outputs
86+
87+
result, need_transpose = sampler_output(
88+
self.mock_self, self.execute_model_req, self.sample_len,
89+
self.seq_ids_with_bonus_token)
90+
91+
self.mock_self._raise_if_unsupported.assert_called_once_with(
92+
self.execute_model_req)
93+
self.mock_self._expand_execute_model_request.assert_called_once_with(
94+
self.execute_model_req, self.seq_ids_with_bonus_token)
95+
96+
mock_model_runner.supports_gpu_multi_step.assert_called_once_with(
97+
self.expanded_request)
98+
self.assertEqual(self.expanded_request.num_steps, self.sample_len)
99+
mock_model_runner.set_indices_of_seq_with_bonus_tokens.assert_called_once_with(
100+
self.indices_of_seq_with_bonus_tokens)
101+
self.mock_self.execute_model.assert_called_once_with(
102+
execute_model_req=self.expanded_request)
103+
104+
self.assertEqual(result, self.filtered_output)
105+
self.assertTrue(need_transpose)
106+
107+
self.mock_self._maybe_update_previous_hidden_states.assert_not_called()
108+
self.mock_self._append_new_tokens.assert_not_called()
109+
110+
def test_cpu_multi_step_path(self):
111+
mock_model_runner = MagicMock()
112+
mock_model_runner.supports_gpu_multi_step.return_value = False
113+
114+
self.mock_self.model_runner = mock_model_runner
115+
self.mock_self.worker = MagicMock()
116+
117+
mock_step_output = MagicMock(spec=SamplerOutput)
118+
self.mock_self.worker.execute_model.return_value = [[mock_step_output]]
119+
120+
result, need_transpose = sampler_output(self.mock_self,
121+
self.execute_model_req,
122+
self.sample_len,
123+
self.seq_ids_with_bonus_token)
124+
125+
self.assertEqual(self.mock_self.worker.execute_model.call_count,
126+
self.sample_len)
127+
self.mock_self._append_new_tokens.assert_called()
128+
self.assertEqual(self.mock_self._append_new_tokens.call_count,
129+
self.sample_len)
130+
131+
self.mock_self._filter_model_output.assert_called_once()
132+
self.assertEqual(result, self.filtered_output)
133+
self.assertTrue(need_transpose)
134+
135+
def test_cpu_path_with_hidden_states(self):
136+
self.expanded_request.previous_hidden_states = MagicMock()
137+
138+
mock_model_runner = MagicMock()
139+
mock_model_runner.supports_gpu_multi_step.return_value = False
140+
self.mock_self.model_runner = mock_model_runner
141+
self.mock_self.worker = MagicMock()
142+
143+
self.mock_self.worker.model_runner = MagicMock()
144+
self.mock_self.worker.model_runner.return_hidden_states = False
145+
146+
mock_step_output = MagicMock(spec=SamplerOutput)
147+
self.mock_self.worker.execute_model.return_value = [[mock_step_output]]
148+
149+
sampler_output(self.mock_self, self.execute_model_req, self.sample_len,
150+
self.seq_ids_with_bonus_token)
151+
152+
self.assertTrue(
153+
self.mock_self.worker.model_runner.return_hidden_states)
154+
self.mock_self._maybe_update_previous_hidden_states.assert_called()

0 commit comments

Comments
 (0)