Skip to content

Commit c9ec430

Browse files
committed
remove npu_wait_stream
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e8bb2cd commit c9ec430

File tree

2 files changed

+7
-21
lines changed

2 files changed

+7
-21
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
3838
AlltoAllCommImpl, MC2CommImpl,
3939
NaiveMulticastCommImpl)
40-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p,
41-
npu_stream_switch, npu_wait_stream)
40+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
4241

4342
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
4443

@@ -439,8 +438,10 @@ def forward(
439438
router_logits: torch.Tensor,
440439
) -> tuple[torch.Tensor, torch.Tensor]:
441440
# Make sure the shared experts stream begins after hidden_states are ready.
442-
npu_wait_stream(self.shared_expert_stream, torch.npu.current_stream(), enabled=self.multistream_overlap_shared_expert)
443-
with npu_stream_switch(self.shared_expert_stream, enabled=self.multistream_overlap_shared_expert):
441+
if self.multistream_overlap_shared_expert:
442+
self.shared_expert_stream.wait_stream(torch.npu.current_stream())
443+
with npu_stream_switch(self.shared_expert_stream,
444+
enabled=self.multistream_overlap_shared_expert):
444445
# Use a separate stream to run shared experts.
445446
shared_out = self._shared_experts(hidden_states)
446447

@@ -455,7 +456,8 @@ def forward(
455456
router_logits=router_logits,
456457
)
457458
# Make sure the default stream waits for the shared experts stream to finish.
458-
npu_wait_stream(torch.npu.current_stream(), self.shared_expert_stream, enabled=self.multistream_overlap_shared_expert)
459+
if self.multistream_overlap_shared_expert:
460+
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
459461
return shared_out, fused_out
460462

461463

vllm_ascend/utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -631,19 +631,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
631631
return nullcontext()
632632
assert target_stream is not None
633633
return torch.npu.stream(target_stream)
634-
635-
636-
def npu_wait_stream(current_stream: torch.npu.Stream,
637-
target_stream: torch.npu.Stream,
638-
*,
639-
enabled: bool = True):
640-
"""
641-
Make current stream wait for the target stream if enabled is True.
642-
This operation will launch a record event on the target stream,
643-
and launch a wait event on current stream, waitint for the record event.
644-
Otherwise, do nothing.
645-
"""
646-
if enabled:
647-
assert current_stream is not None
648-
assert target_stream is not None
649-
current_stream.wait_stream(target_stream)

0 commit comments

Comments
 (0)