Skip to content

Add super kernel in moe #1877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, torchair_graph_config):
self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
self.enable_super_kernel = torchair_graph_config.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the document for the new filed in docs/source/user_guide/configuration/additional_config.md

"enable_super_kernel", False)

if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
Expand Down Expand Up @@ -95,6 +97,16 @@ def __init__(self, torchair_graph_config):
raise RuntimeError(
"enable_kv_nz is valid only when Torchair graph mode is enabled"
)
if self.enable_super_kernel:
raise RuntimeError(
"enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled"
)

if not self.enable_multistream_moe:
if self.enable_super_kernel:
raise RuntimeError(
"enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled"
)


class AscendSchedulerConfig:
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32,
prefix=f"{prefix}.gate")
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
Expand Down
27 changes: 19 additions & 8 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_ascend_soc_version, npu_stream_switch,
npu_wait_tensor)
npu_wait_tensor, super_kernel)

VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER

Expand Down Expand Up @@ -1123,6 +1123,7 @@ def __init__(

AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
self.prefix = prefix

if params_dtype is None:
params_dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -1179,6 +1180,9 @@ def __init__(
self.enable_multistream_moe = (
ascend_config.torchair_graph_config.enable_multistream_moe
and self.torchair_graph_enabled)
self.enable_super_kernel = (
ascend_config.torchair_graph_config.super_kernel
and self.enable_multistream_moe)

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -1264,20 +1268,25 @@ def forward(

forward_context = get_forward_context()
fused_moe_state = get_forward_context().fused_moe_state
is_prefill = get_forward_context().with_prefill
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod

if self.enable_multistream_moe:
assert gate is not None
router_logits, _ = gate(hidden_states)
if (isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod)
and fused_moe_state == FusedMoEState.MC2):
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = (
torch_npu.npu_dynamic_quant(hidden_states))
with super_kernel(self.prefix,
"stream-fusion=1",
enabled=not is_prefill
and self.enable_super_kernel):
router_logits, _ = gate(hidden_states.float())
if (isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod)
and fused_moe_state == FusedMoEState.MC2):
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = (
torch_npu.npu_dynamic_quant(hidden_states))

if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
Expand Down Expand Up @@ -1354,6 +1363,8 @@ def forward(
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
token_dispatcher=self.token_dispatcher,
prefix=self.prefix,
enable_super_kernel=self.enable_super_kernel,
)

if shared_experts:
Expand Down
140 changes: 74 additions & 66 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version,
npu_stream_switch, npu_wait_tensor)
npu_stream_switch, npu_wait_tensor,
super_kernel)

CHUNK_SIZE: int = ascend_envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE

Expand Down Expand Up @@ -853,77 +854,84 @@ def apply(
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
prefix: str = "",
enable_super_kernel: bool = False,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]

# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)

topk_weights = topk_weights.to(x.dtype)
with super_kernel(prefix,
Copy link
Collaborator

@ApsarasX ApsarasX Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding with super_kernel in this location would require modifying too many lines of code. Why not add it in fused_moe.py instead?

"stream-fusion=1",
enabled=enable_super_kernel):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits.float(),
k=top_k, # topk当前写8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; 第三个输出是否输出
# y2_flag=False, # old api; 第三个输出是否输出
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]

# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)

topk_weights = topk_weights.to(x.dtype)

if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))
with super_kernel(prefix,
"stream-fusion=1",
enabled=enable_super_kernel):
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state == FusedMoEState.MC2_PREFILL:
return fused_prefill_experts_with_mc2(
hidden_states=x,
Expand Down
5 changes: 5 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torchair # type: ignore[import] # noqa: F401
from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from torchair.scope import super_kernel as _super_kernel
from vllm.logger import logger

import vllm_ascend.envs as envs
Expand Down Expand Up @@ -296,6 +297,10 @@ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
return _npu_stream_switch(tag, priority) if enabled else nullcontext()


def super_kernel(prefix: str, stream: str, enabled: bool = True):
return _super_kernel(prefix, stream) if enabled else nullcontext()


def npu_wait_tensor(self: torch.Tensor,
dependency: torch.Tensor,
*,
Expand Down
Loading