File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -150,8 +150,8 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
150
150
log2phy : torch .Tensor = None ,
151
151
global_redundant_expert_num : int = 0 ,
152
152
** kwargs ) -> torch .Tensor :
153
-
154
- topk_ids = log2phy [topk_ids ]
153
+ if log2phy :
154
+ topk_ids = log2phy [topk_ids ]
155
155
global_bs = 0
156
156
moe_expert_num = len (expert_map ) + global_redundant_expert_num
157
157
# hidden_states = hidden_states.bfloat16()
@@ -278,7 +278,8 @@ def fused_experts_with_all2all(
278
278
log2phy : torch .Tensor = None ,
279
279
global_redundant_expert_num : int = 0 ,
280
280
):
281
- topk_ids = log2phy [topk_ids ]
281
+ if log2phy :
282
+ topk_ids = log2phy [topk_ids ]
282
283
original_shape = hidden_states .shape
283
284
if len (original_shape ) == 3 :
284
285
hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
You can’t perform that action at this time.
0 commit comments