35
35
from vllm_ascend .ops .layers .experts_selector import select_experts
36
36
from vllm_ascend .ops .moe_dispatcher .token_dispatcher import \
37
37
setup_token_dispatchers
38
- from vllm_ascend .utils import ACL_FORMAT_FRACTAL_NZ , is_310p
38
+ from vllm_ascend .utils import ACL_FORMAT_FRACTAL_NZ , is_310p , vllm_version_is
39
39
40
40
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod .__init__
41
41
@@ -246,6 +246,67 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
246
246
and not vllm_config .model_config .enforce_eager )
247
247
248
248
249
+ def forward_oot_v01011 (
250
+ self ,
251
+ layer : torch .nn .Module ,
252
+ x : torch .Tensor ,
253
+ use_grouped_topk : bool ,
254
+ top_k : int ,
255
+ router_logits : torch .Tensor ,
256
+ renormalize : bool ,
257
+ topk_group : Optional [int ] = None ,
258
+ num_expert_group : Optional [int ] = None ,
259
+ custom_routing_function : Optional [Callable ] = None ,
260
+ scoring_func : str = "softmax" ,
261
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
262
+ global_num_experts : int = - 1 ,
263
+ expert_map : Optional [torch .Tensor ] = None ,
264
+ apply_router_weight_on_input : bool = False ,
265
+ activation : str = "silu" ,
266
+ enable_eplb : bool = False ,
267
+ expert_load_view : Optional [torch .Tensor ] = None ,
268
+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
269
+ logical_replica_count : Optional [torch .Tensor ] = None ) -> torch .Tensor :
270
+
271
+ topk_weights , topk_ids , _ = select_experts (
272
+ hidden_states = x ,
273
+ router_logits = router_logits ,
274
+ top_k = top_k ,
275
+ use_grouped_topk = use_grouped_topk ,
276
+ renormalize = renormalize ,
277
+ topk_group = topk_group ,
278
+ num_expert_group = num_expert_group ,
279
+ custom_routing_function = custom_routing_function ,
280
+ scoring_func = scoring_func ,
281
+ routed_scaling_factor = 1.0 ,
282
+ e_score_correction_bias = e_score_correction_bias ,
283
+ global_num_experts = global_num_experts )
284
+
285
+ if topk_ids .shape [1 ] < top_k or is_310p ():
286
+ assert global_num_experts is not None
287
+ return fused_experts_moge (
288
+ hidden_states = x ,
289
+ w1 = layer .w13_weight ,
290
+ w2 = layer .w2_weight ,
291
+ moe_parallel_config = self .moe .moe_parallel_config ,
292
+ topk_weights = topk_weights ,
293
+ topk_ids = topk_ids ,
294
+ top_k = top_k ,
295
+ global_num_experts = global_num_experts ,
296
+ expert_map = expert_map ,
297
+ apply_router_weight_on_input = apply_router_weight_on_input )
298
+
299
+ return fused_experts (
300
+ hidden_states = x ,
301
+ w1 = layer .w13_weight ,
302
+ w2 = layer .w2_weight ,
303
+ topk_weights = topk_weights ,
304
+ topk_ids = topk_ids ,
305
+ global_num_experts = global_num_experts ,
306
+ expert_map = expert_map ,
307
+ )
308
+
309
+
249
310
def forward_oot (
250
311
self ,
251
312
layer : torch .nn .Module ,
@@ -258,6 +319,7 @@ def forward_oot(
258
319
num_expert_group : Optional [int ] = None ,
259
320
custom_routing_function : Optional [Callable ] = None ,
260
321
scoring_func : str = "softmax" ,
322
+ routed_scaling_factor : float = 1.0 ,
261
323
e_score_correction_bias : Optional [torch .Tensor ] = None ,
262
324
global_num_experts : int = - 1 ,
263
325
expert_map : Optional [torch .Tensor ] = None ,
@@ -278,6 +340,7 @@ def forward_oot(
278
340
num_expert_group = num_expert_group ,
279
341
custom_routing_function = custom_routing_function ,
280
342
scoring_func = scoring_func ,
343
+ routed_scaling_factor = routed_scaling_factor ,
281
344
e_score_correction_bias = e_score_correction_bias ,
282
345
global_num_experts = global_num_experts )
283
346
@@ -441,4 +504,8 @@ def forward_impl(self, hidden_states: torch.Tensor,
441
504
442
505
UnquantizedFusedMoEMethod .__init__ = unquantized_fused_moe_init_func
443
506
UnquantizedFusedMoEMethod .process_weights_after_loading = process_weights_after_loading
444
- UnquantizedFusedMoEMethod .forward_oot = forward_oot
507
+
508
+ if vllm_version_is ("0.10.1.1" ) or vllm_version_is ("0.10.1" ):
509
+ UnquantizedFusedMoEMethod .forward_oot = forward_oot_v01011
510
+ else :
511
+ UnquantizedFusedMoEMethod .forward_oot = forward_oot
0 commit comments