22
22
from pytest_mock import MockerFixture
23
23
from vllm .model_executor .layers .fused_moe import FusedMoEMethodBase
24
24
25
- import vllm_ascend .ops .moe .token_dispatcher as token_dispatcher_module
26
25
from tests .ut .base import TestBase
27
- from vllm_ascend .ascend_forward_context import (FusedMoEState ,
28
- _get_fused_moe_state )
29
26
from vllm_ascend .ops .fused_moe import (AscendFusedMoE ,
30
27
AscendUnquantizedFusedMoEMethod )
31
28
from vllm_ascend .ops .moe .experts_selector import select_experts
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
60
57
61
58
@pytest .fixture
62
59
def mock_dist_env (mocker : MockerFixture ):
63
- mock_setup_token_dispatchers = MagicMock ()
64
- mock_token_dispatcher_with_allgather = MagicMock ()
65
- mock_token_dispatcher_with_all2allv = MagicMock ()
66
- mock_token_dispatcher_with_mc2 = MagicMock ()
67
-
68
- mock_dispatch_result_allgather = {
69
- "hidden_states" : torch .randn (16 , 2 ),
70
- "group_list" : torch .tensor ([8 , 16 ], dtype = torch .int64 ),
71
- "group_list_type" : 0 ,
72
- }
73
- mock_combine_result_allgather = torch .randn (16 , 2 )
74
-
75
- mock_token_dispatcher_with_allgather .token_dispatch .return_value = mock_dispatch_result_allgather
76
- mock_token_dispatcher_with_allgather .token_combine .return_value = mock_combine_result_allgather
77
-
78
- mock_dispatch_result_all2allv = {
79
- "hidden_states" : torch .randn (16 , 2 ),
80
- "group_list" : torch .tensor ([4 , 8 , 12 , 16 ], dtype = torch .int64 ),
81
- "group_list_type" : 1 ,
82
- "dynamic_scale" : None ,
83
- }
84
- mock_combine_result_all2allv = torch .randn (16 , 2 )
85
- mock_token_dispatcher_with_all2allv .token_dispatch .return_value = mock_dispatch_result_all2allv
86
- mock_token_dispatcher_with_all2allv .token_combine .return_value = mock_combine_result_all2allv
87
-
88
- mock_dispatch_result_mc2 = {
89
- "hidden_states" : torch .randn (16 , 2 ),
90
- "group_list" : torch .tensor ([5 , 10 , 15 , 16 ], dtype = torch .int64 ),
91
- "group_list_type" : 1 ,
92
- "dynamic_scale" : None ,
93
- "assist_info_for_combine" : torch .randn (16 , 2 ),
94
- "ep_recv_counts" : torch .tensor ([4 , 4 , 4 , 4 ], dtype = torch .int32 ),
95
- }
96
- mock_combine_result_mc2 = torch .randn (16 , 2 )
97
- mock_token_dispatcher_with_mc2 .token_dispatch .return_value = mock_dispatch_result_mc2
98
- mock_token_dispatcher_with_mc2 .token_combine .return_value = mock_combine_result_mc2
60
+ mock_moe_comm_method = MagicMock ()
99
61
100
- captured_dispatchers = {}
62
+ def mock_prepare (hidden_states , router_logits , ** kwargs ):
63
+ return hidden_states , router_logits
101
64
102
- def capture_register (dispatcher_instance ):
103
- key = dispatcher_instance .__class__ .__name__
104
- captured_dispatchers [key ] = dispatcher_instance
105
- if key == 'TokenDispatcherWithAllGather' :
106
- captured_dispatchers [key ] = mock_token_dispatcher_with_allgather
107
- elif key == 'TokenDispatcherWithAll2AllV' :
108
- captured_dispatchers [key ] = mock_token_dispatcher_with_all2allv
109
- elif key == 'TokenDispatcherWithMC2' :
110
- captured_dispatchers [key ] = mock_token_dispatcher_with_mc2
65
+ mock_moe_comm_method .prepare .side_effect = mock_prepare
111
66
112
- mock_register_token_dispatcher_patcher = patch (
113
- 'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher' ,
114
- side_effect = capture_register )
67
+ mock_fused_experts_result = torch .randn (16 , 2 )
68
+ mock_moe_comm_method .fused_experts .return_value = mock_fused_experts_result
115
69
116
- mock_get_token_dispatcher_patcher = patch (
117
- 'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher' ,
118
- side_effect = lambda name : captured_dispatchers .get (name ))
70
+ def mock_finalize (hidden_states , ** kwargs ):
71
+ return hidden_states
119
72
120
- default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
73
+ mock_moe_comm_method . finalize . side_effect = mock_finalize
121
74
122
75
mock_forward_context_obj = MagicMock (
123
- fused_moe_state = FusedMoEState . AllGather ,
124
- token_dispatcher = default_mock_token_dispatcher ,
76
+ moe_comm_method = mock_moe_comm_method ,
77
+ moe_comm_method_name = "mc2commimpl" ,
125
78
max_tokens_across_dp = 10 ,
126
79
dp_metadata = MagicMock (cu_tokens_across_dp_cpu = [5 , 10 ]),
127
80
mc2_mask = torch .zeros (16 , dtype = torch .bool ),
@@ -131,14 +84,12 @@ def capture_register(dispatcher_instance):
131
84
with patch ('torch.distributed.get_rank' , return_value = 0 ), \
132
85
patch ('torch.distributed.get_world_size' , return_value = 4 ), \
133
86
patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
87
+ patch ('vllm_ascend.ops.moe.token_dispatcher.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
134
88
patch ('vllm_ascend.ops.fused_moe.get_mc2_group' , return_value = mock_ep_and_mc2_group (mocker )), \
135
89
patch ('vllm_ascend.ops.fused_moe.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
136
90
patch ('vllm.distributed.parallel_state.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
137
91
patch ('vllm_ascend.ops.fused_moe.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
138
92
patch ('vllm.model_executor.layers.fused_moe.layer.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
139
- patch ('torch.distributed.all_gather' ), \
140
- patch ('torch.distributed.all_to_all_single' ), \
141
- patch ('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce' ), \
142
93
patch ('vllm.model_executor.layers.fused_moe.config.get_dp_group' ,
143
94
return_value = mock_dp_and_tp_group (mocker )), \
144
95
patch ('vllm_ascend.ops.fused_moe.get_ascend_config' ,
@@ -150,29 +101,29 @@ def capture_register(dispatcher_instance):
150
101
return_value = (3 , torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ]))), \
151
102
patch ('vllm_ascend.ops.fused_moe.get_forward_context' ,
152
103
return_value = mock_forward_context_obj ), \
104
+ patch ('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context' ,
105
+ return_value = mock_forward_context_obj ), \
153
106
patch ('vllm_ascend.ops.fused_moe.get_current_vllm_config' ,
154
107
return_value = MagicMock (
155
108
parallel_config = MagicMock (tensor_parallel_size = 2 ),
156
109
scheduler_config = MagicMock (max_num_seqs = 4 ),
157
110
model_config = MagicMock (max_model_len = 2048 )
158
111
)), \
159
112
patch ("vllm_ascend.utils.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ), \
160
- patch .object (token_dispatcher_module , 'setup_token_dispatchers' , mock_setup_token_dispatchers ), \
161
113
patch ('vllm_ascend.ops.moe.moe_mlp.get_forward_context' ,
162
- return_value = mock_forward_context_obj ):
114
+ return_value = mock_forward_context_obj ), \
115
+ patch ('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher' ,
116
+ return_value = None ), \
117
+ patch ('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher' ,
118
+ return_value = None ), \
119
+ patch ('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher' ,
120
+ return_value = None ):
163
121
164
122
yield {
165
123
'mock_forward_context_obj' : mock_forward_context_obj ,
166
- 'mock_token_dispatcher_with_allgather' :
167
- mock_token_dispatcher_with_allgather ,
168
- 'mock_token_dispatcher_with_all2allv' :
169
- mock_token_dispatcher_with_all2allv ,
170
- 'mock_token_dispatcher_with_mc2' : mock_token_dispatcher_with_mc2 ,
124
+ 'mock_moe_comm_method' : mock_moe_comm_method ,
171
125
}
172
126
173
- mock_register_token_dispatcher_patcher .stop ()
174
- mock_get_token_dispatcher_patcher .stop ()
175
-
176
127
177
128
@pytest .fixture
178
129
def mock_moe_env (mocker : MockerFixture ):
@@ -338,9 +289,7 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
338
289
moe .moe_parallel_config .ep_size = 1
339
290
340
291
moe .quant_method = MockQuantMethod (shared_experts , num_tokens )
341
- forward_context = MagicMock (mc2_mask = torch .zeros (num_tokens ,
342
- dtype = torch .bool ),
343
- padded_num_tokens = num_tokens )
292
+ forward_context = mock_dist_env ['mock_forward_context_obj' ]
344
293
with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
345
294
return_value = forward_context ):
346
295
output = moe .forward (inputs ,
@@ -394,25 +343,10 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
394
343
[[256 , 4 ], [128 , 1 ], [128 , 1 ], [128 , 4 ]])
395
344
def test_apply_without_expert_map (self , moe_method , mock_dist_env ,
396
345
mock_moe_env , others_param ):
397
-
398
346
global_num_experts , ep_size = others_param
399
347
is_prefill = False
400
- is_deepseek_v3_r1 = global_num_experts == 256
401
-
402
- if ep_size == 1 :
403
- selected_token_dispatcher = mock_dist_env [
404
- 'mock_token_dispatcher_with_allgather' ]
405
- elif ep_size < 16 :
406
- selected_token_dispatcher = mock_dist_env [
407
- 'mock_token_dispatcher_with_all2allv' ]
408
- else :
409
- selected_token_dispatcher = mock_dist_env [
410
- 'mock_token_dispatcher_with_mc2' ]
411
348
412
- forward_context = MagicMock (fused_moe_state = _get_fused_moe_state (
413
- ep_size , is_prefill , is_deepseek_v3_r1 ),
414
- with_quant = False ,
415
- token_dispatcher = selected_token_dispatcher )
349
+ forward_context = mock_dist_env ['mock_forward_context_obj' ]
416
350
417
351
with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
418
352
return_value = forward_context ):
@@ -438,35 +372,22 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
438
372
global_num_experts = global_num_experts ,
439
373
is_prefill = is_prefill )
440
374
441
- expected_shape = (16 , 2 )
375
+ mock_moe_comm_method = mock_dist_env ['mock_moe_comm_method' ]
376
+ mock_moe_comm_method .fused_experts .assert_called_once ()
442
377
378
+ expected_shape = (16 , 2 )
443
379
assert result .shape == expected_shape
444
380
445
381
@pytest .mark .parametrize ("others_param" , [16 , 1 , 4 ])
446
382
def test_apply_with_expert_map (self , moe_method , mock_dist_env ,
447
383
mock_moe_env , others_param ):
448
-
449
384
ep_size = others_param
450
385
is_prefill = False
451
386
452
- if ep_size == 1 :
453
- selected_token_dispatcher = mock_dist_env [
454
- 'mock_token_dispatcher_with_allgather' ]
455
- elif ep_size < 16 :
456
- selected_token_dispatcher = mock_dist_env [
457
- 'mock_token_dispatcher_with_all2allv' ]
458
- else :
459
- selected_token_dispatcher = mock_dist_env [
460
- 'mock_token_dispatcher_with_mc2' ]
461
-
462
- forward_context = MagicMock (fused_moe_state = _get_fused_moe_state (
463
- ep_size , is_prefill , True ),
464
- with_quant = False ,
465
- token_dispatcher = selected_token_dispatcher )
387
+ forward_context = mock_dist_env ['mock_forward_context_obj' ]
466
388
467
389
with patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
468
390
patch ("vllm_ascend.utils.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ):
469
-
470
391
expert_map = torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ])
471
392
moe_method .ep_size = ep_size
472
393
x = torch .randn (8 , 2 , 2 )
@@ -493,8 +414,10 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
493
414
expert_map = expert_map ,
494
415
is_prefill = is_prefill )
495
416
496
- expected_shape = (16 , 2 )
417
+ mock_moe_comm_method = mock_dist_env ['mock_moe_comm_method' ]
418
+ mock_moe_comm_method .fused_experts .assert_called_once ()
497
419
420
+ expected_shape = (16 , 2 )
498
421
assert result .shape == expected_shape
499
422
500
423
@@ -574,7 +497,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
574
497
mock_get_forward_context ):
575
498
576
499
mock_forward_context = MagicMock ()
577
- mock_forward_context .fused_moe_state = FusedMoEState . MC2
500
+ mock_forward_context .moe_comm_method_name = "mc2commimpl"
578
501
mock_get_forward_context .return_value = mock_forward_context
579
502
580
503
mock_is_310p .return_value = False
@@ -618,8 +541,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
618
541
with_quant = True )
619
542
620
543
mock_get_forward_context .assert_called ()
621
- self .assertEqual (mock_forward_context .fused_moe_state ,
622
- FusedMoEState .MC2 )
623
544
624
545
mock_npu_dynamic_quant .assert_called ()
625
546
0 commit comments