Skip to content

Commit 61dee65

Browse files
committed
use ep/tp size in fusedmoe parallel config
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent f73bbbc commit 61dee65

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def forward_oot(
8585
hidden_states=x,
8686
w1=layer.w13_weight,
8787
w2=layer.w2_weight,
88+
moe_parallel_config=self.moe.moe_parallel_config,
8889
topk_weights=topk_weights,
8990
topk_ids=topk_ids,
9091
top_k=top_k,

vllm_ascend/ops/fused_moe.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def fused_experts_with_mc2(
124124
topk_weights: torch.Tensor,
125125
topk_ids: torch.Tensor,
126126
top_k: int,
127+
moe_parallel_config: FusedMoEParallelConfig,
127128
expert_map: torch.Tensor = None,
128129
moe_all_to_all_group_name: Optional[str] = None,
129130
shared_experts: Optional[Any] = None
@@ -142,11 +143,10 @@ def fused_experts_with_mc2(
142143
rank = torch.distributed.get_rank()
143144

144145
quant_mode = 0
145-
ep_group = get_ep_group()
146-
ep_rank_id = ep_group.rank_in_group
147-
ep_world_size = ep_group.world_size
146+
ep_rank_id = moe_parallel_config.ep_rank
147+
ep_world_size = moe_parallel_config.ep_size
148148

149-
tp_world_size = get_tp_group().world_size
149+
tp_world_size = moe_parallel_config.tp_size
150150
tp_rank = rank % tp_world_size
151151

152152
stage1_kwargs = {
@@ -559,6 +559,7 @@ def fused_experts_moge(
559559
hidden_states: torch.Tensor,
560560
w1: torch.Tensor,
561561
w2: torch.Tensor,
562+
moe_parallel_config: FusedMoEParallelConfig,
562563
topk_weights: torch.Tensor,
563564
topk_ids: torch.Tensor,
564565
top_k: int,
@@ -580,7 +581,7 @@ def fused_experts_moge(
580581
Returns:
581582
hidden_states: Hidden states after routing.
582583
"""
583-
ep_size = get_ep_group().world_size
584+
ep_size = moe_parallel_config.ep_size
584585
local_num_experts = global_num_experts // ep_size
585586
local_num_group = top_k // ep_size
586587

@@ -981,7 +982,7 @@ def __init__(self, moe: FusedMoEConfig = None):
981982
vllm_config = get_current_vllm_config()
982983

983984
self.ep_group = get_ep_group()
984-
self.ep_size = self.ep_group.world_size
985+
self.ep_size = self.moe.moe_parallel_config.ep_size
985986
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
986987
self.local_batch_size = self.global_batch_size // self.ep_size
987988
self.max_model_len = vllm_config.model_config.max_model_len
@@ -1073,13 +1074,14 @@ def apply(
10731074
if enable_force_load_balance:
10741075
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
10751076

1076-
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
1077-
is_prefill, is_deepseek_v3_r1)
1077+
fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill,
1078+
is_deepseek_v3_r1)
10781079
if fused_moe_state == FusedMoEState.MC2:
10791080
return fused_experts_with_mc2(
10801081
hidden_states=x,
10811082
w1=layer.w13_weight,
10821083
w2=layer.w2_weight,
1084+
moe_parallel_config=self.moe.moe_parallel_config,
10831085
topk_weights=topk_weights,
10841086
topk_ids=topk_ids,
10851087
top_k=top_k,

0 commit comments

Comments
 (0)