|
20 | 20 | import torch
|
21 | 21 | import torch_npu
|
22 | 22 | from vllm.config import CompilationLevel, get_current_vllm_config
|
23 |
| -from vllm.distributed import get_dp_group, get_ep_group, get_tp_group |
| 23 | +from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, |
| 24 | + tensor_model_parallel_all_reduce) |
24 | 25 | from vllm.forward_context import get_forward_context
|
25 | 26 | from vllm.model_executor.layers.fused_moe.config import \
|
26 | 27 | FusedMoEParallelConfig # isort: skip
|
@@ -373,6 +374,21 @@ def __init__(
|
373 | 374 | self, method.__name__.lower(),
|
374 | 375 | method(moe_config=self.moe_config)) # type: ignore[abstract]
|
375 | 376 |
|
| 377 | + def maybe_all_reduce_tensor_model_parallel( |
| 378 | + self, final_hidden_states: torch.Tensor): |
| 379 | + """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, |
| 380 | + and `alltoallcommimpl`, we do not need to all-reduce the final outputs since |
| 381 | + the outputs are already aggregated across tensor parallel ranks in the |
| 382 | + `finalize` function. In `allgathercommimpl`, we still need to all-reduce the |
| 383 | + outputs since each rank only has partial outputs. |
| 384 | + """ |
| 385 | + forward_context = get_forward_context() |
| 386 | + moe_comm_method_name = forward_context.moe_comm_method_name |
| 387 | + if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: |
| 388 | + return final_hidden_states |
| 389 | + else: |
| 390 | + return tensor_model_parallel_all_reduce(final_hidden_states) |
| 391 | + |
376 | 392 | def forward_impl(self, hidden_states: torch.Tensor,
|
377 | 393 | router_logits: torch.Tensor):
|
378 | 394 | assert self.quant_method is not None
|
@@ -415,6 +431,38 @@ def forward_impl(self, hidden_states: torch.Tensor,
|
415 | 431 | return final_hidden_states
|
416 | 432 |
|
417 | 433 |
|
| 434 | +class AscendSharedFusedMoE(AscendFusedMoE): |
| 435 | + |
| 436 | + def __init__( |
| 437 | + self, |
| 438 | + shared_experts: torch.nn.Module, |
| 439 | + use_overlapped: bool = True, |
| 440 | + **kwargs, |
| 441 | + ): |
| 442 | + super().__init__(**kwargs) |
| 443 | + self._shared_experts = shared_experts |
| 444 | + self.use_overlapped = use_overlapped |
| 445 | + |
| 446 | + def forward( |
| 447 | + self, |
| 448 | + hidden_states: torch.Tensor, |
| 449 | + router_logits: torch.Tensor, |
| 450 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 451 | + shared_out = self._shared_experts(hidden_states) |
| 452 | + |
| 453 | + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` |
| 454 | + forward_context = get_forward_context() |
| 455 | + moe_comm_method_name = forward_context.moe_comm_method_name |
| 456 | + if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: |
| 457 | + shared_out = tensor_model_parallel_all_reduce(shared_out) |
| 458 | + |
| 459 | + fused_out = super().forward( |
| 460 | + hidden_states=hidden_states, |
| 461 | + router_logits=router_logits, |
| 462 | + ) |
| 463 | + return shared_out, fused_out |
| 464 | + |
| 465 | + |
418 | 466 | UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
419 | 467 | UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
420 | 468 |
|
|
0 commit comments