@@ -124,6 +124,7 @@ def fused_experts_with_mc2(
124
124
topk_weights : torch .Tensor ,
125
125
topk_ids : torch .Tensor ,
126
126
top_k : int ,
127
+ moe_parallel_config : FusedMoEParallelConfig ,
127
128
expert_map : torch .Tensor = None ,
128
129
moe_all_to_all_group_name : Optional [str ] = None ,
129
130
shared_experts : Optional [Any ] = None
@@ -142,11 +143,10 @@ def fused_experts_with_mc2(
142
143
rank = torch .distributed .get_rank ()
143
144
144
145
quant_mode = 0
145
- ep_group = get_ep_group ()
146
- ep_rank_id = ep_group .rank_in_group
147
- ep_world_size = ep_group .world_size
146
+ ep_rank_id = moe_parallel_config .ep_rank
147
+ ep_world_size = moe_parallel_config .ep_size
148
148
149
- tp_world_size = get_tp_group (). world_size
149
+ tp_world_size = moe_parallel_config . tp_size
150
150
tp_rank = rank % tp_world_size
151
151
152
152
stage1_kwargs = {
@@ -559,6 +559,7 @@ def fused_experts_moge(
559
559
hidden_states : torch .Tensor ,
560
560
w1 : torch .Tensor ,
561
561
w2 : torch .Tensor ,
562
+ moe_parallel_config : FusedMoEParallelConfig ,
562
563
topk_weights : torch .Tensor ,
563
564
topk_ids : torch .Tensor ,
564
565
top_k : int ,
@@ -580,7 +581,7 @@ def fused_experts_moge(
580
581
Returns:
581
582
hidden_states: Hidden states after routing.
582
583
"""
583
- ep_size = get_ep_group (). world_size
584
+ ep_size = moe_parallel_config . ep_size
584
585
local_num_experts = global_num_experts // ep_size
585
586
local_num_group = top_k // ep_size
586
587
@@ -981,7 +982,7 @@ def __init__(self, moe: FusedMoEConfig = None):
981
982
vllm_config = get_current_vllm_config ()
982
983
983
984
self .ep_group = get_ep_group ()
984
- self .ep_size = self .ep_group . world_size
985
+ self .ep_size = self .moe . moe_parallel_config . ep_size
985
986
self .global_batch_size = vllm_config .scheduler_config .max_num_seqs
986
987
self .local_batch_size = self .global_batch_size // self .ep_size
987
988
self .max_model_len = vllm_config .model_config .max_model_len
@@ -1073,13 +1074,14 @@ def apply(
1073
1074
if enable_force_load_balance :
1074
1075
topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
1075
1076
1076
- fused_moe_state = get_fused_moe_state (self .ep_group . world_size ,
1077
- is_prefill , is_deepseek_v3_r1 )
1077
+ fused_moe_state = get_fused_moe_state (self .ep_size , is_prefill ,
1078
+ is_deepseek_v3_r1 )
1078
1079
if fused_moe_state == FusedMoEState .MC2 :
1079
1080
return fused_experts_with_mc2 (
1080
1081
hidden_states = x ,
1081
1082
w1 = layer .w13_weight ,
1082
1083
w2 = layer .w2_weight ,
1084
+ moe_parallel_config = self .moe .moe_parallel_config ,
1083
1085
topk_weights = topk_weights ,
1084
1086
topk_ids = topk_ids ,
1085
1087
top_k = top_k ,
0 commit comments