Skip to content
Closed
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@
# Whether to enable the trace recompiles from pytorch.
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
),
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when
Expand Down
28 changes: 21 additions & 7 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
Expand Down Expand Up @@ -291,9 +291,17 @@ def __init__(
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.etp_group = get_etp_group()

self.params_dtype = torch.get_default_dtype()

# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
self.fused_experts_allgather_ep_enabled = envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and \
config.n_routed_experts == 256 and \
self.ep_group.world_size > 1 and \
self.etp_group.world_size == 1

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -317,10 +325,18 @@ def forward(
use_separated_shared_experts = (self.shared_experts is not None
and not self.enable_multistream_moe)

# torch_npu.npu_format_cast_(layer.w2_weight, 29) is not supported by
# torch_npu.npu_grouped_matmul in current release version of torch_npu
if self.fused_experts_allgather_ep_enabled:
enable_alltoall_ep = False
else:
enable_alltoall_ep = self.ep_group.world_size > 1
if not is_prefill:
enable_alltoall_ep = enable_alltoall_ep and (
VLLM_ENABLE_MC2 or not self.torchair_graph_enabled)

if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
if enable_alltoall_ep:
if num_tokens < self.tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
Expand Down Expand Up @@ -350,9 +366,7 @@ def forward(
experts_hidden_states[1])

if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
if enable_alltoall_ep:
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
Expand Down
115 changes: 112 additions & 3 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
Expand Down Expand Up @@ -346,6 +346,96 @@ def fused_experts_with_all2all(
return final_hidden_states


def fused_experts_with_allgather(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
batch_size, hidden_size = hidden_states.shape
topk_weights = topk_weights.to(hidden_states.dtype)

ep_group = get_ep_group().device_group
ep_rank = torch.distributed.get_rank(group=ep_group)
ep_size = torch.distributed.get_world_size(ep_group)

global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_size

hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)

hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=pertoken_scale,
offset=None,
active_num=num_tokens * top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
],
quant_mode=-1,
row_idx_type=1)
group_list_type = 1

sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
expanded_x_idx)
row_index = expanded_x_idx // topk_ids.shape[-1]
row_index = row_index.to(torch.int64)
# TODO pass share_input from outside
share_input = torch.zeros((batch_size, hidden_size),
dtype=torch.bfloat16,
device="npu")

hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32)[0]

# act_fn: swiglu
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)

final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
hidden_states,
w2,
scale=w2_scale.to(torch.float32),
bias=None,
pertoken_scale=pertoken_scale.view(-1),
group_list=expert_tokens,
shared_input=share_input,
logit=sorted_topk_weight.to(torch.float32),
row_index=row_index,
output_bs=batch_size).to(torch.bfloat16)

if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)

return final_hidden_states


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
Expand Down Expand Up @@ -542,6 +632,11 @@ def __init__(self):
self.transpose_weight = True

self.ep_group = get_ep_group()
self.etp_group = get_etp_group()

self.fused_experts_allgather_ep_enabled = envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and \
self.ep_group.world_size > 1 and \
self.etp_group.world_size == 1

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
Expand Down Expand Up @@ -623,8 +718,10 @@ def apply(
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"

is_deepseek_v3_r1 = global_num_experts == 256

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
Expand Down Expand Up @@ -660,7 +757,18 @@ def apply(

topk_weights = topk_weights.to(x.dtype)

if VLLM_ENABLE_MC2 and not is_prefill:
if self.fused_experts_allgather_ep_enabled and is_deepseek_v3_r1:
return fused_experts_with_allgather(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif VLLM_ENABLE_MC2 and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -711,6 +819,7 @@ def process_weights_after_loading(self, layer):
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
torch_npu.npu_format_cast_(layer.w2_weight, 29)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
Expand Down