Skip to content

Commit c018c84

Browse files
committed
Add super kernel in moe
Signed-off-by: NNUCJ <616151263@qq.com>
1 parent 507dce5 commit c018c84

File tree

3 files changed

+138
-124
lines changed

3 files changed

+138
-124
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
4949
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5050
get_ascend_soc_version, npu_stream_switch,
51-
npu_wait_tensor)
51+
npu_wait_tensor, super_kernel)
5252

5353
VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER
5454

@@ -1123,6 +1123,7 @@ def __init__(
11231123

11241124
AscendFusedMoE.moe_counter += 1
11251125
self.moe_instance_id = AscendFusedMoE.moe_counter
1126+
self.prefix = prefix
11261127

11271128
if params_dtype is None:
11281129
params_dtype = torch.get_default_dtype()
@@ -1264,20 +1265,22 @@ def forward(
12641265

12651266
forward_context = get_forward_context()
12661267
fused_moe_state = get_forward_context().fused_moe_state
1268+
is_prefill = get_forward_context().with_prefill
12671269
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
12681270
quantized_x_for_share, dynamic_scale_for_share = None, None
12691271
from vllm_ascend.quantization.w8a8_dynamic import \
12701272
AscendW8A8DynamicFusedMoEMethod
12711273

12721274
if self.enable_multistream_moe:
12731275
assert gate is not None
1274-
router_logits, _ = gate(hidden_states)
1275-
if (isinstance(self.quant_method.quant_method,
1276-
AscendW8A8DynamicFusedMoEMethod)
1277-
and fused_moe_state == FusedMoEState.MC2):
1278-
with npu_stream_switch("moe_secondary", 0):
1279-
quantized_x_for_share, dynamic_scale_for_share = (
1280-
torch_npu.npu_dynamic_quant(hidden_states))
1276+
with super_kernel(self.prefix, "stream-fusion=1", not is_prefill):
1277+
router_logits, _ = gate(hidden_states)
1278+
if (isinstance(self.quant_method.quant_method,
1279+
AscendW8A8DynamicFusedMoEMethod)
1280+
and fused_moe_state == FusedMoEState.MC2):
1281+
with npu_stream_switch("moe_secondary", 0):
1282+
quantized_x_for_share, dynamic_scale_for_share = (
1283+
torch_npu.npu_dynamic_quant(hidden_states))
12811284

12821285
if shared_experts:
12831286
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
@@ -1354,6 +1357,7 @@ def forward(
13541357
dynamic_scale_for_share=dynamic_scale_for_share,
13551358
mc2_mask=mc2_mask,
13561359
token_dispatcher=self.token_dispatcher,
1360+
prefix=self.prefix,
13571361
)
13581362

13591363
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 121 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from vllm_ascend.ops.fused_moe import select_experts
3131
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
3232
dispose_tensor, get_ascend_soc_version,
33-
npu_stream_switch, npu_wait_tensor)
33+
npu_stream_switch, npu_wait_tensor,
34+
super_kernel)
3435

3536
CHUNK_SIZE: int = ascend_envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
3637

