Skip to content

Commit fc83278

Browse files
committed
refactor moe_comm_method selection process
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
1 parent 723d460 commit fc83278

File tree

10 files changed

+92
-75
lines changed

10 files changed

+92
-75
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

2525
from tests.ut.base import TestBase
26+
from vllm_ascend.ascend_forward_context import MoECommType
2627
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
2728
AscendUnquantizedFusedMoEMethod)
2829
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -497,7 +498,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
497498
mock_get_forward_context):
498499

499500
mock_forward_context = MagicMock()
500-
mock_forward_context.moe_comm_method_name = "mc2commimpl"
501+
mock_forward_context.moe_comm_type = MoECommType.MC2
501502
mock_get_forward_context.return_value = mock_forward_context
502503

503504
mock_is_310p.return_value = False

tests/ut/worker/test_model_runner_v1.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717

18+
from vllm_ascend.ascend_forward_context import MoECommType
1819
from vllm_ascend.utils import AscendSocVersion
1920
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2021

@@ -24,21 +25,21 @@
2425
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
2526
[
2627
# Case 1: Expert parallel is disabled, should always be 'allgather'
27-
(AscendSocVersion.A2, False, 8, 100, 256, None, "allgather"),
28-
(AscendSocVersion.A3, False, 16, 500, 256, None, "allgather"),
28+
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
29+
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),
2930
3031
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
31-
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", "alltoall"),
32-
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", "alltoall"),
33-
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", "mc2"), # meets mc2 condition
32+
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
33+
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
34+
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
3435
3536
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
36-
(AscendSocVersion.A2, True, 8, 100, 256, None, "allgather"),
37-
(AscendSocVersion.A2, True, 16, 257, 256, None, "allgather"),
37+
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
38+
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),
3839
3940
# Case 4: A3 SOC
40-
(AscendSocVersion.A3, True, 8, 100, 256, None, "mc2"),
41-
(AscendSocVersion.A3, True, 8, 257, 256, None, "alltoall"),
41+
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
42+
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLGATHER),
4243
])
4344
# yapf: enable
4445
def test_select_moe_comm_method(soc_version, enable_expert_parallel,

vllm_ascend/ascend_forward_context.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ class FusedMoEState(Enum):
2222
All2AllSeq = 5
2323

2424

25+
class MoECommType(Enum):
26+
ALLGATHER = "AllGather"
27+
MC2 = "MC2"
28+
ALLTOALL = "AlltoAll"
29+
NAIVE_MULTICAST = "NaiveMulticast"
30+
31+
def __str__(self):
32+
return self.value + "CommImpl"
33+
34+
2535
# TODO(zzzzwwjj): add soc_version to choose branch
2636
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
2737
is_deepseek_v3_r1: bool):
@@ -52,7 +62,7 @@ def set_ascend_forward_context(
5262
with_prefill: bool = True,
5363
in_profile_run: bool = False,
5464
reserved_mc2_mask: Optional[torch.Tensor] = None,
55-
moe_comm_method: str = "",
65+
moe_comm_type: Optional[MoECommType] = None,
5666
num_actual_tokens: Optional[int] = None,
5767
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
5868
batch_descriptor: Optional[BatchDescriptor] = None,
@@ -72,7 +82,11 @@ def set_ascend_forward_context(
7282
batch_descriptor=batch_descriptor,
7383
):
7484
forward_context = get_forward_context()
75-
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
85+
86+
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
87+
forward_context.moe_comm_type = moe_comm_type
88+
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
89+
7690
forward_context.with_prefill = with_prefill
7791
tp_world_size = get_tensor_model_parallel_world_size()
7892
ep_size = (get_ep_group().world_size if

vllm_ascend/ops/common_fused_moe.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@
2929
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
3030

3131
from vllm_ascend.ascend_config import get_ascend_config
32+
from vllm_ascend.ascend_forward_context import MoECommType
3233
from vllm_ascend.distributed.parallel_state import get_mc2_group
3334
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
3435
determine_default_log2phy_map)
3536
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
3637
from vllm_ascend.ops.moe.experts_selector import select_experts
37-
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
38-
AlltoAllCommImpl, MC2CommImpl,
39-
NaiveMulticastCommImpl)
38+
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
4039
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
4140

4241
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -276,13 +275,7 @@ def __init__(self, *args, **kwargs):
276275
if self.dynamic_eplb:
277276
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
278277

