|
15 | 15 |
|
16 | 16 | import pytest
|
17 | 17 |
|
| 18 | +from vllm_ascend.ascend_forward_context import MoECommType |
18 | 19 | from vllm_ascend.utils import AscendSocVersion
|
19 | 20 | from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
20 | 21 |
|
|
24 | 25 | "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
25 | 26 | [
|
26 | 27 | # Case 1: Expert parallel is disabled, should always be 'allgather'
|
27 |
| - (AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"), |
28 |
| - (AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"), |
| 28 | + (AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER), |
| 29 | + (AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER), |
29 | 30 |
|
30 | 31 | # Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
31 |
| - (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"), |
32 |
| - (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"), |
33 |
| - (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition |
| 32 | + (AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL), |
| 33 | + (AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL), |
| 34 | + (AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition |
34 | 35 |
|
35 | 36 | # Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
36 |
| - (AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"), |
37 |
| - (AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"), |
| 37 | + (AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER), |
| 38 | + (AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER), |
38 | 39 |
|
39 | 40 | # Case 4: A3 SOC
|
40 |
| - (AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"), |
41 |
| - (AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"), |
| 41 | + (AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2), |
| 42 | + (AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL), |
42 | 43 | ])
|
43 | 44 | # yapf: enable
|
44 | 45 | def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
|
0 commit comments