@@ -141,7 +141,11 @@ def fused_experts_with_mc2(
141
141
is_torchair : bool = False ,
142
142
hidden_states_for_share : Optional [Any ] = None ,
143
143
mc2_mask : Optional [torch .Tensor ] = None ,
144
+ log2phy : Optional [torch .Tensor ] = None ,
145
+ global_redundant_expert_num : int = 0
144
146
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
147
+ if log2phy is not None :
148
+ topk_ids = log2phy [topk_ids ]
145
149
quant_mode = 0
146
150
ep_group = get_mc2_group ()
147
151
ep_rank_id = ep_group .rank_in_group
@@ -163,7 +167,7 @@ def fused_experts_with_mc2(
163
167
164
168
enable_dispatch_v2 = hasattr (torch_npu , "npu_moe_distribute_dispatch_v2" )
165
169
166
- moe_expert_num = len (expert_map )
170
+ moe_expert_num = len (expert_map ) + global_redundant_expert_num
167
171
kwargs_mc2 = {
168
172
"x" : hidden_states ,
169
173
"expert_ids" : topk_ids ,
@@ -349,17 +353,16 @@ def apply_mlp(
349
353
350
354
# currently expert parallelism implemented with all2all
351
355
# is under-optimized.
352
- def fused_experts_with_all2all (
353
- hidden_states : torch .Tensor ,
354
- w1 : torch .Tensor ,
355
- w2 : torch .Tensor ,
356
- topk_weights : torch .Tensor ,
357
- topk_ids : torch .Tensor ,
358
- top_k : int ,
359
- expert_map : torch .Tensor = None ,
360
- ep_group : GroupCoordinator = None ,
361
- max_num_tokens : Optional [int ] = None ,
362
- ):
356
+ def fused_experts_with_all2all (hidden_states : torch .Tensor ,
357
+ w1 : torch .Tensor ,
358
+ w2 : torch .Tensor ,
359
+ topk_weights : torch .Tensor ,
360
+ topk_ids : torch .Tensor ,
361
+ top_k : int ,
362
+ expert_map : torch .Tensor = None ,
363
+ ep_group : GroupCoordinator = None ,
364
+ max_num_tokens : Optional [int ] = None ,
365
+ global_redundant_expert_num : int = 0 ):
363
366
original_shape = hidden_states .shape
364
367
if len (original_shape ) == 3 :
365
368
hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
@@ -369,7 +372,7 @@ def fused_experts_with_all2all(
369
372
device = hidden_states .device
370
373
371
374
if expert_map is not None :
372
- global_num_experts = len (expert_map )
375
+ global_num_experts = len (expert_map ) + global_redundant_expert_num
373
376
local_num_experts = global_num_experts // ep_group .world_size
374
377
row_idx_len = num_tokens * top_k
375
378
row_idx = (torch .arange (0 ,
@@ -639,7 +642,10 @@ def fused_experts_with_all2allv(
639
642
hidden_states : torch .Tensor ,
640
643
w1 : torch .Tensor ,
641
644
w2 : torch .Tensor ,
645
+ log2phy : Optional [torch .Tensor ] = None ,
642
646
):
647
+ if log2phy is not None :
648
+ routing_map = log2phy [routing_map ]
643
649
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
644
650
(share_experts_output , dispatched_input ,
645
651
tokens_per_expert ) = (token_dispatcher .token_permutation (
@@ -824,8 +830,8 @@ def fused_experts(
824
830
expanded_src_to_dst_row = expanded_row_idx ,
825
831
export_for_source_row = topk_ids ,
826
832
)
827
-
828
- return final_hidden_states
833
+ group_list_type = 0
834
+ return final_hidden_states , expert_tokens , group_list_type
829
835
830
836
831
837
def native_grouped_topk (
@@ -1015,6 +1021,8 @@ def apply(
1015
1021
enable_force_load_balance : bool = False ,
1016
1022
hidden_states_for_share : Optional [Any ] = None ,
1017
1023
shared_experts : Optional [Any ] = None ,
1024
+ log2phy : Optional [Any ] = None ,
1025
+ global_redundant_expert_num : int = 0 ,
1018
1026
** kwargs ,
1019
1027
) -> torch .Tensor :
1020
1028
@@ -1071,6 +1079,8 @@ def apply(
1071
1079
is_torchair = self .torchair_graph_enabled ,
1072
1080
hidden_states_for_share = hidden_states_for_share ,
1073
1081
mc2_mask = mc2_mask ,
1082
+ log2phy = log2phy ,
1083
+ global_redundant_expert_num = global_redundant_expert_num ,
1074
1084
)
1075
1085
elif fused_moe_state == FusedMoEState .AllGather :
1076
1086
max_num_tokens = self .max_num_batched_tokens if self .use_aclgraph else None
@@ -1105,18 +1115,20 @@ def apply(
1105
1115
hidden_states = x ,
1106
1116
w1 = layer .w13_weight ,
1107
1117
w2 = layer .w2_weight ,
1108
- )
1118
+ log2phy = log2phy )
1109
1119
else :
1110
1120
max_num_tokens = self .max_num_batched_tokens if self .use_aclgraph else None
1111
- return fused_experts_with_all2all (hidden_states = x ,
1112
- w1 = layer .w13_weight ,
1113
- w2 = layer .w2_weight ,
1114
- topk_weights = topk_weights ,
1115
- topk_ids = topk_ids ,
1116
- top_k = top_k ,
1117
- expert_map = expert_map ,
1118
- ep_group = get_ep_group (),
1119
- max_num_tokens = max_num_tokens )
1121
+ return fused_experts_with_all2all (
1122
+ hidden_states = x ,
1123
+ w1 = layer .w13_weight ,
1124
+ w2 = layer .w2_weight ,
1125
+ topk_weights = topk_weights ,
1126
+ topk_ids = topk_ids ,
1127
+ top_k = top_k ,
1128
+ expert_map = expert_map ,
1129
+ ep_group = get_ep_group (),
1130
+ max_num_tokens = max_num_tokens ,
1131
+ global_redundant_expert_num = global_redundant_expert_num )
1120
1132
1121
1133
1122
1134
class AscendFusedMoE (FusedMoE ):
@@ -1273,6 +1285,10 @@ def __init__(
1273
1285
if envs_ascend .VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance (
1274
1286
self .quant_method , AscendUnquantizedFusedMoEMethod ):
1275
1287
self .reduce_results = False
1288
+ if expert_map_path and os .path .exists (expert_map_path ):
1289
+ self .global_num_experts = self .global_num_experts + self .global_redundant_expert_num
1290
+ self .local_num_experts = self .global_num_experts // self .ep_size
1291
+
1276
1292
moe_dispatcher_config = (
1277
1293
MoEDispatcherConfig ().set_num_moe_experts (
1278
1294
self .global_num_experts ).set_num_local_experts (
0 commit comments