Skip to content

Commit 0e7a4f5

Browse files
committed
fix
1. use unum instead of string 2. avoid setting forward_context in AscendFusedMoE.forward()
1 parent 7280429 commit 0e7a4f5

File tree

6 files changed

+81
-51
lines changed

6 files changed

+81
-51
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ class FusedMoEState(Enum):
2222
All2AllSeq = 5
2323

2424

25+
class MoECommImpl(Enum):
26+
ALLGATHER = "AllGather"
27+
MC2 = "MC2"
28+
ALLTOALL = "AlltoAll"
29+
NAIVE_MULTICAST = "NaiveMulticast"
30+
31+
def __str__(self):
32+
return self.value + "CommImpl"
33+
34+
2535
# TODO(zzzzwwjj): add soc_version to choose branch
2636
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
2737
is_deepseek_v3_r1: bool):
@@ -52,7 +62,7 @@ def set_ascend_forward_context(
5262
with_prefill: bool = True,
5363
in_profile_run: bool = False,
5464
reserved_mc2_mask: Optional[torch.Tensor] = None,
55-
moe_comm_method: str = "",
65+
moe_comm_method_type: Optional[MoECommImpl] = None,
5666
num_actual_tokens: Optional[int] = None,
5767
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
5868
batch_descriptor: Optional[BatchDescriptor] = None,
@@ -72,7 +82,12 @@ def set_ascend_forward_context(
7282
batch_descriptor=batch_descriptor,
7383
):
7484
forward_context = get_forward_context()
75-
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
85+
86+
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
87+
forward_context.moe_comm_method_type = moe_comm_method_type
88+
forward_context.moe_comm_method = get_moe_comm_method(
89+
moe_comm_method_type)
90+
7691
forward_context.with_prefill = with_prefill
7792
tp_world_size = get_tensor_model_parallel_world_size()
7893
ep_size = (get_ep_group().world_size if

vllm_ascend/ops/common_fused_moe.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@
2929
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
3030

3131
from vllm_ascend.ascend_config import get_ascend_config
32+
from vllm_ascend.ascend_forward_context import MoECommImpl
3233
from vllm_ascend.distributed.parallel_state import get_mc2_group
3334
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
3435
determine_default_log2phy_map)
3536
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
3637
from vllm_ascend.ops.moe.experts_selector import select_experts
37-
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
38-
AlltoAllCommImpl, MC2CommImpl,
39-
NaiveMulticastCommImpl)
38+
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
4039
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
4140
get_all_reduce_merge_state,
4241
get_rm_router_logits_state, is_310p)
@@ -145,6 +144,8 @@ def apply(self,
145144

146145

147146
class AscendFusedMoE(FusedMoE):
147+
# The moe_counter parameter is required during the initialization of EPLB
148+
# to identify the current layer index within the MOE model.
148149
moe_counter = -1
149150

150151
def __init__(self, *args, **kwargs):
@@ -172,14 +173,11 @@ def __init__(self, *args, **kwargs):
172173

173174
assert self.quant_method is not None
174175

175-
AscendFusedMoE.moe_counter += 1
176-
self.moe_instance_id = AscendFusedMoE.moe_counter
177176
self.moe_config.tp_group = get_tp_group()
178177
self.moe_config.dp_group = get_dp_group()
179178
self.moe_config.ep_group = get_ep_group()
180179
self.moe_config.mc2_group = get_mc2_group()
181180
ascend_config = get_ascend_config()
182-
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
183181
self.dynamic_eplb = ascend_config.dynamic_eplb
184182
self.expert_map_path = ascend_config.expert_map_path
185183
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
@@ -215,13 +213,9 @@ def __init__(self, *args, **kwargs):
215213
if self.dynamic_eplb:
216214
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
217215

218-
for method in {
219-
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
220-
NaiveMulticastCommImpl
221-
}:
222-
setattr(
223-
self, method.__name__.lower(),
224-
method(moe_config=self.moe_config)) # type: ignore[abstract]
216+
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
217+
218+
setup_moe_comm_method(self.moe_config)
225219

226220
def update_expert_map(self, new_expert_map):
227221
self.expert_map = new_expert_map
@@ -245,8 +239,8 @@ def maybe_all_reduce_tensor_model_parallel(
245239
outputs since each rank only has partial outputs.
246240
"""
247241
forward_context = get_forward_context()
248-
moe_comm_method_name = forward_context.moe_comm_method_name
249-
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
242+
moe_comm_method_type = forward_context.moe_comm_method_type
243+
if moe_comm_method_type in {MoECommImpl.AllTOAll, MoECommImpl.MC2}:
250244
return final_hidden_states
251245
else:
252246
return tensor_model_parallel_all_reduce(final_hidden_states)
@@ -260,9 +254,6 @@ def forward_impl(self, hidden_states: torch.Tensor,
260254

261255
forward_context = get_forward_context()
262256
enable_force_load_balance = forward_context.in_profile_run
263-
moe_comm_method_name = forward_context.moe_comm_method_name
264-
265-
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
266257

267258
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
268259
hidden_states=hidden_states,
@@ -287,11 +278,13 @@ def forward_impl(self, hidden_states: torch.Tensor,
287278
e_score_correction_bias=self.e_score_correction_bias,
288279
activation=self.activation,
289280
apply_router_weight_on_input=self.apply_router_weight_on_input,
290-
enable_eplb=self.enable_eplb,
291-
expert_load_view=self.expert_load_view,
292-
logical_to_physical_map=self.logical_to_physical_map,
293-
logical_replica_count=self.logical_replica_count,
294-
)
281+
quantized_x_for_share=quantized_x_for_share,
282+
dynamic_scale_for_share=dynamic_scale_for_share,
283+
shared_experts=None,
284+
enable_force_load_balance=enable_force_load_balance,
285+
log2phy=self.log2phy,
286+
global_redundant_expert_num=self.global_redundant_expert_num)
287+
295288
if isinstance(final_hidden_states, tuple):
296289
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
297290

@@ -410,8 +403,8 @@ def forward(
410403

411404
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
412405
forward_context = get_forward_context()
413-
moe_comm_method_name = forward_context.moe_comm_method_name
414-
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
406+
moe_comm_method_type = forward_context.moe_comm_method_type
407+
if moe_comm_method_type in {MoECommImpl.AllTOAll, MoECommImpl.MC2}:
415408
shared_out = tensor_model_parallel_all_reduce(shared_out)
416409

417410
fused_out = super().forward(

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
16+
from __future__ import annotations
1617

1718
from abc import ABC, abstractmethod
18-
from typing import Any, Optional
19+
from typing import Any, Dict, Optional
1920

2021
import torch
2122
from vllm.config import get_current_vllm_config
2223
from vllm.forward_context import get_forward_context
2324
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
2425

26+
from vllm_ascend.ascend_forward_context import MoECommImpl
2527
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
2628
FusedMoEPrepareAndFinalizeWithAll2All,
2729
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
@@ -32,6 +34,23 @@
3234
TokenDispatcherWithMC2,
3335
TokenDispatcherWithMoge)
3436

37+
_MoECommMethods: Dict[str, MoECommMethod] = {}
38+
39+
40+
def _register_moe_comm_method(moe_comm_method: MoECommMethod):
41+
_MoECommMethods[moe_comm_method.__class__.__name__] = moe_comm_method
42+
43+
44+
def get_moe_comm_method(name: MoECommImpl) -> Optional[MoECommMethod]:
45+
return _MoECommMethods.get(str(name))
46+
47+
48+
def setup_moe_comm_method(moe_config):
49+
_register_moe_comm_method(AlltoAllCommImpl(moe_config))
50+
_register_moe_comm_method(AllGatherCommImpl(moe_config))
51+
_register_moe_comm_method(MC2CommImpl(moe_config))
52+
_register_moe_comm_method(NaiveMulticastCommImpl(moe_config))
53+
3554

3655
class MoECommMethod(ABC):
3756
"""Base class for MoE communication methods."""

vllm_ascend/ops/moe/moe_mlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.nn.functional import pad
2222
from vllm.forward_context import get_forward_context
2323

24+
from vllm_ascend.ascend_forward_context import MoECommImpl
2425
from vllm_ascend.utils import dispose_tensor, is_310p
2526

2627

@@ -76,7 +77,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
7677
bias1, bias2 = None, None
7778
_output_dtype = w2_scale.dtype
7879

79-
is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl"
80+
is_mc2 = get_forward_context().moe_comm_method_type == MoECommImpl.MC2
8081
if w1_scale_bias is None and is_mc2:
8182
if w1_scale.dtype != torch.float32:
8283
w1_scale = w1_scale.to(torch.float32)

vllm_ascend/torchair/models/torchair_pangu_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,6 @@ def forward(
498498
global _ROUTER_SCALE
499499
_ROUTER_SCALE = self.router_scale
500500

501-
# TODO(angazenn): Does not support MC2 currently
502-
get_forward_context().moe_comm_method_name = "allgathercommimpl"
503-
504501
if not use_h2p():
505502
final_hidden_states = self.experts.forward_impl(
506503
hidden_states=hidden_states, router_logits=router_logits)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@
8888
sanity_check_mm_encoder_outputs,
8989
scatter_mm_placeholders)
9090

91-
from vllm_ascend.ascend_config import get_ascend_config
91+
from vllm_ascend.ascend_config import (MoECommImpl,
92+
get_ascend_config)
9293
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
9394
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
9495
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -1692,7 +1693,7 @@ def _pool(
16921693
)
16931694

16941695
def _select_moe_comm_method(self, num_tokens: int,
1695-
with_prefill: bool) -> str:
1696+
with_prefill: bool) -> MoECommImpl:
16961697
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
16971698
are designed for expert parallelism.
16981699
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1713,38 +1714,42 @@ def _select_moe_comm_method(self, num_tokens: int,
17131714
ValueError: If the soc version is unsupported.
17141715
17151716
Returns:
1716-
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
1717+
MoECommImpl: The selected MoE communication method.
17171718
"""
17181719
soc_version = get_ascend_soc_version()
17191720
quant_type = getattr(self.vllm_config.model_config.hf_config,
17201721
'moe_quantize', None)
17211722
model_type = self.vllm_config.model_config.hf_config.model_type
17221723

17231724
if not self.parallel_config.enable_expert_parallel:
1724-
moe_comm_method = "allgather"
1725+
moe_comm_method = MoECommImpl.ALLGATHER
17251726
elif soc_version in {AscendSocVersion.A2}:
1726-
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
1727-
moe_comm_method = "mc2"
1727+
if (num_tokens <= self.mc2_tokens_capacity
1728+
and self.parallel_config.world_size_across_dp >= 16):
1729+
moe_comm_method = MoECommImpl.MC2
17281730
else:
17291731
if quant_type == "w4a8_dynamic":
1730-
moe_comm_method = "alltoall"
1732+
moe_comm_method = MoECommImpl.ALLTOALL
17311733
else:
1732-
moe_comm_method = "allgather"
1734+
moe_comm_method = MoECommImpl.ALLGATHER
17331735

17341736
elif soc_version in {AscendSocVersion.A3}:
1735-
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
1737+
moe_comm_method = (MoECommImpl.MC2
1738+
if num_tokens <= self.mc2_tokens_capacity else
1739+
MoECommImpl.ALLTOALL)
17361740
else:
17371741
raise ValueError(f"Unsupported soc_version: {soc_version}")
17381742

1739-
if moe_comm_method == "allgather" and with_prefill:
1740-
moe_comm_method = "naivemulticast"
1741-
1743+
if moe_comm_method == MoECommImpl.ALLGATHER and with_prefill:
1744+
moe_comm_method = MoECommImpl.NAIVE_MULTICAST
1745+
17421746
if model_type == "PanguProMoE":
1743-
moe_comm_method = "allgather"
1747+
moe_comm_method = MoECommImpl.ALLGATHER
17441748

17451749
if is_global_first_rank():
17461750
logger.debug(f"num_tokens: {num_tokens}, "
17471751
f"moe_comm_method: {moe_comm_method}")
1752+
print("moe_comm_method = ", moe_comm_method)
17481753

17491754
return moe_comm_method
17501755

@@ -1777,8 +1782,8 @@ def execute_model(
17771782
if self.dynamic_eplb:
17781783
self.eplb_updator.take_update_info_from_eplb_process()
17791784

1780-
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
1781-
self.with_prefill)
1785+
moe_comm_method_type = self._select_moe_comm_method(num_input_tokens,
1786+
self.with_prefill)
17821787

17831788
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
17841789
uniform_decode=False)
@@ -1794,7 +1799,7 @@ def execute_model(
17941799
num_tokens_across_dp=num_tokens_across_dp,
17951800
with_prefill=self.with_prefill,
17961801
reserved_mc2_mask=self.reserved_mc2_mask,
1797-
moe_comm_method=moe_comm_method,
1802+
moe_comm_method_type=moe_comm_method_type,
17981803
aclgraph_runtime_mode=aclgraph_runtime_mode,
17991804
batch_descriptor=batch_descriptor,
18001805
num_actual_tokens=scheduler_output.
@@ -2151,8 +2156,8 @@ def _dummy_run(
21512156
(num_tokens, num_tokens_across_dp, with_prefill,
21522157
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
21532158

2154-
moe_comm_method = self._select_moe_comm_method(num_tokens,
2155-
with_prefill)
2159+
moe_comm_method_type = self._select_moe_comm_method(num_tokens,
2160+
with_prefill)
21562161

21572162
# If cudagraph_mode.decode_mode() == FULL and
21582163
# cudagraph_mode.seperate_routine(). This means that we are using
@@ -2268,7 +2273,7 @@ def dummy_compute_logits(hidden_states):
22682273
with_prefill=with_prefill,
22692274
in_profile_run=self.in_profile_run,
22702275
reserved_mc2_mask=self.reserved_mc2_mask,
2271-
moe_comm_method=moe_comm_method,
2276+
moe_comm_method_type=moe_comm_method_type,
22722277
num_actual_tokens=0,
22732278
aclgraph_runtime_mode=aclgraph_runtime_mode,
22742279
batch_descriptor=batch_descriptor,

0 commit comments

Comments
 (0)