@@ -183,29 +183,33 @@ def _select_experts_with_fusion_ops(
183
183
global_num_experts : int = - 1 ):
184
184
185
185
topk_weights , topk_ids , row_idx = None , None , None
186
- # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
187
- is_deepseek_v3_r1 = global_num_experts == 256
188
- if is_deepseek_v3_r1 :
186
+ if scoring_func == "softmax" :
187
+ norm_type = 0
188
+ topk_group = 1
189
+ num_expert_group = 1
190
+ else :
191
+ norm_type = 1
192
+ if custom_routing_function is None :
193
+ if e_score_correction_bias is not None and \
194
+ e_score_correction_bias .dtype != router_logits .dtype :
195
+ e_score_correction_bias = e_score_correction_bias .to (router_logits .dtype )
189
196
topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k (
190
197
router_logits ,
191
- k = top_k , # topk currently 8
198
+ k = top_k ,
192
199
bias = e_score_correction_bias ,
193
- k_group = topk_group , # fix: 4
194
- group_count = num_expert_group , # fix 8
200
+ k_group = topk_group ,
201
+ group_count = num_expert_group ,
195
202
group_select_mode =
196
203
1 , # 0: the maximum in the group; 1: topk2.sum(fix)
197
204
renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
198
- norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
205
+ norm_type = norm_type , # 0: softmax; 1: sigmoid
199
206
# out_flag=False, # todo new api; should the third output be output
200
207
# y2_flag=False, # old api; should the third output be output
201
208
routed_scaling_factor = 1 ,
202
209
eps = float (1e-20 ))
203
210
row_idx = return_row_idx (hidden_states , top_k )
204
- if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" :
205
- topk_weights , topk_ids , row_idx = torch_npu .npu_moe_gating_top_k_softmax (
206
- x = router_logits , finished = None , k = top_k )
207
- topk_ids = topk_ids .to (torch .int32 )
208
- topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
211
+ if scoring_func == "softmax" :
212
+ topk_weights = _renormalize_topk_weights (topk_weights , renormalize )
209
213
210
214
return topk_weights , topk_ids , row_idx
211
215
0 commit comments