15
15
16
16
import pytest
17
17
18
- from vllm_ascend .utils import AscendSocVersion
19
18
from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
20
19
21
20
24
23
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method" ,
25
24
[
26
25
# Case 1: Expert parallel is disabled, should always be 'allgather'
27
- (AscendSocVersion . A2 , False , 8 , 100 , 256 , None , "allgather" ),
28
- (AscendSocVersion . A3 , False , 16 , 500 , 256 , None , "allgather" ),
26
+ ("A2" , False , 8 , 100 , 256 , None , "allgather" ),
27
+ ("A3" , False , 16 , 500 , 256 , None , "allgather" ),
29
28
30
29
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
31
- (AscendSocVersion . A2 , True , 8 , 100 , 256 , "w4a8_dynamic" , "alltoall" ),
32
- (AscendSocVersion . A2 , True , 16 , 257 , 256 , "w4a8_dynamic" , "alltoall" ),
33
- (AscendSocVersion . A2 , True , 16 , 100 , 256 , "w4a8_dynamic" , "mc2" ), # meets mc2 condition
30
+ ("A2" , True , 8 , 100 , 256 , "w4a8_dynamic" , "alltoall" ),
31
+ ("A2" , True , 16 , 257 , 256 , "w4a8_dynamic" , "alltoall" ),
32
+ ("A2" , True , 16 , 100 , 256 , "w4a8_dynamic" , "mc2" ), # meets mc2 condition
34
33
35
34
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
36
- (AscendSocVersion . A2 , True , 8 , 100 , 256 , None , "allgather" ),
37
- (AscendSocVersion . A2 , True , 16 , 257 , 256 , None , "allgather" ),
35
+ ("A2" , True , 8 , 100 , 256 , None , "allgather" ),
36
+ ("A2" , True , 16 , 257 , 256 , None , "allgather" ),
38
37
39
38
# Case 4: A3 SOC
40
- (AscendSocVersion .A3 , True , 8 , 100 , 256 , None , "mc2" ),
41
- (AscendSocVersion .A3 , True , 8 , 257 , 256 , None , "alltoall" ),
39
+ ("A3" , True , 8 , 100 , 256 , None , "mc2" ),
40
+ ("A3" , True , 8 , 257 , 256 , None , "alltoall" ),
41
+
42
+ # Case 5: P3 SOC
43
+ ("310P" , True , 8 , 100 , 256 , None , "allgather" ),
44
+ ("310P" , True , 8 , 257 , 256 , None , "allgather" ),
42
45
])
43
46
# yapf: enable
44
47
def test_select_moe_comm_method (soc_version , enable_expert_parallel ,
@@ -64,8 +67,8 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
64
67
mock_runner .vllm_config = mock_vllm_config
65
68
66
69
# Patch the helper functions
67
- with patch ('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version ' ,
68
- return_value = soc_version ), \
70
+ with patch ('vllm_ascend._build_info.__ascend_soc_version__ ' ,
71
+ new = soc_version ), \
69
72
patch ('vllm_ascend.worker.model_runner_v1.is_global_first_rank' ,
70
73
return_value = True ):
71
74
@@ -97,8 +100,8 @@ def test_select_moe_comm_method_unsupported_soc():
97
100
98
101
unsupported_soc = "UnsupportedSOC"
99
102
100
- with patch ('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version ' ,
101
- return_value = unsupported_soc ), \
103
+ with patch ('vllm_ascend._build_info.__ascend_soc_version__ ' ,
104
+ new = unsupported_soc ), \
102
105
patch ('vllm_ascend.worker.model_runner_v1.is_global_first_rank' ,
103
106
return_value = True ), \
104
107
pytest .raises (ValueError , match = f"Unsupported soc_version: { unsupported_soc } " ):
0 commit comments