Skip to content

Commit 7e5f06a

Browse files
committed
support gatingtopk
1 parent 76fc832 commit 7e5f06a

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,29 +183,33 @@ def _select_experts_with_fusion_ops(
183183
global_num_experts: int = -1):
184184

185185
topk_weights, topk_ids, row_idx = None, None, None
186-
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
187-
is_deepseek_v3_r1 = global_num_experts == 256
188-
if is_deepseek_v3_r1:
186+
if scoring_func == "softmax":
187+
norm_type = 0
188+
topk_group = 1
189+
num_expert_group = 1
190+
else:
191+
norm_type = 1
192+
if custom_routing_function is None:
193+
if e_score_correction_bias is not None and \
194+
e_score_correction_bias.dtype != router_logits.dtype:
195+
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
189196
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
190197
router_logits,
191-
k=top_k, # topk currently 8
198+
k=top_k,
192199
bias=e_score_correction_bias,
193-
k_group=topk_group, # fix: 4
194-
group_count=num_expert_group, # fix 8
200+
k_group=topk_group,
201+
group_count=num_expert_group,
195202
group_select_mode=
196203
1, # 0: the maximum in the group; 1: topk2.sum(fix)
197204
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
198-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
205+
norm_type=norm_type, # 0: softmax; 1: sigmoid
199206
# out_flag=False, # todo new api; should the third output be output
200207
# y2_flag=False, # old api; should the third output be output
201208
routed_scaling_factor=1,
202209
eps=float(1e-20))
203210
row_idx = return_row_idx(hidden_states, top_k)
204-
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
205-
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
206-
x=router_logits, finished=None, k=top_k)
207-
topk_ids = topk_ids.to(torch.int32)
208-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
211+
if scoring_func == "softmax":
212+
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
209213

210214
return topk_weights, topk_ids, row_idx
211215

0 commit comments

Comments
 (0)