Skip to content

Commit f6f84fc

Browse files
refactor moe_comm_method selection process
Co-Authored-By: weijinqian0 <12153182+weijinqian0@users.noreply.github.com> Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
1 parent af2a886 commit f6f84fc

14 files changed

+176
-349
lines changed

tests/ut/ops/test_common_fused_moe.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,56 +17,7 @@
1717
import torch
1818

1919
from tests.ut.base import TestBase
20-
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge
21-
22-
23-
class TestFusedExpertsMoGE(TestBase):
24-
25-
def test_fused_experts_moge(self):
26-
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
27-
patch('torch_npu.npu_swiglu') as mock_swiglu, \
28-
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
29-
30-
mock_is_310p.return_value = False
31-
32-
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
33-
torch.randn(x[0].shape[0], weight[0].shape[1])
34-
]
35-
36-
mock_swiglu.side_effect = lambda x: x
37-
38-
hidden_states = torch.randn(4, 128)
39-
w1 = torch.randn(4, 256, 128)
40-
w2 = torch.randn(4, 128, 128)
41-
topk_weights = torch.rand(4, 1)
42-
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
43-
top_k = 1
44-
global_num_experts = 4
45-
46-
moe_parallel_config = type(
47-
'MockConfig', (), {
48-
'ep_size': 1,
49-
'tp_size': 1,
50-
'dp_size': 1,
51-
'tp_rank': 0,
52-
'dp_rank': 0,
53-
'ep_rank': 0,
54-
'use_ep': True
55-
})()
56-
57-
output = fused_experts_moge(
58-
hidden_states=hidden_states,
59-
w1=w1,
60-
w2=w2,
61-
moe_parallel_config=moe_parallel_config,
62-
topk_weights=topk_weights,
63-
topk_ids=topk_ids,
64-
top_k=top_k,
65-
global_num_experts=global_num_experts,
66-
apply_router_weight_on_input=True,
67-
)
68-
69-
self.assertEqual(output.shape, (4, 128))
20+
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
7021

7122

7223
class TestLoadWeight(TestBase):

tests/ut/ops/test_fused_ops.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

2525
from tests.ut.base import TestBase
26+
from vllm_ascend.ascend_forward_context import MoECommType
2627
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
2728
AscendUnquantizedFusedMoEMethod)
2829
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -55,6 +56,26 @@ def mock_npu_format_cast(weight_data, format):
5556
return weight_data
5657

5758

59+
@pytest.fixture(autouse=True)
60+
def setup_vllm_config_mock(mocker: MockerFixture):
61+
mock_hf_config = MagicMock()
62+
mock_hf_config.model_type = "llama"
63+
64+
mock_model_config = MagicMock()
65+
mock_model_config.hf_config = mock_hf_config
66+
67+
mock_vllm_config = MagicMock()
68+
mock_vllm_config.model_config = mock_model_config
69+
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
70+
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
71+
mock_vllm_config.model_config.max_model_len = 2048
72+
73+
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
74+
return_value=mock_vllm_config)
75+
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
76+
return_value=mock_vllm_config)
77+
78+
5879
@pytest.fixture
5980
def mock_dist_env(mocker: MockerFixture):
6081
mock_moe_comm_method = MagicMock()
@@ -74,7 +95,7 @@ def mock_finalize(hidden_states, **kwargs):
7495

7596
mock_forward_context_obj = MagicMock(
7697
moe_comm_method=mock_moe_comm_method,
77-
moe_comm_method_name="mc2commimpl",
98+
moe_comm_type=MoECommType.MC2,
7899
max_tokens_across_dp=10,
79100
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
80101
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -103,12 +124,6 @@ def mock_finalize(hidden_states, **kwargs):
103124
return_value=mock_forward_context_obj), \
104125
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
105126
return_value=mock_forward_context_obj), \
106-
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
107-
return_value=MagicMock(
108-
parallel_config=MagicMock(tensor_parallel_size=2),
109-
scheduler_config=MagicMock(max_num_seqs=4),
110-
model_config=MagicMock(max_model_len=2048)
111-
)), \
112127
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
113128
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
114129
return_value=mock_forward_context_obj), \
@@ -497,7 +512,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
497512
mock_get_forward_context):
498513

499514
mock_forward_context = MagicMock()
500-
mock_forward_context.moe_comm_method_name = "mc2commimpl"
515+
mock_forward_context.moe_comm_type = MoECommType.MC2
501516
mock_get_forward_context.return_value = mock_forward_context
502517

503518
mock_is_310p.return_value = False

tests/ut/ops/test_moe_comm_method.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@ def setUp(self):
2424
self.moe_config.dp_group = MagicMock()
2525
self.moe_config.num_global_redundant_experts = 0
2626

27+
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
2728
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
2829
@patch(
2930
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
3031
)
3132
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
3233
def test_all_gather_comm_impl(self, mock_token_dispatcher,
3334
mock_prepare_finalize,
34-
mock_get_forward_context):
35+
mock_get_forward_context,
36+
mock_get_current_vllm_config):
37+
# Mock vLLM config
38+
mock_get_current_vllm_config.return_value = MagicMock()
39+
3540
# Mock forward context
3641
mock_context = MagicMock()
3742
mock_context.moe_comm_method = "all_gather"
@@ -64,13 +69,18 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher,
6469
comm_impl.finalize(h_out, reduce_results=True)
6570
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
6671

