Skip to content

Commit 1b5b021

Browse files
committed
add UT
Signed-off-by: Liccol <740821011@qq.com>
1 parent dc30870 commit 1b5b021

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import torch
44

55
from tests.ut.base import TestBase
6-
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
7-
torchair_fused_experts_with_all2all
6+
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
7+
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
8+
from vllm_ascend.utils import AscendSocVersion
89

910

1011
class TestAscendW8A8FusedMoEMethod(TestBase):
@@ -73,3 +74,57 @@ def test_torchair_fused_experts_with_all2all(
7374
self.assertIsNotNone(result)
7475
self.assertEqual(result.dtype, torch.bfloat16)
7576
self.assertEqual(result.shape, (128, 128))
77+
78+
@patch.dict('os.environ', {
79+
'HCCL_INTRA_ROCE_ENABLE': '0',
80+
'HCCL_INTRA_PCIE_ENABLE': '1'
81+
})
82+
@patch(
83+
"vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_soc_version"
84+
)
85+
@patch(
86+
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group'
87+
)
88+
@patch('torch_npu.npu_moe_distribute_combine_v2')
89+
@patch('torch_npu.npu_moe_distribute_dispatch_v2')
90+
@patch(
91+
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode'
92+
)
93+
def test_torchair_fused_experts_with_mc2_a2_optimization(
94+
self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group,
95+
mock_ascend_soc_version):
96+
"""Test expert_scales is passed in A2 SOC version with mc2 optimization"""
97+
# Setup mocks
98+
mock_ascend_soc_version.return_value = AscendSocVersion.A2
99+
100+
mock_group = MagicMock()
101+
mock_group.rank_in_group = 0
102+
mock_group.world_size = 4
103+
mock_get_group.return_value = mock_group
104+
105+
mock_combine.return_value = self.placeholder
106+
107+
mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1),
108+
torch.randint(0, 32, (32, )),
109+
torch.randint(1, 5, (8, )),
110+
torch.randint(1, 5, (4, )), None,
111+
torch.randn(32))
112+
mock_mlp_decode.return_value = self.placeholder
113+
114+
result = torchair_fused_experts_with_mc2(
115+
hidden_states=self.placeholder,
116+
w1=self.placeholder,
117+
w2=self.placeholder,
118+
w1_scale=self.placeholder,
119+
w2_scale=self.placeholder,
120+
topk_weights=self.placeholder,
121+
topk_ids=self.placeholder,
122+
top_k=2,
123+
mc2_mask=self.placeholder)
124+
125+
# Check that expert_scales was passed to dispatch
126+
call_args = mock_dispatch.call_args[1]
127+
self.assertIn('expert_scales', call_args)
128+
129+
self.assertIsInstance(result, torch.Tensor)
130+
self.assertEqual(result.shape, self.placeholder.shape)

0 commit comments

Comments
 (0)