16
16
import pytest
17
17
18
18
from vllm_ascend .ascend_forward_context import MoECommType
19
- from vllm_ascend .utils import AscendSocVersion
20
19
from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
21
20
22
21
25
24
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method" ,
26
25
[
27
26
# Case 1: Expert parallel is disabled, should always be 'allgather'
28
- (AscendSocVersion . A2 , False , 8 , 100 , 256 , None , MoECommType . ALLGATHER ),
29
- (AscendSocVersion . A3 , False , 16 , 500 , 256 , None , MoECommType . ALLGATHER ),
27
+ ("A2" , False , 8 , 100 , 256 , None , "allgather" ),
28
+ ("A3" , False , 16 , 500 , 256 , None , "allgather" ),
30
29
31
30
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
32
- (AscendSocVersion . A2 , True , 8 , 100 , 256 , "w4a8_dynamic" , MoECommType . ALLTOALL ),
33
- (AscendSocVersion . A2 , True , 16 , 257 , 256 , "w4a8_dynamic" , MoECommType . ALLTOALL ),
34
- (AscendSocVersion . A2 , True , 16 , 100 , 256 , "w4a8_dynamic" , MoECommType . MC2 ), # meets mc2 condition
31
+ ("A2" , True , 8 , 100 , 256 , "w4a8_dynamic" , "alltoall" ),
32
+ ("A2" , True , 16 , 257 , 256 , "w4a8_dynamic" , "alltoall" ),
33
+ ("A2" , True , 16 , 100 , 256 , "w4a8_dynamic" , "mc2" ), # meets mc2 condition
35
34
36
35
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
37
- (AscendSocVersion . A2 , True , 8 , 100 , 256 , None , MoECommType . ALLGATHER ),
38
- (AscendSocVersion . A2 , True , 16 , 257 , 256 , None , MoECommType . ALLGATHER ),
36
+ ("A2" , True , 8 , 100 , 256 , None , "allgather" ),
37
+ ("A2" , True , 16 , 257 , 256 , None , "allgather" ),
39
38
40
39
# Case 4: A3 SOC
41
- (AscendSocVersion .A3 , True , 8 , 100 , 256 , None , MoECommType .MC2 ),
42
- (AscendSocVersion .A3 , True , 8 , 257 , 256 , None , MoECommType .ALLTOALL ),
40
+ ("A3" , True , 8 , 100 , 256 , None , "mc2" ),
41
+ ("A3" , True , 8 , 257 , 256 , None , "alltoall" ),
42
+
43
+ # Case 5: P3 SOC
44
+ ("310P" , True , 8 , 100 , 256 , None , "allgather" ),
45
+ ("310P" , True , 8 , 257 , 256 , None , "allgather" ),
43
46
])
44
47
# yapf: enable
45
48
def test_select_moe_comm_method (soc_version , enable_expert_parallel ,
@@ -65,8 +68,8 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
65
68
mock_runner .vllm_config = mock_vllm_config
66
69
67
70
# Patch the helper functions
68
- with patch ('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version ' ,
69
- return_value = soc_version ), \
71
+ with patch ('vllm_ascend._build_info.__ascend_soc_version__ ' ,
72
+ new = soc_version ), \
70
73
patch ('vllm_ascend.worker.model_runner_v1.is_global_first_rank' ,
71
74
return_value = True ):
72
75
@@ -98,8 +101,8 @@ def test_select_moe_comm_method_unsupported_soc():
98
101
99
102
unsupported_soc = "UnsupportedSOC"
100
103
101
- with patch ('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version ' ,
102
- return_value = unsupported_soc ), \
104
+ with patch ('vllm_ascend._build_info.__ascend_soc_version__ ' ,
105
+ new = unsupported_soc ), \
103
106
patch ('vllm_ascend.worker.model_runner_v1.is_global_first_rank' ,
104
107
return_value = True ), \
105
108
pytest .raises (ValueError , match = f"Unsupported soc_version: { unsupported_soc } " ):
0 commit comments