Skip to content

Commit dfc5935

Browse files
committed
Add super kernel in moe
Signed-off-by: NNUCJ <616151263@qq.com>
1 parent 507dce5 commit dfc5935

File tree

3 files changed

+141
-124
lines changed

3 files changed

+141
-124
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
4949
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5050
get_ascend_soc_version, npu_stream_switch,
51-
npu_wait_tensor)
51+
npu_wait_tensor, super_kernel)
5252

5353
VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER
5454

@@ -1123,6 +1123,7 @@ def __init__(
11231123

11241124
AscendFusedMoE.moe_counter += 1
11251125
self.moe_instance_id = AscendFusedMoE.moe_counter
1126+
self.prefix = prefix
11261127

11271128
if params_dtype is None:
11281129
params_dtype = torch.get_default_dtype()
@@ -1264,20 +1265,24 @@ def forward(
12641265

12651266
forward_context = get_forward_context()
12661267
fused_moe_state = get_forward_context().fused_moe_state
1268+
is_prefill = get_forward_context().with_prefill
12671269
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
12681270
quantized_x_for_share, dynamic_scale_for_share = None, None
12691271
from vllm_ascend.quantization.w8a8_dynamic import \
12701272
AscendW8A8DynamicFusedMoEMethod
12711273

12721274
if self.enable_multistream_moe:
12731275
assert gate is not None
1274-
router_logits, _ = gate(hidden_states)
1275-
if (isinstance(self.quant_method.quant_method,
1276-
AscendW8A8DynamicFusedMoEMethod)
1277-
and fused_moe_state == FusedMoEState.MC2):
1278-
with npu_stream_switch("moe_secondary", 0):
1279-
quantized_x_for_share, dynamic_scale_for_share = (
1280-
torch_npu.npu_dynamic_quant(hidden_states))
1276+
with super_kernel(self.prefix,
1277+
"stream-fusion=1",
1278+
enabled=not is_prefill):
1279+
router_logits, _ = gate(hidden_states)
1280+
if (isinstance(self.quant_method.quant_method,
1281+
AscendW8A8DynamicFusedMoEMethod)
1282+
and fused_moe_state == FusedMoEState.MC2):
1283+
with npu_stream_switch("moe_secondary", 0):
1284+
quantized_x_for_share, dynamic_scale_for_share = (
1285+
torch_npu.npu_dynamic_quant(hidden_states))
12811286

12821287
if shared_experts:
12831288
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
@@ -1354,6 +1359,7 @@ def forward(
13541359
dynamic_scale_for_share=dynamic_scale_for_share,
13551360
mc2_mask=mc2_mask,
13561361
token_dispatcher=self.token_dispatcher,
1362+
prefix=self.prefix,
13571363
)
13581364

13591365
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 122 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from vllm_ascend.ops.fused_moe import select_experts
3131
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
3232
dispose_tensor, get_ascend_soc_version,
33-
npu_stream_switch, npu_wait_tensor)
33+
npu_stream_switch, npu_wait_tensor,
34+
super_kernel)
3435

3536
CHUNK_SIZE: int = ascend_envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
3637

