|
3 | 3 | import torch
|
4 | 4 |
|
5 | 5 | from tests.ut.base import TestBase
|
6 |
| -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ |
7 |
| - torchair_fused_experts_with_all2all |
| 6 | +from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( |
| 7 | + torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) |
| 8 | +from vllm_ascend.utils import AscendSocVersion |
8 | 9 |
|
9 | 10 |
|
10 | 11 | class TestAscendW8A8FusedMoEMethod(TestBase):
|
@@ -73,3 +74,57 @@ def test_torchair_fused_experts_with_all2all(
|
73 | 74 | self.assertIsNotNone(result)
|
74 | 75 | self.assertEqual(result.dtype, torch.bfloat16)
|
75 | 76 | self.assertEqual(result.shape, (128, 128))
|
| 77 | + |
| 78 | + @patch.dict('os.environ', { |
| 79 | + 'HCCL_INTRA_ROCE_ENABLE': '0', |
| 80 | + 'HCCL_INTRA_PCIE_ENABLE': '1' |
| 81 | + }) |
| 82 | + @patch( |
| 83 | + "vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_soc_version" |
| 84 | + ) |
| 85 | + @patch( |
| 86 | + 'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group' |
| 87 | + ) |
| 88 | + @patch('torch_npu.npu_moe_distribute_combine_v2') |
| 89 | + @patch('torch_npu.npu_moe_distribute_dispatch_v2') |
| 90 | + @patch( |
| 91 | + 'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode' |
| 92 | + ) |
| 93 | + def test_torchair_fused_experts_with_mc2_a2_optimization( |
| 94 | + self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group, |
| 95 | + mock_ascend_soc_version): |
| 96 | + """Test expert_scales is passed in A2 SOC version with mc2 optimization""" |
| 97 | + # Setup mocks |
| 98 | + mock_ascend_soc_version.return_value = AscendSocVersion.A2 |
| 99 | + |
| 100 | + mock_group = MagicMock() |
| 101 | + mock_group.rank_in_group = 0 |
| 102 | + mock_group.world_size = 4 |
| 103 | + mock_get_group.return_value = mock_group |
| 104 | + |
| 105 | + mock_combine.return_value = self.placeholder |
| 106 | + |
| 107 | + mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1), |
| 108 | + torch.randint(0, 32, (32, )), |
| 109 | + torch.randint(1, 5, (8, )), |
| 110 | + torch.randint(1, 5, (4, )), None, |
| 111 | + torch.randn(32)) |
| 112 | + mock_mlp_decode.return_value = self.placeholder |
| 113 | + |
| 114 | + result = torchair_fused_experts_with_mc2( |
| 115 | + hidden_states=self.placeholder, |
| 116 | + w1=self.placeholder, |
| 117 | + w2=self.placeholder, |
| 118 | + w1_scale=self.placeholder, |
| 119 | + w2_scale=self.placeholder, |
| 120 | + topk_weights=self.placeholder, |
| 121 | + topk_ids=self.placeholder, |
| 122 | + top_k=2, |
| 123 | + mc2_mask=self.placeholder) |
| 124 | + |
| 125 | + # Check that expert_scales was passed to dispatch |
| 126 | + call_args = mock_dispatch.call_args[1] |
| 127 | + self.assertIn('expert_scales', call_args) |
| 128 | + |
| 129 | + self.assertIsInstance(result, torch.Tensor) |
| 130 | + self.assertEqual(result.shape, self.placeholder.shape) |
0 commit comments