Skip to content

Commit 950af06

Browse files
committed
support moe ms in aclgraph
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 1f6465c commit 950af06

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, vllm_config):
6161
self.enable_shared_expert_dp = additional_config.get(
6262
"enable_shared_expert_dp", False
6363
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
64+
self.enable_multistream_moe = additional_config.get(
65+
"enable_multistream_moe", False)
6466
self.enable_prefetch = additional_config.get("enable_prefetch", False)
6567
self.lmhead_tensor_parallel_size = additional_config.get(
6668
"lmhead_tensor_parallel_size", None)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
3838
AlltoAllCommImpl, MC2CommImpl,
3939
NaiveMulticastCommImpl)
40-
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
40+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p,
41+
npu_stream_switch, npu_wait_stream)
4142

4243
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
4344

@@ -426,24 +427,35 @@ def __init__(
426427
super().__init__(**kwargs)
427428
self._shared_experts = shared_experts
428429
self.use_overlapped = use_overlapped
430+
self.shared_expert_stream = None
431+
ascend_config = get_ascend_config()
432+
self.enable_multistream_moe = ascend_config.enable_multistream_moe
433+
if self.enable_multistream_moe:
434+
self.shared_expert_stream = torch.npu.Stream()
429435

430436
def forward(
431437
self,
432438
hidden_states: torch.Tensor,
433439
router_logits: torch.Tensor,
434440
) -> tuple[torch.Tensor, torch.Tensor]:
435-
shared_out = self._shared_experts(hidden_states)
436-
437-
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
438-
forward_context = get_forward_context()
439-
moe_comm_method_name = forward_context.moe_comm_method_name
440-
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
441-
shared_out = tensor_model_parallel_all_reduce(shared_out)
441+
# Make sure the shared experts stream begins after hidden_states are ready.
442+
npu_wait_stream(self.shared_expert_stream, torch.npu.current_stream(), enabled=self.enable_multistream_moe)
443+
with npu_stream_switch(self.shared_expert_stream, enabled=self.enable_multistream_moe):
444+
# Use a separate stream to run shared experts.
445+
shared_out = self._shared_experts(hidden_states)
446+
447+
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
448+
forward_context = get_forward_context()
449+
moe_comm_method_name = forward_context.moe_comm_method_name
450+
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
451+
shared_out = tensor_model_parallel_all_reduce(shared_out)
442452

443453
fused_out = super().forward(
444454
hidden_states=hidden_states,
445455
router_logits=router_logits,
446456
)
457+
# Make sure the default stream waits for the shared experts stream to finish.
458+
npu_wait_stream(torch.npu.current_stream(), self.shared_expert_stream, enabled=self.enable_multistream_moe)
447459
return shared_out, fused_out
448460

449461

vllm_ascend/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import functools
2222
import math
2323
import os
24-
from contextlib import contextmanager
24+
from contextlib import contextmanager, nullcontext
2525
from enum import Enum
2626
from threading import Lock
2727
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -617,3 +617,29 @@ def weak_ref_tensors(
617617
if isinstance(tensors, tuple):
618618
return tuple(weak_ref_tensor(t) for t in tensors)
619619
raise ValueError("Invalid type for tensors")
620+
621+
622+
def npu_stream_switch(target_stream: torch.npu.Stream, *, enabled: bool = True):
623+
"""
624+
Switch to the target stream if enabled is True.
625+
Otherwise, do nothing.
626+
"""
627+
if not enabled:
628+
return nullcontext()
629+
return torch.npu.stream(target_stream)
630+
631+
632+
def npu_wait_stream(
633+
current_stream: torch.npu.Stream,
634+
target_stream: torch.npu.Stream,
635+
*,
636+
enabled: bool = True
637+
):
638+
"""
639+
Make current stream wait for the target stream if enabled is True.
640+
This operation will launch a record event on the target stream,
641+
and launch a wait event on current stream, waitint for the record event.
642+
Otherwise, do nothing.
643+
"""
644+
if enabled:
645+
current_stream.wait_stream(target_stream)

0 commit comments

Comments
 (0)