72+
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
6773
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
6874
@patch(
6975
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
7076
)
7177
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
7278
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
73-
mock_get_forward_context):
79+
mock_get_forward_context,
80+
mock_get_current_vllm_config):
81+
# Mock vLLM config
82+
mock_get_current_vllm_config.return_value = MagicMock()
83+
7484
# Mock forward context
7585
mock_context = MagicMock()
7686
mock_context.moe_comm_method = "mc2"
@@ -104,14 +114,19 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
104114
comm_impl.finalize(h_out, reduce_results=True)
105115
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
106116

117+
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
107118
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
108119
@patch(
109120
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
110121
)
111122
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
112123
def test_alltoall_comm_impl(self, mock_token_dispatcher,
113124
mock_prepare_finalize,
114-
mock_get_forward_context):
125+
mock_get_forward_context,
126+
mock_get_current_vllm_config):
127+
# Mock vLLM config
128+
mock_get_current_vllm_config.return_value = MagicMock()
129+
115130
# Mock forward context
116131
mock_context = MagicMock()
117132
mock_context.moe_comm_method = "alltoall"
@@ -140,6 +155,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
140155
mock_pf_instance.prepare.assert_called_once_with(
141156
hidden_states, router_logits, False, False, False, None)
142157

158+
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
143159
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
144160
@patch(
145161
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
@@ -148,7 +164,11 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
148164
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
149165
def test_fused_experts_method(self, mock_unified_apply_mlp,
150166
mock_token_dispatcher, mock_prepare_finalize,
151-
mock_get_forward_context):
167+
mock_get_forward_context,
168+
mock_get_current_vllm_config):
169+
# Mock vLLM config
170+
mock_get_current_vllm_config.return_value = MagicMock()
171+
152172
# Mock forward context
153173
mock_context = MagicMock()
154174
mock_context.moe_comm_method = "all_gather"

tests/ut/quantization/test_w4a8_dynamic.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
4848
output_size = 56
4949
group_size = 2
5050

51+
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
5152
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
5253
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
5354
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
5455
@patch('torch.distributed.get_rank', return_value=0)
5556
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
56-
get_current_vllm_config):
57+
get_current_vllm_config, mock_get_ascend_config):
58+
# Mock ascend config
59+
mock_ascend_config = Mock()
60+
mock_ascend_config.dynamic_eplb = False
61+
mock_get_ascend_config.return_value = mock_ascend_config
62+
5763
mock_vllm_config = Mock()
5864
mock_vllm_config.quant_config = Mock(quant_description={
5965
"group_size": self.group_size,
6066
"version": "0.0.0"
6167
})
6268
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
69+
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
70+
max_model_len=2048,
71+
enable_chunked_prefill=False)
6372
get_current_vllm_config.return_value = mock_vllm_config
6473
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
6574

tests/ut/worker/test_model_runner_v1.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717

18+
from vllm_ascend.ascend_forward_context import MoECommType
1819
from vllm_ascend.utils import AscendSocVersion
1920
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2021

@@ -24,21 +25,21 @@
2425
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
2526
[
2627
# Case 1: Expert parallel is disabled, should always be 'allgather'
27-
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
28-
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
28+
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
29+
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),
2930
3031
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
31-
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
32-
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
33-
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
32+
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
33+
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
34+
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
3435
3536
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
36-
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
37-
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
37+
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
38+
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),
3839
3940
# Case 4: A3 SOC
40-
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
41-
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
41+
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
42+
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL),
4243
])
4344
# yapf: enable
4445
def test_select_moe_comm_method(soc_version, enable_expert_parallel,

vllm_ascend/ascend_forward_context.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ class FusedMoEState(Enum):
2222
All2AllSeq = 5
2323

2424

25+
class MoECommType(Enum):
26+
ALLGATHER = "AllGather"
27+
MC2 = "MC2"
28+
ALLTOALL = "AlltoAll"
29+
NAIVE_MULTICAST = "NaiveMulticast"
30+
31+
def __str__(self):
32+
return self.value + "CommImpl"
33+
34+
2535
# TODO(zzzzwwjj): add soc_version to choose branch
2636
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
2737
is_deepseek_v3_r1: bool):
@@ -52,7 +62,7 @@ def set_ascend_forward_context(
5262
with_prefill: bool = True,
5363
in_profile_run: bool = False,
5464
reserved_mc2_mask: Optional[torch.Tensor] = None,
55-
moe_comm_method: str = "",
65+
moe_comm_type: Optional[MoECommType] = None,
5666
num_actual_tokens: Optional[int] = None,
5767
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
5868
batch_descriptor: Optional[BatchDescriptor] = None,
@@ -72,7 +82,11 @@ def set_ascend_forward_context(
7282
batch_descriptor=batch_descriptor,
7383
):
7484
forward_context = get_forward_context()
75-
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
85+
86+
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
87+
forward_context.moe_comm_type = moe_comm_type
88+
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
89+
7690
forward_context.with_prefill = with_prefill
7791
tp_world_size = get_tensor_model_parallel_world_size()
7892
ep_size = (get_ep_group().world_size if

0 commit comments

Comments
 (0)