Skip to content
Closed
82 changes: 82 additions & 0 deletions tests/multicard/test_fused_moe_allgather_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep.

Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
"""

import os
from unittest.mock import patch

from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams

from tests.conftest import VllmRunner


@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1",
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
})
def test_generate_with_allgather():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)

with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=16,
enforce_eager=True,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"
})
def test_generate_with_alltoall():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)

with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=16,
enforce_eager=True,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
5 changes: 5 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@
# Whether to enable the trace recompiles from pytorch.
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
),
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when
Expand Down
17 changes: 11 additions & 6 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,8 +897,9 @@ def apply(
**kwargs,
) -> torch.Tensor:

is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
Expand Down Expand Up @@ -934,7 +935,7 @@ def apply(
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)

fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
is_prefill, is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
Expand Down Expand Up @@ -1128,15 +1129,17 @@ def forward(self,
real_top_k = self.top_k

num_tokens, hidden_size = hidden_states.shape
is_deepseek_v3_r1 = self.global_num_experts == 256

fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
is_prefill)
is_prefill, is_deepseek_v3_r1)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
shared_hidden_states = shared_experts(hidden_states)

tp_size = get_tensor_model_parallel_world_size()
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
Expand Down Expand Up @@ -1194,7 +1197,8 @@ def forward(self,
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states

if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
Expand All @@ -1212,7 +1216,8 @@ def forward(self,
else:
final_hidden_states = e_hidden_states

if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
or fused_moe_state == FusedMoEState.AllGatherEP):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

Expand Down
111 changes: 108 additions & 3 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch_npu
from vllm.distributed import GroupCoordinator

import vllm_ascend.envs as envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
Expand Down Expand Up @@ -346,6 +347,95 @@ def fused_experts_with_all2all(
return final_hidden_states


def fused_experts_with_allgather(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
batch_size, hidden_size = hidden_states.shape
topk_weights = topk_weights.to(hidden_states.dtype)

ep_group = get_ep_group().device_group
ep_rank = torch.distributed.get_rank(group=ep_group)
ep_size = torch.distributed.get_world_size(ep_group)

global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_size

hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)

hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=pertoken_scale,
offset=None,
active_num=num_tokens * top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
],
quant_mode=-1,
row_idx_type=1)
group_list_type = 1

sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
expanded_x_idx)
row_index = expanded_x_idx // topk_ids.shape[-1]
row_index = row_index.to(torch.int64)
share_input = torch.zeros((batch_size, hidden_size),
dtype=torch.bfloat16,
device="npu")

hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32)[0]

# act_fn: swiglu
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)

final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
hidden_states,
w2,
scale=w2_scale.to(torch.float32),
bias=None,
pertoken_scale=pertoken_scale.view(-1),
group_list=expert_tokens,
shared_input=share_input,
logit=sorted_topk_weight.to(torch.float32),
row_index=row_index,
output_bs=batch_size).to(torch.bfloat16)

if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)

return final_hidden_states


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
Expand Down Expand Up @@ -623,8 +713,10 @@ def apply(
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"

is_deepseek_v3_r1 = global_num_experts == 256

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
Expand Down Expand Up @@ -661,8 +753,19 @@ def apply(
topk_weights = topk_weights.to(x.dtype)

fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
if fused_moe_state == FusedMoEState.MC2:
is_prefill, is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -713,6 +816,8 @@ def process_weights_after_loading(self, layer):
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
if envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
torch_npu.npu_format_cast_(layer.w2_weight, 29)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
Expand Down
11 changes: 9 additions & 2 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,18 @@ class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3


# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool):
if ep_size == 1:
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
Expand Down
Loading