Skip to content
Closed
82 changes: 82 additions & 0 deletions tests/multicard/test_fused_moe_allgather_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep.

Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
"""

import os
from unittest.mock import patch

from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams

from tests.conftest import VllmRunner


@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1",
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
})
def test_generate_with_allgather():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)

with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=16,
enforce_eager=True,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"
})
def test_generate_with_alltoall():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)

with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=16,
enforce_eager=True,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@
# Whether to enable the trace recompiles from pytorch.
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
),
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when
Expand Down
27 changes: 20 additions & 7 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
Expand Down Expand Up @@ -291,9 +291,17 @@ def __init__(
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.etp_group = get_etp_group()

self.params_dtype = torch.get_default_dtype()

# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
self.fused_experts_allgather_ep_enabled = envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and \
config.n_routed_experts == 256 and \
self.ep_group.world_size > 1 and \
self.etp_group.world_size == 1

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -317,10 +325,17 @@ def forward(
use_separated_shared_experts = (self.shared_experts is not None
and not self.enable_multistream_moe)

# torch_npu.npu_format_cast_(layer.w2_weight, 29) is not supported by
# torch_npu.npu_grouped_matmul in current release version of torch_npu
if self.fused_experts_allgather_ep_enabled:
enable_alltoall_ep = False
else:
enable_alltoall_ep = ((VLLM_ENABLE_MC2 and not is_prefill)
or not (self.torchair_graph_enabled
or self.ep_group.world_size == 1))

if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
if enable_alltoall_ep:
if num_tokens < self.tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
Expand Down Expand Up @@ -350,9 +365,7 @@ def forward(
experts_hidden_states[1])

if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
if enable_alltoall_ep:
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
Expand Down
114 changes: 111 additions & 3 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
Expand Down Expand Up @@ -346,6 +346,95 @@ def fused_experts_with_all2all(
return final_hidden_states


def fused_experts_with_allgather(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
batch_size, hidden_size = hidden_states.shape
topk_weights = topk_weights.to(hidden_states.dtype)

ep_group = get_ep_group().device_group
ep_rank = torch.distributed.get_rank(group=ep_group)
ep_size = torch.distributed.get_world_size(ep_group)

global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_size

hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)

hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=pertoken_scale,
offset=None,
active_num=num_tokens * top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
],
quant_mode=-1,
row_idx_type=1)
group_list_type = 1

sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
expanded_x_idx)
row_index = expanded_x_idx // topk_ids.shape[-1]
row_index = row_index.to(torch.int64)
share_input = torch.zeros((batch_size, hidden_size),
dtype=torch.bfloat16,
device="npu")

hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32)[0]

# act_fn: swiglu
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)

final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
hidden_states,
w2,
scale=w2_scale.to(torch.float32),
bias=None,
pertoken_scale=pertoken_scale.view(-1),
group_list=expert_tokens,
shared_input=share_input,
logit=sorted_topk_weight.to(torch.float32),
row_index=row_index,
output_bs=batch_size).to(torch.bfloat16)

if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)

return final_hidden_states


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
Expand Down Expand Up @@ -542,6 +631,11 @@ def __init__(self):
self.transpose_weight = True

self.ep_group = get_ep_group()
self.etp_group = get_etp_group()

self.fused_experts_allgather_ep_enabled = envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and \
self.ep_group.world_size > 1 and \
self.etp_group.world_size == 1

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
Expand Down Expand Up @@ -623,8 +717,10 @@ def apply(
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"

is_deepseek_v3_r1 = global_num_experts == 256

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
Expand Down Expand Up @@ -660,7 +756,18 @@ def apply(

topk_weights = topk_weights.to(x.dtype)

if VLLM_ENABLE_MC2 and not is_prefill:
if self.fused_experts_allgather_ep_enabled and is_deepseek_v3_r1:
return fused_experts_with_allgather(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif VLLM_ENABLE_MC2 and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -711,6 +818,7 @@ def process_weights_after_loading(self, layer):
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
torch_npu.npu_format_cast_(layer.w2_weight, 29)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
Expand Down
Loading