Skip to content

Commit 6a92a27

Browse files
author
offline0806
committed
Merge remote-tracking branch 'upstream_gitee/main' into main_eplb_0916
# Conflicts: # vllm_ascend/ops/common_fused_moe.py # vllm_ascend/ops/fused_moe.py # vllm_ascend/ops/moe/moe_comm_method.py # vllm_ascend/worker/model_runner_v1.py
2 parents 9b724b5 + 18ca786 commit 6a92a27

18 files changed

+527
-601
lines changed

tests/ut/models/test_deepseek_v2.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323

2424
from vllm_ascend.models.deepseek_v2 import (
2525
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
26-
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
27-
CustomDeepseekV2RowParallelLinear,
26+
CustomDeepseekV2MLP, CustomDeepseekV2RowParallelLinear,
2827
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
2928
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
3029

@@ -213,22 +212,6 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
213212
quant_config=None)
214213

215214

216-
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
217-
mock_forward_context):
218-
base_config.n_shared_experts = 1
219-
moe = CustomDeepseekV2MoE(config=base_config,
220-
quant_config=None,
221-
prefix="mlp")
222-
assert moe.top_k == 2
223-
224-
x = torch.randn(2, 4, 128)
225-
attn_metadata = Mock(num_prefills=1)
226-
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
227-
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
228-
output = moe(x, attn_metadata)
229-
assert output.shape == (2, 4, 128)
230-
231-
232215
@patch("torch_npu.npu_rms_norm")
233216
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
234217
base_config):

tests/ut/ops/test_ascend_forwad_context.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

tests/ut/ops/test_fused_moe_prepare_and_finalize.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
88
FusedMoEPrepareAndFinalizeWithAll2All,
9-
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)
9+
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
10+
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
1011

1112

1213
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
@@ -216,3 +217,68 @@ def mock_reduce_scatter_func(tensor, dim):
216217
mock_tp_all_reduce.return_value = result
217218
result_with_tp = layer.finalize(h_out, reduce_results=True)
218219
self.assertEqual(result_with_tp.shape[0], 3)
220+
221+
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
222+
@patch(
223+
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
224+
)
225+
@patch(
226+
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
227+
)
228+
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
229+
mock_tp_all_reduce,
230+
mock_get_dp_group):
231+
# Mock forward context with DP metadata
232+
mock_context = MagicMock()
233+
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
234+
[2, 5, 7])
235+
mock_get_forward_context.return_value = mock_context
236+
237+
# Setup DP group mock
238+
mock_dp_group = MagicMock()
239+
mock_dp_group.broadcast = MagicMock()
240+
mock_dp_group.all_reduce = MagicMock()
241+
mock_get_dp_group.return_value = mock_dp_group
242+
243+
# Mock all_reduce to just return input (simulate sum)
244+
def mock_all_reduce(tensor):
245+
return tensor * 2
246+
247+
mock_dp_group.all_reduce.side_effect = mock_all_reduce
248+
249+
# Setup config
250+
self.moe_config.dp_size = 3
251+
self.moe_config.dp_rank = 1
252+
self.moe_config.tp_size = 1
253+
self.moe_config.ep_size = 1
254+
255+
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
256+
257+
# Local inputs
258+
hidden_states = torch.randn(3, 8)
259+
router_logits = torch.randn(3, 2)
260+
261+
# Mock gate for router logits recomputation
262+
mock_gate = MagicMock()
263+
mock_gate.return_value = (torch.randn(7, 2), None)
264+
265+
# Run prepare
266+
h_out, r_out, _ = layer.prepare(hidden_states,
267+
router_logits,
268+
rm_router_logits=False,
269+
gate=mock_gate)
270+
271+
# Should be global tensor: [7, 8] and [7, 2]
272+
self.assertEqual(h_out.shape, (7, 8))
273+
self.assertEqual(r_out.shape, (7, 2))
274+
275+
# Run finalize
276+
result = layer.finalize(h_out, reduce_results=False)
277+
278+
# Should slice back to local: [3, 8]
279+
self.assertEqual(result.shape, (3, 8))
280+
281+
# Test with reduce_results=True and TP/EP > 1
282+
mock_tp_all_reduce.return_value = result
283+
result_with_tp = layer.finalize(h_out, reduce_results=True)
284+
self.assertEqual(result_with_tp.shape, (3, 8))

tests/ut/ops/test_fused_ops.py

Lines changed: 32 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from pytest_mock import MockerFixture
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