@@ -853,125 +854,129 @@ def apply(
853854
shared_experts: Optional[Any] = None,
854855
quantized_x_for_share: Optional[Any] = None,
855856
dynamic_scale_for_share: Optional[Any] = None,
857+
prefix: str = "",
856858
**kwargs,
857859
) -> torch.Tensor:
858860
assert router_logits.shape[
859861
1] == global_num_experts, "Number of global experts mismatch"
860-
861-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
862-
if global_num_experts == 256:
863-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
864-
router_logits,
865-
k=top_k, # topk当前写8
866-
bias=e_score_correction_bias,
867-
k_group=topk_group, # fix: 4
868-
group_count=num_expert_group, # fix 8
869-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
870-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
871-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
872-
# out_flag=False, # todo new api; 第三个输出是否输出
873-
# y2_flag=False, # old api; 第三个输出是否输出
874-
routed_scaling_factor=1,
875-
eps=float(1e-20))
876-
else:
877-
topk_weights, topk_ids = select_experts(
878-
hidden_states=x,
879-
router_logits=router_logits,
880-
top_k=top_k,
881-
use_grouped_topk=use_grouped_topk,
882-
renormalize=renormalize,
883-
topk_group=topk_group,
884-
num_expert_group=num_expert_group,
885-
custom_routing_function=custom_routing_function,
886-
scoring_func=scoring_func,
887-
e_score_correction_bias=e_score_correction_bias,
888-
)
889-
890-
fused_moe_state = get_forward_context().fused_moe_state
891-
shared_gate_up, shared_dequant_scale = None, None
892-
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
893-
with npu_stream_switch("moe_secondary", 0):
894-
npu_wait_tensor(quantized_x_for_share, router_logits)
895-
share_up_out, _ = shared_experts.gate_up_proj(
896-
(quantized_x_for_share, dynamic_scale_for_share))
897-
shared_gate_up, shared_dequant_scale = share_up_out[
898-
0], share_up_out[1]
899-
900-
# this is a naive implementation for experts load balance so as
901-
# to avoid accumulating too much tokens on a single rank.
902-
# currently it is only activated when doing profile runs.
903-
if enable_force_load_balance:
904-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
905-
906-
topk_weights = topk_weights.to(x.dtype)
907-
908-
if fused_moe_state == FusedMoEState.MC2:
909-
return fused_experts_with_mc2(
910-
hidden_states=x,
911-
w1=layer.w13_weight,
912-
w2=layer.w2_weight,
913-
w1_scale=layer.w13_weight_scale_fp32,
914-
w2_scale=layer.w2_weight_scale,
915-
topk_weights=topk_weights,
916-
topk_ids=topk_ids,
917-
top_k=top_k,
918-
expert_map=expert_map,
919-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
920-
log2phy=log2phy,
921-
global_redundant_expert_num=global_redundant_expert_num,
922-
shared_experts=shared_experts,
923-
is_torchair=self.torchair_graph_enabled,
924-
quantized_x_for_share=shared_gate_up,
925-
dynamic_scale_for_share=shared_dequant_scale,
926-
mc2_mask=kwargs.get("mc2_mask", None))
927-
elif fused_moe_state == FusedMoEState.MC2_PREFILL:
928-
return fused_prefill_experts_with_mc2(
929-
hidden_states=x,
930-
w1=layer.w13_weight,
931-
w2=layer.w2_weight,
932-
w1_scale=layer.w13_weight_scale_fp32,
933-
w2_scale=layer.w2_weight_scale,
934-
topk_weights=topk_weights,
935-
topk_ids=topk_ids,
936-
top_k=top_k,
937-
expert_map=expert_map,
938-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
939-
log2phy=log2phy,
940-
global_redundant_expert_num=global_redundant_expert_num,
941-
shared_experts=shared_experts,
942-
is_torchair=self.torchair_graph_enabled,
943-
quantized_x_for_share=shared_gate_up,
944-
dynamic_scale_for_share=shared_dequant_scale,
945-
mc2_mask=kwargs.get("mc2_mask", None))
946-
elif fused_moe_state == FusedMoEState.AllGather:
947-
return fused_experts(hidden_states=x,
948-
w1=layer.w13_weight,
949-
w1_scale=layer.w13_weight_scale,
950-
w2=layer.w2_weight,
951-
w2_scale=layer.w2_weight_scale,
952-
topk_weights=topk_weights,
953-
topk_ids=topk_ids,
954-
top_k=top_k,
955-
expert_map=expert_map)
956-
else:
957-
# The current implementation of deepseek moe splits hidden_states
958-
# according to tp_size before they are feed into fused_moe module.
959-
# Therefore, all2all is needed no matter how dp/tp is set so as to
960-
# dispatch/combine tokens.
961-
return fused_experts_with_all2all(
962-
hidden_states=x,
963-
w1=layer.w13_weight,
964-
w1_scale=layer.w13_weight_scale,
965-
w2=layer.w2_weight,
966-
w2_scale=layer.w2_weight_scale,
967-
topk_weights=topk_weights,
968-
topk_ids=topk_ids,
969-
top_k=top_k,
970-
expert_map=expert_map,
971-
ep_group=self.ep_group,
972-
log2phy=log2phy,
973-
global_redundant_expert_num=global_redundant_expert_num,
974-
)
862+
if shared_experts is not None:
863+
router_logits = router_logits.float()
864+
with super_kernel(prefix, "stream-fusion=1", shared_experts
865+
is not None):
866+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
867+
if global_num_experts == 256:
868+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
869+
router_logits,
870+
k=top_k, # topk当前写8
871+
bias=e_score_correction_bias,
872+
k_group=topk_group, # fix: 4
873+
group_count=num_expert_group, # fix 8
874+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
875+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
876+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
877+
# out_flag=False, # todo new api; 第三个输出是否输出
878+
# y2_flag=False, # old api; 第三个输出是否输出
879+
routed_scaling_factor=1,
880+
eps=float(1e-20))
881+
else:
882+
topk_weights, topk_ids = select_experts(
883+
hidden_states=x,
884+
router_logits=router_logits,
885+
top_k=top_k,
886+
use_grouped_topk=use_grouped_topk,
887+
renormalize=renormalize,
888+
topk_group=topk_group,
889+
num_expert_group=num_expert_group,
890+
custom_routing_function=custom_routing_function,
891+
scoring_func=scoring_func,
892+
e_score_correction_bias=e_score_correction_bias,
893+
)
894+
895+
fused_moe_state = get_forward_context().fused_moe_state
896+
shared_gate_up, shared_dequant_scale = None, None
897+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
898+
with npu_stream_switch("moe_secondary", 0):
899+
npu_wait_tensor(quantized_x_for_share, router_logits)
900+
share_up_out, _ = shared_experts.gate_up_proj(
901+
(quantized_x_for_share, dynamic_scale_for_share))
902+
shared_gate_up, shared_dequant_scale = share_up_out[
903+
0], share_up_out[1]
904+
905+
# this is a naive implementation for experts load balance so as
906+
# to avoid accumulating too much tokens on a single rank.
907+
# currently it is only activated when doing profile runs.
908+
if enable_force_load_balance:
909+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
910+
911+
topk_weights = topk_weights.to(x.dtype)
912+
913+
if fused_moe_state == FusedMoEState.MC2:
914+
return fused_experts_with_mc2(
915+
hidden_states=x,
916+
w1=layer.w13_weight,
917+
w2=layer.w2_weight,
918+
w1_scale=layer.w13_weight_scale_fp32,
919+
w2_scale=layer.w2_weight_scale,
920+
topk_weights=topk_weights,
921+
topk_ids=topk_ids,
922+
top_k=top_k,
923+
expert_map=expert_map,
924+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
925+
log2phy=log2phy,
926+
global_redundant_expert_num=global_redundant_expert_num,
927+
shared_experts=shared_experts,
928+
is_torchair=self.torchair_graph_enabled,
929+
quantized_x_for_share=shared_gate_up,
930+
dynamic_scale_for_share=shared_dequant_scale,
931+
mc2_mask=kwargs.get("mc2_mask", None))
932+
elif fused_moe_state == FusedMoEState.MC2_PREFILL:
933+
return fused_prefill_experts_with_mc2(
934+
hidden_states=x,
935+
w1=layer.w13_weight,
936+
w2=layer.w2_weight,
937+
w1_scale=layer.w13_weight_scale_fp32,
938+
w2_scale=layer.w2_weight_scale,
939+
topk_weights=topk_weights,
940+
topk_ids=topk_ids,
941+
top_k=top_k,
942+
expert_map=expert_map,
943+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
944+
log2phy=log2phy,
945+
global_redundant_expert_num=global_redundant_expert_num,
946+
shared_experts=shared_experts,
947+
is_torchair=self.torchair_graph_enabled,
948+
quantized_x_for_share=shared_gate_up,
949+
dynamic_scale_for_share=shared_dequant_scale,
950+
mc2_mask=kwargs.get("mc2_mask", None))
951+
elif fused_moe_state == FusedMoEState.AllGather:
952+
return fused_experts(hidden_states=x,
953+
w1=layer.w13_weight,
954+
w1_scale=layer.w13_weight_scale,
955+
w2=layer.w2_weight,
956+
w2_scale=layer.w2_weight_scale,
957+
topk_weights=topk_weights,
958+
topk_ids=topk_ids,
959+
top_k=top_k,
960+
expert_map=expert_map)
961+
else:
962+
# The current implementation of deepseek moe splits hidden_states
963+
# according to tp_size before they are feed into fused_moe module.
964+
# Therefore, all2all is needed no matter how dp/tp is set so as to
965+
# dispatch/combine tokens.
966+
return fused_experts_with_all2all(
967+
hidden_states=x,
968+
w1=layer.w13_weight,
969+
w1_scale=layer.w13_weight_scale,
970+
w2=layer.w2_weight,
971+
w2_scale=layer.w2_weight_scale,
972+
topk_weights=topk_weights,
973+
topk_ids=topk_ids,
974+
top_k=top_k,
975+
expert_map=expert_map,
976+
ep_group=self.ep_group,
977+
log2phy=log2phy,
978+
global_redundant_expert_num=global_redundant_expert_num,
979+
)
975980

976981
def process_weights_after_loading(self, layer):
977982
if self.transpose_weight:

vllm_ascend/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torchair # type: ignore[import] # noqa: F401
3131
from packaging.version import InvalidVersion, Version
3232
from torch_npu.npu.streams import Event
33+
from torchair.scope import super_kernel as _super_kernel
3334
from vllm.logger import logger
3435

3536
import vllm_ascend.envs as envs
@@ -296,6 +297,10 @@ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
296297
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
297298

298299

300+
def super_kernel(prefix: str, stream: str, enabled: bool = True):
301+
return _super_kernel(prefix, stream) if enabled else nullcontext()
302+
303+
299304
def npu_wait_tensor(self: torch.Tensor,
300305
dependency: torch.Tensor,
301306
*,

0 commit comments

Comments
 (0)