@@ -150,6 +150,7 @@ def fused_experts(
150
150
topk_ids : torch .Tensor ,
151
151
top_k : int ,
152
152
expert_map : torch .Tensor = None ,
153
+ apply_router_weight_on_input : bool = False ,
153
154
) -> torch .Tensor :
154
155
"""
155
156
Fused experts with top-k routing.
@@ -188,6 +189,15 @@ def fused_experts(
188
189
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
189
190
# ], "Only float32, float16, and bfloat16 are supported"
190
191
192
+ if apply_router_weight_on_input :
193
+ assert (topk_weights .dim () == 2
194
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
195
+ _ , topk = topk_weights .shape
196
+ assert (
197
+ topk == 1
198
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
199
+ hidden_states = hidden_states * topk_weights .to (hidden_states .dtype )
200
+
191
201
if expert_map is not None :
192
202
# Generate token indices and flatten
193
203
token_indices = (torch .arange (num_tokens ,
@@ -289,14 +299,15 @@ def fused_experts(
289
299
torch .zeros_like (weighted_down_out )).to (dtype )
290
300
final_hidden_states .index_add_ (0 , sorted_token_indices , valid_output )
291
301
else :
302
+ scales = torch .ones_like (topk_weights ) if apply_router_weight_on_input else topk_weights
292
303
# TODO: Reorder device memory 2 times here, replace the current
293
304
# implementation here when suitable operators become available.
294
305
final_hidden_states = torch_npu .npu_moe_finalize_routing (
295
306
down_out_list ,
296
307
skip1 = None ,
297
308
skip2 = None ,
298
309
bias = None ,
299
- scales = topk_weights ,
310
+ scales = scales ,
300
311
expanded_src_to_dst_row = expanded_row_idx ,
301
312
export_for_source_row = topk_ids ,
302
313
)
@@ -363,9 +374,6 @@ def select_experts(
363
374
Raises:
364
375
ValueError: If an unsupported scoring function is provided.
365
376
"""
366
- if custom_routing_function is not None :
367
- raise NotImplementedError (
368
- "Custom routing function is not supported now" )
369
377
370
378
if scoring_func == "softmax" :
371
379
# NOTE: vLLM use dtype=torch.float here
@@ -402,9 +410,18 @@ def select_experts(
402
410
k = top_k ,
403
411
dim = - 1 ,
404
412
sorted = False )
405
- else :
413
+ elif custom_routing_function is None :
406
414
topk_weights , topk_ids = topk_weights .topk (top_k , dim = - 1 )
407
415
topk_weights = topk_weights .to (hidden_states .dtype )
416
+ else :
417
+ topk_weights , topk_ids = custom_routing_function (
418
+ hidden_states = hidden_states ,
419
+ gating_output = router_logits ,
420
+ topk = top_k ,
421
+ renormalize = renormalize )
422
+ # Required by npu_moe_init_routing
423
+ topk_ids = topk_ids .to (torch .int32 )
424
+ return topk_weights , topk_ids
408
425
409
426
# Required by npu_moe_init_routing
410
427
topk_ids = topk_ids .to (torch .int32 )
0 commit comments