|
37 | 37 | from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
38 | 38 | AlltoAllCommImpl, MC2CommImpl,
|
39 | 39 | NaiveMulticastCommImpl)
|
40 |
| -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p |
| 40 | +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, |
| 41 | + npu_stream_switch, npu_wait_stream) |
41 | 42 |
|
42 | 43 | original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
43 | 44 |
|
@@ -426,24 +427,35 @@ def __init__(
|
426 | 427 | super().__init__(**kwargs)
|
427 | 428 | self._shared_experts = shared_experts
|
428 | 429 | self.use_overlapped = use_overlapped
|
| 430 | + self.shared_expert_stream = None |
| 431 | + ascend_config = get_ascend_config() |
| 432 | + self.enable_multistream_moe = ascend_config.enable_multistream_moe |
| 433 | + if self.enable_multistream_moe: |
| 434 | + self.shared_expert_stream = torch.npu.Stream() |
429 | 435 |
|
430 | 436 | def forward(
|
431 | 437 | self,
|
432 | 438 | hidden_states: torch.Tensor,
|
433 | 439 | router_logits: torch.Tensor,
|
434 | 440 | ) -> tuple[torch.Tensor, torch.Tensor]:
|
435 |
| - shared_out = self._shared_experts(hidden_states) |
436 |
| - |
437 |
| - # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` |
438 |
| - forward_context = get_forward_context() |
439 |
| - moe_comm_method_name = forward_context.moe_comm_method_name |
440 |
| - if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: |
441 |
| - shared_out = tensor_model_parallel_all_reduce(shared_out) |
| 441 | + # Make sure the shared experts stream begins after hidden_states are ready. |
| 442 | + npu_wait_stream(self.shared_expert_stream, torch.npu.current_stream(), enabled=self.enable_multistream_moe) |
| 443 | + with npu_stream_switch(self.shared_expert_stream, enabled=self.enable_multistream_moe): |
| 444 | + # Use a separate stream to run shared experts. |
| 445 | + shared_out = self._shared_experts(hidden_states) |
| 446 | + |
| 447 | + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` |
| 448 | + forward_context = get_forward_context() |
| 449 | + moe_comm_method_name = forward_context.moe_comm_method_name |
| 450 | + if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: |
| 451 | + shared_out = tensor_model_parallel_all_reduce(shared_out) |
442 | 452 |
|
443 | 453 | fused_out = super().forward(
|
444 | 454 | hidden_states=hidden_states,
|
445 | 455 | router_logits=router_logits,
|
446 | 456 | )
|
| 457 | + # Make sure the default stream waits for the shared experts stream to finish. |
| 458 | + npu_wait_stream(torch.npu.current_stream(), self.shared_expert_stream, enabled=self.enable_multistream_moe) |
447 | 459 | return shared_out, fused_out
|
448 | 460 |
|
449 | 461 |
|
|
0 commit comments