@@ -853,125 +854,130 @@ def apply(
853854
shared_experts: Optional[Any] = None,
854855
quantized_x_for_share: Optional[Any] = None,
855856
dynamic_scale_for_share: Optional[Any] = None,
857+
prefix: str = "",
856858
**kwargs,
857859
) -> torch.Tensor:
858860
assert router_logits.shape[
859861
1] == global_num_experts, "Number of global experts mismatch"
860-
861-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
862-
if global_num_experts == 256:
863-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
864-
router_logits,
865-
k=top_k, # topk当前写8
866-
bias=e_score_correction_bias,
867-
k_group=topk_group, # fix: 4
868-
group_count=num_expert_group, # fix 8
869-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
870-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
871-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
872-
# out_flag=False, # todo new api; 第三个输出是否输出
873-
# y2_flag=False, # old api; 第三个输出是否输出
874-
routed_scaling_factor=1,
875-
eps=float(1e-20))
876-
else:
877-
topk_weights, topk_ids = select_experts(
878-
hidden_states=x,
879-
router_logits=router_logits,
880-
top_k=top_k,
881-
use_grouped_topk=use_grouped_topk,
882-
renormalize=renormalize,
883-
topk_group=topk_group,
884-
num_expert_group=num_expert_group,
885-
custom_routing_function=custom_routing_function,
886-
scoring_func=scoring_func,
887-
e_score_correction_bias=e_score_correction_bias,
888-
)
889-
890-
fused_moe_state = get_forward_context().fused_moe_state
891-
shared_gate_up, shared_dequant_scale = None, None
892-
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
893-
with npu_stream_switch("moe_secondary", 0):
894-
npu_wait_tensor(quantized_x_for_share, router_logits)
895-
share_up_out, _ = shared_experts.gate_up_proj(
896-
(quantized_x_for_share, dynamic_scale_for_share))
897-
shared_gate_up, shared_dequant_scale = share_up_out[
898-
0], share_up_out[1]
899-
900-
# this is a naive implementation for experts load balance so as
901-
# to avoid accumulating too much tokens on a single rank.
902-
# currently it is only activated when doing profile runs.
903-
if enable_force_load_balance:
904-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
905-
906-
topk_weights = topk_weights.to(x.dtype)
907-
908-
if fused_moe_state == FusedMoEState.MC2:
909-
return fused_experts_with_mc2(
910-
hidden_states=x,
911-
w1=layer.w13_weight,
912-
w2=layer.w2_weight,
913-
w1_scale=layer.w13_weight_scale_fp32,
914-
w2_scale=layer.w2_weight_scale,
915-
topk_weights=topk_weights,
916-
topk_ids=topk_ids,
917-
top_k=top_k,
918-
expert_map=expert_map,
919-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
920-
log2phy=log2phy,
921-
global_redundant_expert_num=global_redundant_expert_num,
922-
shared_experts=shared_experts,
923-
is_torchair=self.torchair_graph_enabled,
924-
quantized_x_for_share=shared_gate_up,
925-
dynamic_scale_for_share=shared_dequant_scale,
926-
mc2_mask=kwargs.get("mc2_mask", None))
927-
elif fused_moe_state == FusedMoEState.MC2_PREFILL:
928-
return fused_prefill_experts_with_mc2(
929-
hidden_states=x,
930-
w1=layer.w13_weight,
931-
w2=layer.w2_weight,
932-
w1_scale=layer.w13_weight_scale_fp32,
933-
w2_scale=layer.w2_weight_scale,
934-
topk_weights=topk_weights,
935-
topk_ids=topk_ids,
936-
top_k=top_k,
937-
expert_map=expert_map,
938-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
939-
log2phy=log2phy,
940-
global_redundant_expert_num=global_redundant_expert_num,
941-
shared_experts=shared_experts,
942-
is_torchair=self.torchair_graph_enabled,
943-
quantized_x_for_share=shared_gate_up,
944-
dynamic_scale_for_share=shared_dequant_scale,
945-
mc2_mask=kwargs.get("mc2_mask", None))
946-
elif fused_moe_state == FusedMoEState.AllGather:
947-
return fused_experts(hidden_states=x,
948-
w1=layer.w13_weight,
949-
w1_scale=layer.w13_weight_scale,
950-
w2=layer.w2_weight,
951-
w2_scale=layer.w2_weight_scale,
952-
topk_weights=topk_weights,
953-
topk_ids=topk_ids,
954-
top_k=top_k,
955-
expert_map=expert_map)
956-
else:
957-
# The current implementation of deepseek moe splits hidden_states
958-
# according to tp_size before they are feed into fused_moe module.
959-
# Therefore, all2all is needed no matter how dp/tp is set so as to
960-
# dispatch/combine tokens.
961-
return fused_experts_with_all2all(
962-
hidden_states=x,
963-
w1=layer.w13_weight,
964-
w1_scale=layer.w13_weight_scale,
965-
w2=layer.w2_weight,
966-
w2_scale=layer.w2_weight_scale,
967-
topk_weights=topk_weights,
968-
topk_ids=topk_ids,
969-
top_k=top_k,
970-
expert_map=expert_map,
971-
ep_group=self.ep_group,
972-
log2phy=log2phy,
973-
global_redundant_expert_num=global_redundant_expert_num,
974-
)
862+
if shared_experts is not None:
863+
router_logits = router_logits.float()
864+
with super_kernel(prefix,
865+
"stream-fusion=1",
866+
enabled=shared_experts is not None):
867+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
868+
if global_num_experts == 256:
869+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
870+
router_logits,
871+
k=top_k, # topk当前写8
872+
bias=e_score_correction_bias,
873+
k_group=topk_group, # fix: 4
874+
group_count=num_expert_group, # fix 8
875+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
876+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
877+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
878+
# out_flag=False, # todo new api; 第三个输出是否输出
879+
# y2_flag=False, # old api; 第三个输出是否输出
880+
routed_scaling_factor=1,
881+
eps=float(1e-20))
882+
else:
883+
topk_weights, topk_ids = select_experts(
884+
hidden_states=x,
885+
router_logits=router_logits,
886+
top_k=top_k,
887+
use_grouped_topk=use_grouped_topk,
888+
renormalize=renormalize,
889+
topk_group=topk_group,
890+
num_expert_group=num_expert_group,
891+
custom_routing_function=custom_routing_function,
892+
scoring_func=scoring_func,
893+
e_score_correction_bias=e_score_correction_bias,
894+
)
895+
896+
fused_moe_state = get_forward_context().fused_moe_state
897+
shared_gate_up, shared_dequant_scale = None, None
898+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
899+
with npu_stream_switch("moe_secondary", 0):
900+
npu_wait_tensor(quantized_x_for_share, router_logits)
901+
share_up_out, _ = shared_experts.gate_up_proj(
902+
(quantized_x_for_share, dynamic_scale_for_share))
903+
shared_gate_up, shared_dequant_scale = share_up_out[
904+
0], share_up_out[1]
905+
906+
# this is a naive implementation for experts load balance so as
907+
# to avoid accumulating too much tokens on a single rank.
908+
# currently it is only activated when doing profile runs.
909+
if enable_force_load_balance:
910+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
911+
912+
topk_weights = topk_weights.to(x.dtype)
913+
914+
if fused_moe_state == FusedMoEState.MC2:
915+
return fused_experts_with_mc2(
916+
hidden_states=x,
917+
w1=layer.w13_weight,
918+
w2=layer.w2_weight,
919+
w1_scale=layer.w13_weight_scale_fp32,
920+
w2_scale=layer.w2_weight_scale,
921+
topk_weights=topk_weights,
922+
topk_ids=topk_ids,
923+
top_k=top_k,
924+
expert_map=expert_map,
925+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
926+
log2phy=log2phy,
927+
global_redundant_expert_num=global_redundant_expert_num,
928+
shared_experts=shared_experts,
929+
is_torchair=self.torchair_graph_enabled,
930+
quantized_x_for_share=shared_gate_up,
931+
dynamic_scale_for_share=shared_dequant_scale,
932+
mc2_mask=kwargs.get("mc2_mask", None))
933+
elif fused_moe_state == FusedMoEState.MC2_PREFILL:
934+
return fused_prefill_experts_with_mc2(
935+
hidden_states=x,
936+
w1=layer.w13_weight,
937+
w2=layer.w2_weight,
938+
w1_scale=layer.w13_weight_scale_fp32,
939+
w2_scale=layer.w2_weight_scale,
940+
topk_weights=topk_weights,
941+
topk_ids=topk_ids,
942+
top_k=top_k,
943+
expert_map=expert_map,
944+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
945+
log2phy=log2phy,
946+
global_redundant_expert_num=global_redundant_expert_num,
947+
shared_experts=shared_experts,
948+
is_torchair=self.torchair_graph_enabled,
949+
quantized_x_for_share=shared_gate_up,
950+
dynamic_scale_for_share=shared_dequant_scale,
951+
mc2_mask=kwargs.get("mc2_mask", None))
952+
elif fused_moe_state == FusedMoEState.AllGather:
953+
return fused_experts(hidden_states=x,
954+
w1=layer.w13_weight,
955+
w1_scale=layer.w13_weight_scale,
956+
w2=layer.w2_weight,
957+
w2_scale=layer.w2_weight_scale,
958+
topk_weights=topk_weights,
959+
topk_ids=topk_ids,
960+
top_k=top_k,
961+
expert_map=expert_map)
962+
else:
963+
# The current implementation of deepseek moe splits hidden_states
964+
# according to tp_size before they are feed into fused_moe module.
965+
# Therefore, all2all is needed no matter how dp/tp is set so as to
966+
# dispatch/combine tokens.
967+
return fused_experts_with_all2all(
968+
hidden_states=x,
969+
w1=layer.w13_weight,
970+
w1_scale=layer.w13_weight_scale,
971+
w2=layer.w2_weight,
972+
w2_scale=layer.w2_weight_scale,
973+
topk_weights=topk_weights,
974+
topk_ids=topk_ids,
975+
top_k=top_k,
976+
expert_map=expert_map,
977+
ep_group=self.ep_group,
978+
log2phy=log2phy,
979+
global_redundant_expert_num=global_redundant_expert_num,
980+
)
975981

976982
def process_weights_after_loading(self, layer):
977983
if self.transpose_weight:

vllm_ascend/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torchair # type: ignore[import] # noqa: F401
3131
from packaging.version import InvalidVersion, Version
3232
from torch_npu.npu.streams import Event
33+
from torchair.scope import super_kernel as _super_kernel
3334
from vllm.logger import logger
3435

3536
import vllm_ascend.envs as envs
@@ -296,6 +297,10 @@ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
296297
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
297298

298299

300+
def super_kernel(prefix: str, stream: str, enabled: bool = True):
301+
return _super_kernel(prefix, stream) if enabled else nullcontext()
302+
303+
299304
def npu_wait_tensor(self: torch.Tensor,
300305
dependency: torch.Tensor,
301306
*,

0 commit comments

Comments
 (0)