279-
for method in {
280-
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
281-
NaiveMulticastCommImpl
282-
}:
283-
setattr(
284-
self, method.__name__.lower(),
285-
method(moe_config=self.moe_config)) # type: ignore[abstract]
278+
setup_moe_comm_method(self.moe_config)
286279

287280
def update_expert_map(self, new_expert_map):
288281
self.expert_map = new_expert_map
@@ -306,8 +299,8 @@ def maybe_all_reduce_tensor_model_parallel(
306299
outputs since each rank only has partial outputs.
307300
"""
308301
forward_context = get_forward_context()
309-
moe_comm_method_name = forward_context.moe_comm_method_name
310-
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
302+
moe_comm_type = forward_context.moe_comm_type
303+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
311304
return final_hidden_states
312305
else:
313306
return tensor_model_parallel_all_reduce(final_hidden_states)
@@ -317,10 +310,6 @@ def forward_impl(self, hidden_states: torch.Tensor,
317310
assert self.quant_method is not None
318311

319312
forward_context = get_forward_context()
320-
moe_comm_method_name = forward_context.moe_comm_method_name
321-
322-
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
323-
324313
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
325314
hidden_states=hidden_states, router_logits=router_logits)
326315

@@ -436,8 +425,8 @@ def forward(
436425

437426
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
438427
forward_context = get_forward_context()
439-
moe_comm_method_name = forward_context.moe_comm_method_name
440-
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
428+
moe_comm_type = forward_context.moe_comm_type
429+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
441430
shared_out = tensor_model_parallel_all_reduce(shared_out)
442431

443432
fused_out = super().forward(

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
determine_default_log2phy_map)
4242
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4343
from vllm_ascend.ops.moe.experts_selector import select_experts
44-
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
45-
AlltoAllCommImpl, MC2CommImpl,
46-
NaiveMulticastCommImpl)
44+
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
4745
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
4846
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
4947
get_all_reduce_merge_state,
@@ -329,13 +327,7 @@ def __init__(
329327
self.moe_config.mc2_group = get_mc2_group()
330328
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
331329

332-
for method in {
333-
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
334-
NaiveMulticastCommImpl
335-
}:
336-
setattr(
337-
self, method.__name__.lower(),
338-
method(moe_config=self.moe_config)) # type: ignore[abstract]
330+
setup_moe_comm_method(self.moe_config)
339331

340332
def update_expert_map(self, new_expert_map):
341333
self.expert_map = new_expert_map
@@ -402,9 +394,6 @@ def forward(self,
402394
mc2_mask = chunk_mc2_mask[tp_rank]
403395
replace_allreduce = True
404396

405-
moe_comm_method_name = forward_context.moe_comm_method_name
406-
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
407-
408397
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
409398
hidden_states=hidden_states,
410399
router_logits=router_logits,

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
16+
from __future__ import annotations
1617

1718
from abc import ABC, abstractmethod
18-
from typing import Any, Optional
19+
from typing import Any, Dict, Optional
1920

2021
import torch
2122
from vllm.forward_context import get_forward_context
2223
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
2324

25+
from vllm_ascend.ascend_forward_context import MoECommType
2426
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
2527
FusedMoEPrepareAndFinalizeWithAll2All,
2628
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
@@ -30,6 +32,24 @@
3032
TokenDispatcherWithAllGather,
3133
TokenDispatcherWithMC2)
3234

35+
_MoECommMethods: Dict[str, MoECommMethod] = {}
36+
37+
38+
def _register_moe_comm_method(moe_comm_method: MoECommMethod):
39+
_MoECommMethods[moe_comm_method.__class__.__name__] = moe_comm_method
40+
41+
42+
def get_moe_comm_method(
43+
name: Optional[MoECommType]) -> Optional[MoECommMethod]:
44+
return _MoECommMethods.get(str(name))
45+
46+
47+
def setup_moe_comm_method(moe_config):
48+
_register_moe_comm_method(AlltoAllCommImpl(moe_config))
49+
_register_moe_comm_method(AllGatherCommImpl(moe_config))
50+
_register_moe_comm_method(MC2CommImpl(moe_config))
51+
_register_moe_comm_method(NaiveMulticastCommImpl(moe_config))
52+
3353

3454
class MoECommMethod(ABC):
3555
"""Base class for MoE communication methods."""

vllm_ascend/ops/moe/moe_mlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.nn.functional import pad
2222
from vllm.forward_context import get_forward_context
2323

24+
from vllm_ascend.ascend_forward_context import MoECommType
2425
from vllm_ascend.utils import dispose_tensor, is_310p
2526

2627

@@ -76,7 +77,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
7677
bias1, bias2 = None, None
7778
_output_dtype = w2_scale.dtype
7879

79-
is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl"
80+
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
8081
if w1_scale_bias is None and is_mc2:
8182
if w1_scale.dtype != torch.float32:
8283
w1_scale = w1_scale.to(torch.float32)

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ def dummy_run(self,
117117
skip_attn: bool = False,
118118
num_reqs: int = 0,
119119
num_tokens_across_dp: Optional[torch.Tensor] = None):
120-
moe_comm_method = self.runner._select_moe_comm_method(
120+
moe_comm_type = self.runner._select_moe_comm_method(
121121
num_tokens, with_prefill)
122122
with set_ascend_forward_context(None,
123123
self.vllm_config,
124-
moe_comm_method=moe_comm_method,
124+
moe_comm_type=moe_comm_type,
125125
num_tokens=num_tokens):
126126
self.model(
127127
input_ids=self.input_ids[:num_tokens],
@@ -454,7 +454,7 @@ def _propose(
454454
with_prefill = attn_metadata.attn_state not in [
455455
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
456456
]
457-
moe_comm_method = self.runner._select_moe_comm_method(
457+
moe_comm_type = self.runner._select_moe_comm_method(
458458
num_input_tokens, with_prefill)
459459

460460
# copy inputs to buffer for cudagraph
@@ -463,7 +463,7 @@ def _propose(
463463
attn_metadata.block_tables = block_table.to(device)
464464
with set_ascend_forward_context(attn_metadata,
465465
self.vllm_config,
466-
moe_comm_method=moe_comm_method,
466+
moe_comm_type=moe_comm_type,
467467
num_tokens=num_input_tokens):
468468
last_hidden_states, hidden_states = self.model(
469469
input_ids=self.input_ids[:num_input_tokens],
@@ -495,7 +495,7 @@ def _propose(
495495
else:
496496
input_batch_size = batch_size
497497

498-
moe_comm_method = self.runner._select_moe_comm_method(
498+
moe_comm_type = self.runner._select_moe_comm_method(
499499
input_batch_size, False)
500500

501501
attn_metadata.num_actual_tokens = batch_size
@@ -568,7 +568,7 @@ def _propose(
568568
# Run the model.
569569
with set_ascend_forward_context(attn_metadata,
570570
self.vllm_config,
571-
moe_comm_method=moe_comm_method,
571+
moe_comm_type=moe_comm_type,
572572
num_tokens=input_batch_size):
573573

574574
last_hidden_states, hidden_states = self.model(

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def dummy_run(self,
113113
_) = self.runner._sync_metadata_across_dp(num_tokens,
114114
with_prefill, False)
115115

116-
moe_comm_method = self.runner._select_moe_comm_method(
116+
moe_comm_type = self.runner._select_moe_comm_method(
117117
num_tokens, with_prefill)
118118

119119
is_running_torchair = self.torchair_graph_enabled and \
@@ -146,7 +146,7 @@ def dummy_run(self,
146146
with_prefill=with_prefill,
147147
num_tokens_across_dp=num_tokens_across_dp,
148148
reserved_mc2_mask=self.runner.reserved_mc2_mask,
149-
moe_comm_method=moe_comm_method,
149+
moe_comm_type=moe_comm_type,
150150
in_profile_run=self.runner.in_profile_run,
151151
num_actual_tokens=0):
152152
if is_running_torchair:
@@ -416,7 +416,7 @@ def _propose(
416416
num_tokens_across_dp = self.runner.num_tokens_across_dp
417417
with_prefill = self.runner.with_prefill
418418

419-
moe_comm_method = self.runner._select_moe_comm_method(
419+
moe_comm_type = self.runner._select_moe_comm_method(
420420
num_input_tokens, with_prefill)
421421

422422
for step in range(self.num_speculative_tokens):
@@ -427,7 +427,7 @@ def _propose(
427427
with_prefill=with_prefill,
428428
num_tokens_across_dp=num_tokens_across_dp,
429429
reserved_mc2_mask=self.runner.reserved_mc2_mask,
430-
moe_comm_method=moe_comm_method,
430+
moe_comm_type=moe_comm_type,
431431
in_profile_run=self.runner.in_profile_run,
432432
num_actual_tokens=num_tokens):
433433
with ProfileExecuteDuration().capture_async('mtp_forward'):

0 commit comments

Comments
 (0)