25-
import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
2625
from tests.ut.base import TestBase
27-
from vllm_ascend.ascend_forward_context import (FusedMoEState,
28-
_get_fused_moe_state)
2926
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
3027
AscendUnquantizedFusedMoEMethod)
3128
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
6057

6158
@pytest.fixture
6259
def mock_dist_env(mocker: MockerFixture):
63-
mock_setup_token_dispatchers = MagicMock()
64-
mock_token_dispatcher_with_allgather = MagicMock()
65-
mock_token_dispatcher_with_all2allv = MagicMock()
66-
mock_token_dispatcher_with_mc2 = MagicMock()
67-
68-
mock_dispatch_result_allgather = {
69-
"hidden_states": torch.randn(16, 2),
70-
"group_list": torch.tensor([8, 16], dtype=torch.int64),
71-
"group_list_type": 0,
72-
}
73-
mock_combine_result_allgather = torch.randn(16, 2)
74-
75-
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
76-
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
77-
78-
mock_dispatch_result_all2allv = {
79-
"hidden_states": torch.randn(16, 2),
80-
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
81-
"group_list_type": 1,
82-
"dynamic_scale": None,
83-
}
84-
mock_combine_result_all2allv = torch.randn(16, 2)
85-
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
86-
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
87-
88-
mock_dispatch_result_mc2 = {
89-
"hidden_states": torch.randn(16, 2),
90-
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
91-
"group_list_type": 1,
92-
"dynamic_scale": None,
93-
"assist_info_for_combine": torch.randn(16, 2),
94-
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
95-
}
96-
mock_combine_result_mc2 = torch.randn(16, 2)
97-
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
98-
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
60+
mock_moe_comm_method = MagicMock()
9961

100-
captured_dispatchers = {}
62+
def mock_prepare(hidden_states, router_logits, **kwargs):
63+
return hidden_states, router_logits
10164

102-
def capture_register(dispatcher_instance):
103-
key = dispatcher_instance.__class__.__name__
104-
captured_dispatchers[key] = dispatcher_instance
105-
if key == 'TokenDispatcherWithAllGather':
106-
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
107-
elif key == 'TokenDispatcherWithAll2AllV':
108-
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
109-
elif key == 'TokenDispatcherWithMC2':
110-
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
65+
mock_moe_comm_method.prepare.side_effect = mock_prepare
11166

112-
mock_register_token_dispatcher_patcher = patch(
113-
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
114-
side_effect=capture_register)
67+
mock_fused_experts_result = torch.randn(16, 2)
68+
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
11569

116-
mock_get_token_dispatcher_patcher = patch(
117-
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
118-
side_effect=lambda name: captured_dispatchers.get(name))
70+
def mock_finalize(hidden_states, **kwargs):
71+
return hidden_states
11972

120-
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
73+
mock_moe_comm_method.finalize.side_effect = mock_finalize
12174

12275
mock_forward_context_obj = MagicMock(
123-
fused_moe_state=FusedMoEState.AllGather,
124-
token_dispatcher=default_mock_token_dispatcher,
76+
moe_comm_method=mock_moe_comm_method,
77+
moe_comm_method_name="mc2commimpl",
12578
max_tokens_across_dp=10,
12679
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
12780
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -131,14 +84,12 @@ def capture_register(dispatcher_instance):
13184
with patch('torch.distributed.get_rank', return_value=0), \
13285
patch('torch.distributed.get_world_size', return_value=4), \
13386
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
87+
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
13488
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
13589
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
13690
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
13791
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
13892
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
139-
patch('torch.distributed.all_gather'), \
140-
patch('torch.distributed.all_to_all_single'), \
141-
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
14293
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
14394
return_value=mock_dp_and_tp_group(mocker)), \
14495
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
@@ -150,29 +101,29 @@ def capture_register(dispatcher_instance):
150101
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
151102
patch('vllm_ascend.ops.fused_moe.get_forward_context',
152103
return_value=mock_forward_context_obj), \
104+
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
105+
return_value=mock_forward_context_obj), \
153106
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
154107
return_value=MagicMock(
155108
parallel_config=MagicMock(tensor_parallel_size=2),
156109
scheduler_config=MagicMock(max_num_seqs=4),
157110
model_config=MagicMock(max_model_len=2048)
158111
)), \
159112
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
160-
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
161113
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
162-
return_value=mock_forward_context_obj):
114+
return_value=mock_forward_context_obj), \
115+
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
116+
return_value=None), \
117+
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
118+
return_value=None), \
119+
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
120+
return_value=None):
163121

