Skip to content

Commit 8326f15

Browse files
[CustomOp] Register AscendSharedFusedMoE custom op (#2980)
### What this PR does / why we need it? Register `AscendSharedFusedMoE` custom op. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `DeepSeek-V2-Lite` is a MoE model with shared experts. Test: ```bash vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \ --trust-remote-code \ --enforce-eager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite", "messages": [ {"role": "user", "content": "介绍一下联通公司?"} ], "stream": false, "max_tokens": 100 }' ``` Output: ```bash 中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内 ``` - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@486c559 --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 05a700d commit 8326f15

File tree

4 files changed

+18
-26
lines changed

4 files changed

+18
-26
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
FusedMoEParallelConfig # isort: skip
2828
from vllm.model_executor.layers.fused_moe.layer import (
2929
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
30+
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
3031

3132
from vllm_ascend.ascend_config import get_ascend_config
3233
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -415,15 +416,15 @@ def _load_w2(self,
415416
expert_data.copy_(loaded_weight)
416417

417418

418-
class AscendSharedFusedMoE(AscendFusedMoE):
419+
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
419420

420421
def __init__(
421422
self,
422423
shared_experts: torch.nn.Module,
423424
use_overlapped: bool = True,
424425
**kwargs,
425426
):
426-
super().__init__(**kwargs)
427+
AscendFusedMoE.__init__(self, **kwargs)
427428
self._shared_experts = shared_experts
428429
self.use_overlapped = use_overlapped
429430
self.shared_expert_stream = None
@@ -452,7 +453,8 @@ def forward(
452453
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
453454
shared_out = tensor_model_parallel_all_reduce(shared_out)
454455

455-
fused_out = super().forward(
456+
_, fused_out = AscendFusedMoE.forward(
457+
self,
456458
hidden_states=hidden_states,
457459
router_logits=router_logits,
458460
)
@@ -461,6 +463,16 @@ def forward(
461463
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
462464
return shared_out, fused_out
463465

466+
def forward_impl(self, hidden_states: torch.Tensor,
467+
router_logits: torch.Tensor):
468+
shared_output = torch.empty(1)
469+
fused_output = AscendFusedMoE.forward_impl(
470+
self,
471+
hidden_states=hidden_states,
472+
router_logits=router_logits,
473+
)
474+
return shared_output, fused_output
475+
464476

465477
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
466478
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,3 @@
1818
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
1919
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
2020
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
21-
import vllm_ascend.patch.worker.patch_common.patch_shared_fused_moe # noqa

vllm_ascend/patch/worker/patch_common/patch_shared_fused_moe.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

vllm_ascend/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
498498

499499
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
500500
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
501-
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
501+
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
502+
AscendSharedFusedMoE)
502503
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
503504
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
504505
AscendMergedColumnParallelLinear,
@@ -525,6 +526,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
525526
"LogitsProcessor": AscendLogitsProcessor,
526527
"RMSNorm": AscendRMSNorm,
527528
"FusedMoE": AscendFusedMoE,
529+
"SharedFusedMoE": AscendSharedFusedMoE,
528530
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
529531
}
530532

0 commit comments

Comments
 (0)