37
37
from vllm_ascend .ops .moe .moe_comm_method import (AllGatherCommImpl ,
38
38
AlltoAllCommImpl , MC2CommImpl ,
39
39
NaiveMulticastCommImpl )
40
- from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , is_310p ,
41
- npu_stream_switch , npu_wait_stream )
40
+ from vllm_ascend .utils import ACL_FORMAT_FRACTAL_NZ , is_310p , npu_stream_switch
42
41
43
42
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod .__init__
44
43
@@ -439,8 +438,10 @@ def forward(
439
438
router_logits : torch .Tensor ,
440
439
) -> tuple [torch .Tensor , torch .Tensor ]:
441
440
# Make sure the shared experts stream begins after hidden_states are ready.
442
- npu_wait_stream (self .shared_expert_stream , torch .npu .current_stream (), enabled = self .multistream_overlap_shared_expert )
443
- with npu_stream_switch (self .shared_expert_stream , enabled = self .multistream_overlap_shared_expert ):
441
+ if self .multistream_overlap_shared_expert :
442
+ self .shared_expert_stream .wait_stream (torch .npu .current_stream ())
443
+ with npu_stream_switch (self .shared_expert_stream ,
444
+ enabled = self .multistream_overlap_shared_expert ):
444
445
# Use a separate stream to run shared experts.
445
446
shared_out = self ._shared_experts (hidden_states )
446
447
@@ -455,7 +456,8 @@ def forward(
455
456
router_logits = router_logits ,
456
457
)
457
458
# Make sure the default stream waits for the shared experts stream to finish.
458
- npu_wait_stream (torch .npu .current_stream (), self .shared_expert_stream , enabled = self .multistream_overlap_shared_expert )
459
+ if self .multistream_overlap_shared_expert :
460
+ torch .npu .current_stream ().wait_stream (self .shared_expert_stream )
459
461
return shared_out , fused_out
460
462
461
463
0 commit comments