Skip to content

Commit e13c4dd

Browse files
authored
[Fix] Fix SharedFusedMoE (#2817)
### What this PR does / why we need it? Really strange that `register_oot` doesn't work with `SharedFusedMoE`, so we have to add this patch, for now. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? This PR won't have any effect in DeepSeek since we currently still stick with the old `CustomDeepseekV2`. - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@0cdd213 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 7a205db commit e13c4dd

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch
2121
import torch_npu
2222
from vllm.config import CompilationLevel, get_current_vllm_config
23-
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
23+
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
24+
tensor_model_parallel_all_reduce)
2425
from vllm.forward_context import get_forward_context
2526
from vllm.model_executor.layers.fused_moe.config import \
2627
FusedMoEParallelConfig # isort: skip
@@ -373,6 +374,21 @@ def __init__(
373374
self, method.__name__.lower(),
374375
method(moe_config=self.moe_config)) # type: ignore[abstract]
375376

377+
def maybe_all_reduce_tensor_model_parallel(
378+
self, final_hidden_states: torch.Tensor):
379+
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
380+
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
381+
the outputs are already aggregated across tensor parallel ranks in the
382+
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
383+
outputs since each rank only has partial outputs.
384+
"""
385+
forward_context = get_forward_context()
386+
moe_comm_method_name = forward_context.moe_comm_method_name
387+
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
388+
return final_hidden_states
389+
else:
390+
return tensor_model_parallel_all_reduce(final_hidden_states)
391+
376392
def forward_impl(self, hidden_states: torch.Tensor,
377393
router_logits: torch.Tensor):
378394
assert self.quant_method is not None
@@ -415,6 +431,38 @@ def forward_impl(self, hidden_states: torch.Tensor,
415431
return final_hidden_states
416432

417433

434+
class AscendSharedFusedMoE(AscendFusedMoE):
435+
436+
def __init__(
437+
self,
438+
shared_experts: torch.nn.Module,
439+
use_overlapped: bool = True,
440+
**kwargs,
441+
):
442+
super().__init__(**kwargs)
443+
self._shared_experts = shared_experts
444+
self.use_overlapped = use_overlapped
445+
446+
def forward(
447+
self,
448+
hidden_states: torch.Tensor,
449+
router_logits: torch.Tensor,
450+
) -> tuple[torch.Tensor, torch.Tensor]:
451+
shared_out = self._shared_experts(hidden_states)
452+
453+
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
454+
forward_context = get_forward_context()
455+
moe_comm_method_name = forward_context.moe_comm_method_name
456+
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
457+
shared_out = tensor_model_parallel_all_reduce(shared_out)
458+
459+
fused_out = super().forward(
460+
hidden_states=hidden_states,
461+
router_logits=router_logits,
462+
)
463+
return shared_out, fused_out
464+
465+
418466
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
419467
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
420468

vllm_ascend/patch/platform/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
#
1717

1818
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
19+
import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2023 The vLLM team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from vllm.model_executor.models import deepseek_v2, llama4
17+
18+
from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE
19+
20+
deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE
21+
llama4.SharedFusedMoE = AscendSharedFusedMoE

0 commit comments

Comments
 (0)