164122
yield {
165123
'mock_forward_context_obj': mock_forward_context_obj,
166-
'mock_token_dispatcher_with_allgather':
167-
mock_token_dispatcher_with_allgather,
168-
'mock_token_dispatcher_with_all2allv':
169-
mock_token_dispatcher_with_all2allv,
170-
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
124+
'mock_moe_comm_method': mock_moe_comm_method,
171125
}
172126

173-
mock_register_token_dispatcher_patcher.stop()
174-
mock_get_token_dispatcher_patcher.stop()
175-
176127

177128
@pytest.fixture
178129
def mock_moe_env(mocker: MockerFixture):
@@ -338,9 +289,7 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
338289
moe.moe_parallel_config.ep_size = 1
339290

340291
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
341-
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
342-
dtype=torch.bool),
343-
padded_num_tokens=num_tokens)
292+
forward_context = mock_dist_env['mock_forward_context_obj']
344293
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
345294
return_value=forward_context):
346295
output = moe.forward(inputs,
@@ -394,25 +343,10 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
394343
[[256, 4], [128, 1], [128, 1], [128, 4]])
395344
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
396345
mock_moe_env, others_param):
397-
398346
global_num_experts, ep_size = others_param
399347
is_prefill = False
400-
is_deepseek_v3_r1 = global_num_experts == 256
401-
402-
if ep_size == 1:
403-
selected_token_dispatcher = mock_dist_env[
404-
'mock_token_dispatcher_with_allgather']
405-
elif ep_size < 16:
406-
selected_token_dispatcher = mock_dist_env[
407-
'mock_token_dispatcher_with_all2allv']
408-
else:
409-
selected_token_dispatcher = mock_dist_env[
410-
'mock_token_dispatcher_with_mc2']
411348

412-
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
413-
ep_size, is_prefill, is_deepseek_v3_r1),
414-
with_quant=False,
415-
token_dispatcher=selected_token_dispatcher)
349+
forward_context = mock_dist_env['mock_forward_context_obj']
416350

417351
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
418352
return_value=forward_context):
@@ -438,35 +372,22 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
438372
global_num_experts=global_num_experts,
439373
is_prefill=is_prefill)
440374

441-
expected_shape = (16, 2)
375+
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
376+
mock_moe_comm_method.fused_experts.assert_called_once()
442377

378+
expected_shape = (16, 2)
443379
assert result.shape == expected_shape
444380

445381
@pytest.mark.parametrize("others_param", [16, 1, 4])
446382
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
447383
mock_moe_env, others_param):
448-
449384
ep_size = others_param
450385
is_prefill = False
451386

452-
if ep_size == 1:
453-
selected_token_dispatcher = mock_dist_env[
454-
'mock_token_dispatcher_with_allgather']
455-
elif ep_size < 16:
456-
selected_token_dispatcher = mock_dist_env[
457-
'mock_token_dispatcher_with_all2allv']
458-
else:
459-
selected_token_dispatcher = mock_dist_env[
460-
'mock_token_dispatcher_with_mc2']
461-
462-
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
463-
ep_size, is_prefill, True),
464-
with_quant=False,
465-
token_dispatcher=selected_token_dispatcher)
387+
forward_context = mock_dist_env['mock_forward_context_obj']
466388

467389
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
468390
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
469-
470391
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
471392
moe_method.ep_size = ep_size
472393
x = torch.randn(8, 2, 2)
@@ -493,8 +414,10 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
493414
expert_map=expert_map,
494415
is_prefill=is_prefill)
495416

496-
expected_shape = (16, 2)
417+
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
418+
mock_moe_comm_method.fused_experts.assert_called_once()
497419

420+
expected_shape = (16, 2)
498421
assert result.shape == expected_shape
499422

500423

@@ -574,7 +497,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
574497
mock_get_forward_context):
575498

576499
mock_forward_context = MagicMock()
577-
mock_forward_context.fused_moe_state = FusedMoEState.MC2
500+
mock_forward_context.moe_comm_method_name = "mc2commimpl"
578501
mock_get_forward_context.return_value = mock_forward_context
579502

580503
mock_is_310p.return_value = False
@@ -618,8 +541,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
618541
with_quant=True)
619542

620543
mock_get_forward_context.assert_called()
621-
self.assertEqual(mock_forward_context.fused_moe_state,
622-
FusedMoEState.MC2)
623544

624545
mock_npu_dynamic_quant.assert_called()
625546

0 commit comments

Comments
 (0)