diff --git a/tests/ut/patch/worker/patch_common/test_patch_distributed.py b/tests/ut/patch/worker/patch_common/test_patch_distributed.py index 73525eefd7..4975313890 100644 --- a/tests/ut/patch/worker/patch_common/test_patch_distributed.py +++ b/tests/ut/patch/worker/patch_common/test_patch_distributed.py @@ -13,15 +13,100 @@ # This file is a part of the vllm-ascend project. # +from unittest.mock import MagicMock, patch + +import torch +from vllm.distributed.parallel_state import GroupCoordinator + from tests.ut.base import TestBase +from vllm_ascend.patch.worker.patch_common.patch_distributed import \ + GroupCoordinatorPatch class TestPatchDistributed(TestBase): - def test_GroupCoordinator_patched(self): - from vllm.distributed.parallel_state import GroupCoordinator + def setUp(self): + self.mock_group_ranks = [[0, 1]] + self.mock_local_rank = 0 + self.mock_backend = "hccl" + self.mock_use_device_comm = True + + patcher_get_rank = patch("torch.distributed.get_rank", return_value=0) + patcher_new_group = patch("torch.distributed.new_group", + return_value=MagicMock()) + patcher_is_cuda_alike = patch( + "vllm.platforms.current_platform.is_cuda_alike", return_value=True) + patcher_device_comm_cls = patch( + "vllm.distributed.parallel_state.resolve_obj_by_qualname", + return_value=MagicMock()) + + self.mock_get_rank = patcher_get_rank.start() + self.mock_new_group = patcher_new_group.start() + self.mock_is_cuda_alike = patcher_is_cuda_alike.start() + self.mock_resolve_obj = patcher_device_comm_cls.start() - from vllm_ascend.patch.worker.patch_common.patch_distributed import \ - GroupCoordinatorPatch + self.addCleanup(patcher_get_rank.stop) + self.addCleanup(patcher_new_group.stop) + self.addCleanup(patcher_is_cuda_alike.stop) + self.addCleanup(patcher_device_comm_cls.stop) + self.group_coordinator = GroupCoordinatorPatch( + group_ranks=self.mock_group_ranks, + local_rank=self.mock_local_rank, + torch_distributed_backend=self.mock_backend, + use_device_communicator=self.mock_use_device_comm) + + def test_GroupCoordinator_patched(self): self.assertIs(GroupCoordinator, GroupCoordinatorPatch) + + def test_all_to_all_returns_input_when_world_size_1(self): + self.group_coordinator.world_size = 1 + input_tensor = torch.randn(2, 3) + output = self.group_coordinator.all_to_all(input_tensor) + self.assertTrue(torch.equal(output, input_tensor)) + + def test_all_to_all_raises_assertion_on_invalid_scatter_dim(self): + input_tensor = torch.randn(2, 3) + with self.assertRaises(AssertionError) as cm: + self.group_coordinator.all_to_all(input_tensor, scatter_dim=2) + self.assertIn("Invalid scatter dim", str(cm.exception)) + + def test_all_to_all_raises_assertion_on_invalid_gather_dim(self): + input_tensor = torch.randn(2, 3) + with self.assertRaises(AssertionError) as cm: + self.group_coordinator.all_to_all(input_tensor, gather_dim=2) + self.assertIn("Invalid gather dim", str(cm.exception)) + + def test_all_to_all_calls_device_communicator_with_correct_args(self): + mock_communicator = MagicMock() + self.group_coordinator.device_communicator = mock_communicator + + input_tensor = torch.randn(2, 3) + scatter_dim = 0 + gather_dim = 1 + scatter_sizes = [1, 1] + gather_sizes = [1, 1] + + self.group_coordinator.all_to_all(input_tensor, + scatter_dim=scatter_dim, + gather_dim=gather_dim, + scatter_sizes=scatter_sizes, + gather_sizes=gather_sizes) + + mock_communicator.all_to_all.assert_called_once_with( + input_tensor, scatter_dim, gather_dim, scatter_sizes, gather_sizes) + + def test_all_to_all_calls_device_communicator_without_sizes(self): + mock_communicator = MagicMock() + self.group_coordinator.device_communicator = mock_communicator + + input_tensor = torch.randn(2, 3) + scatter_dim = 0 + gather_dim = 1 + + self.group_coordinator.all_to_all(input_tensor, + scatter_dim=scatter_dim, + gather_dim=gather_dim) + + mock_communicator.all_to_all.assert_called_once_with( + input_tensor, scatter_dim, gather_dim, None, None) diff --git a/tests/ut/patch/worker/patch_common/test_patch_minicpm.py b/tests/ut/patch/worker/patch_common/test_patch_minicpm.py new file mode 100644 index 0000000000..47d195715d --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_minicpm.py @@ -0,0 +1,77 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from unittest.mock import MagicMock + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.patch.worker.patch_common.patch_minicpm import forward + + +class TestPatchMiniCPM(TestBase): + + def setUp(self): + self.mock_self = MagicMock() + + self.mock_self.q_size = 128 + self.mock_self.kv_size = 128 + + self.mock_self.qkv_proj = MagicMock() + self.mock_self.rotary_emb = MagicMock() + self.mock_self.attn = MagicMock() + self.mock_self.o_proj = MagicMock() + + self.positions = torch.tensor([1, 2, 3]) + self.hidden_states = torch.randn(3, 256) + + self.mock_qkv = torch.randn(3, 384) + self.mock_q = self.mock_qkv[:, :128] + self.mock_k = self.mock_qkv[:, 128:256] + self.mock_v = self.mock_qkv[:, 256:] + + self.mock_self.qkv_proj.return_value = (self.mock_qkv, None) + self.mock_self.rotary_emb.return_value = (self.mock_q, self.mock_k) + self.mock_self.attn.return_value = torch.randn(3, 256) + self.mock_self.o_proj.return_value = (torch.randn(3, 256), None) + + def test_forward_patched(self): + from vllm.model_executor.models.minicpm import MiniCPMAttention + + self.assertIs(MiniCPMAttention.forward, forward) + + def test_forward_function(self): + result = forward(self.mock_self, self.positions, self.hidden_states) + + self.mock_self.qkv_proj.assert_called_once_with(self.hidden_states) + + args, _ = self.mock_self.rotary_emb.call_args + self.assertEqual(len(args), 3) + self.assertTrue(torch.equal(args[0], self.positions)) + self.assertTrue(torch.equal(args[1], self.mock_q)) + self.assertTrue(torch.equal(args[2], self.mock_k)) + + args, _ = self.mock_self.attn.call_args + self.assertEqual(len(args), 3) + self.assertTrue(torch.equal(args[0], self.mock_q)) + self.assertTrue(torch.equal(args[1], self.mock_k)) + self.assertTrue(torch.equal(args[2], self.mock_v)) + + self.mock_self.o_proj.assert_called_once_with( + self.mock_self.attn.return_value) + + self.assertEqual(result.shape, (3, 256)) + self.assertTrue( + torch.equal(result, self.mock_self.o_proj.return_value[0])) diff --git a/tests/ut/patch/worker/patch_common/test_patch_utils.py b/tests/ut/patch/worker/patch_common/test_patch_utils.py new file mode 100644 index 0000000000..d64e83346f --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_utils.py @@ -0,0 +1,104 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from typing import List, Optional +from unittest.mock import MagicMock, patch + +import torch +from torch.library import Library + +from tests.ut.base import TestBase +from vllm_ascend.patch.worker.patch_common.patch_utils import \ + ascend_direct_register_custom_op + + +class TestPatchUtils(TestBase): + + def setUp(self): + super().setUp() + + self.mock_op_func = MagicMock() + self.mock_op_func.__annotations__ = { + 'param1': list[int], + 'param2': Optional[list[int]], + 'param3': str + } + + self.mock_fake_impl = MagicMock() + self.mock_lib = MagicMock(spec=Library) + + self.op_name = "test_op" + self.mutates_args = ["arg1"] + self.dispatch_key = "NPU" + self.tags = (torch.Tag.pt2_compliant_tag, ) + + self.patch_infer_schema = patch( + 'vllm_ascend.patch.worker.patch_common.patch_utils.torch.library.infer_schema' + ) + self.patch_vllm_lib = patch( + 'vllm_ascend.patch.worker.patch_common.patch_utils.vllm_lib') + + self.mock_infer_schema = self.patch_infer_schema.start() + self.mock_vllm_lib = self.patch_vllm_lib.start() + + self.addCleanup(self.patch_infer_schema.stop) + self.addCleanup(self.patch_vllm_lib.stop) + + def test_utils_patched(self): + from vllm import utils + + self.assertIs(utils.direct_register_custom_op, + ascend_direct_register_custom_op) + + def test_register_with_default_lib(self): + self.mock_infer_schema.return_value = "(Tensor self) -> Tensor" + + ascend_direct_register_custom_op(op_name=self.op_name, + op_func=self.mock_op_func, + mutates_args=self.mutates_args, + fake_impl=self.mock_fake_impl, + dispatch_key=self.dispatch_key, + tags=self.tags) + + self.assertEqual(self.mock_op_func.__annotations__['param1'], + List[int]) + self.assertEqual(self.mock_op_func.__annotations__['param2'], + Optional[List[int]]) + self.assertEqual(self.mock_op_func.__annotations__['param3'], str) + + self.mock_infer_schema.assert_called_once_with( + self.mock_op_func, mutates_args=self.mutates_args) + + self.mock_vllm_lib.define.assert_called_once_with( + f"{self.op_name}(Tensor self) -> Tensor", tags=self.tags) + self.mock_vllm_lib.impl.assert_called_once_with( + self.op_name, self.mock_op_func, dispatch_key=self.dispatch_key) + self.mock_vllm_lib._register_fake.assert_called_once_with( + self.op_name, self.mock_fake_impl) + + def test_register_with_custom_lib(self): + self.mock_infer_schema.return_value = "(Tensor a, Tensor b) -> Tensor" + + ascend_direct_register_custom_op(op_name=self.op_name, + op_func=self.mock_op_func, + mutates_args=self.mutates_args, + target_lib=self.mock_lib) + + self.mock_lib.define.assert_called_once_with( + f"{self.op_name}(Tensor a, Tensor b) -> Tensor", tags=()) + self.mock_lib.impl.assert_called_once_with(self.op_name, + self.mock_op_func, + dispatch_key="CUDA") + self.mock_lib._register_fake.assert_not_called()