27
27
FusedMoEParallelConfig # isort: skip
28
28
from vllm .model_executor .layers .fused_moe .layer import (
29
29
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
30
+ from vllm .model_executor .layers .shared_fused_moe import SharedFusedMoE
30
31
31
32
from vllm_ascend .ascend_config import get_ascend_config
32
33
from vllm_ascend .distributed .parallel_state import get_mc2_group
@@ -415,15 +416,15 @@ def _load_w2(self,
415
416
expert_data .copy_ (loaded_weight )
416
417
417
418
418
- class AscendSharedFusedMoE (AscendFusedMoE ):
419
+ class AscendSharedFusedMoE (SharedFusedMoE , AscendFusedMoE ):
419
420
420
421
def __init__ (
421
422
self ,
422
423
shared_experts : torch .nn .Module ,
423
424
use_overlapped : bool = True ,
424
425
** kwargs ,
425
426
):
426
- super () .__init__ (** kwargs )
427
+ AscendFusedMoE .__init__ (self , ** kwargs )
427
428
self ._shared_experts = shared_experts
428
429
self .use_overlapped = use_overlapped
429
430
self .shared_expert_stream = None
@@ -452,7 +453,8 @@ def forward(
452
453
if moe_comm_method_name in {"alltoallcommimpl" , "mc2commimpl" }:
453
454
shared_out = tensor_model_parallel_all_reduce (shared_out )
454
455
455
- fused_out = super ().forward (
456
+ _ , fused_out = AscendFusedMoE .forward (
457
+ self ,
456
458
hidden_states = hidden_states ,
457
459
router_logits = router_logits ,
458
460
)
@@ -461,6 +463,16 @@ def forward(
461
463
torch .npu .current_stream ().wait_stream (self .shared_expert_stream )
462
464
return shared_out , fused_out
463
465
466
+ def forward_impl (self , hidden_states : torch .Tensor ,
467
+ router_logits : torch .Tensor ):
468
+ shared_output = torch .empty (1 )
469
+ fused_output = AscendFusedMoE .forward_impl (
470
+ self ,
471
+ hidden_states = hidden_states ,
472
+ router_logits = router_logits ,
473
+ )
474
+ return shared_output , fused_output
475
+
464
476
465
477
UnquantizedFusedMoEMethod .__init__ = unquantized_fused_moe_init_func
466
478
UnquantizedFusedMoEMethod .process_weights_after_loading = process_weights_after_loading
0 commit comments