|
30 | 30 | from vllm_ascend.ops.fused_moe import select_experts
|
31 | 31 | from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
32 | 32 | dispose_tensor, get_ascend_soc_version,
|
33 |
| - npu_stream_switch, npu_wait_tensor) |
| 33 | + npu_stream_switch, npu_wait_tensor, |
| 34 | + super_kernel) |
34 | 35 |
|
35 | 36 | CHUNK_SIZE: int = ascend_envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE
|
36 | 37 |
|
@@ -853,125 +854,130 @@ def apply(
|
853 | 854 | shared_experts: Optional[Any] = None,
|
854 | 855 | quantized_x_for_share: Optional[Any] = None,
|
855 | 856 | dynamic_scale_for_share: Optional[Any] = None,
|
| 857 | + prefix: str = "", |
856 | 858 | **kwargs,
|
857 | 859 | ) -> torch.Tensor:
|
858 | 860 | assert router_logits.shape[
|
859 | 861 | 1] == global_num_experts, "Number of global experts mismatch"
|
860 |
| - |
861 |
| - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern |
862 |
| - if global_num_experts == 256: |
863 |
| - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( |
864 |
| - router_logits, |
865 |
| - k=top_k, # topk当前写8 |
866 |
| - bias=e_score_correction_bias, |
867 |
| - k_group=topk_group, # fix: 4 |
868 |
| - group_count=num_expert_group, # fix 8 |
869 |
| - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) |
870 |
| - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax |
871 |
| - norm_type=1, # 0: softmax; 1: sigmoid(fix) |
872 |
| - # out_flag=False, # todo new api; 第三个输出是否输出 |
873 |
| - # y2_flag=False, # old api; 第三个输出是否输出 |
874 |
| - routed_scaling_factor=1, |
875 |
| - eps=float(1e-20)) |
876 |
| - else: |
877 |
| - topk_weights, topk_ids = select_experts( |
878 |
| - hidden_states=x, |
879 |
| - router_logits=router_logits, |
880 |
| - top_k=top_k, |
881 |
| - use_grouped_topk=use_grouped_topk, |
882 |
| - renormalize=renormalize, |
883 |
| - topk_group=topk_group, |
884 |
| - num_expert_group=num_expert_group, |
885 |
| - custom_routing_function=custom_routing_function, |
886 |
| - scoring_func=scoring_func, |
887 |
| - e_score_correction_bias=e_score_correction_bias, |
888 |
| - ) |
889 |
| - |
890 |
| - fused_moe_state = get_forward_context().fused_moe_state |
891 |
| - shared_gate_up, shared_dequant_scale = None, None |
892 |
| - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: |
893 |
| - with npu_stream_switch("moe_secondary", 0): |
894 |
| - npu_wait_tensor(quantized_x_for_share, router_logits) |
895 |
| - share_up_out, _ = shared_experts.gate_up_proj( |
896 |
| - (quantized_x_for_share, dynamic_scale_for_share)) |
897 |
| - shared_gate_up, shared_dequant_scale = share_up_out[ |
898 |
| - 0], share_up_out[1] |
899 |
| - |
900 |
| - # this is a naive implementation for experts load balance so as |
901 |
| - # to avoid accumulating too much tokens on a single rank. |
902 |
| - # currently it is only activated when doing profile runs. |
903 |
| - if enable_force_load_balance: |
904 |
| - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) |
905 |
| - |
906 |
| - topk_weights = topk_weights.to(x.dtype) |
907 |
| - |
908 |
| - if fused_moe_state == FusedMoEState.MC2: |
909 |
| - return fused_experts_with_mc2( |
910 |
| - hidden_states=x, |
911 |
| - w1=layer.w13_weight, |
912 |
| - w2=layer.w2_weight, |
913 |
| - w1_scale=layer.w13_weight_scale_fp32, |
914 |
| - w2_scale=layer.w2_weight_scale, |
915 |
| - topk_weights=topk_weights, |
916 |
| - topk_ids=topk_ids, |
917 |
| - top_k=top_k, |
918 |
| - expert_map=expert_map, |
919 |
| - moe_all_to_all_group_name=self.moe_all_to_all_group_name, |
920 |
| - log2phy=log2phy, |
921 |
| - global_redundant_expert_num=global_redundant_expert_num, |
922 |
| - shared_experts=shared_experts, |
923 |
| - is_torchair=self.torchair_graph_enabled, |
924 |
| - quantized_x_for_share=shared_gate_up, |
925 |
| - dynamic_scale_for_share=shared_dequant_scale, |
926 |
| - mc2_mask=kwargs.get("mc2_mask", None)) |
927 |
| - elif fused_moe_state == FusedMoEState.MC2_PREFILL: |
928 |
| - return fused_prefill_experts_with_mc2( |
929 |
| - hidden_states=x, |
930 |
| - w1=layer.w13_weight, |
931 |
| - w2=layer.w2_weight, |
932 |
| - w1_scale=layer.w13_weight_scale_fp32, |
933 |
| - w2_scale=layer.w2_weight_scale, |
934 |
| - topk_weights=topk_weights, |
935 |
| - topk_ids=topk_ids, |
936 |
| - top_k=top_k, |
937 |
| - expert_map=expert_map, |
938 |
| - moe_all_to_all_group_name=self.moe_all_to_all_group_name, |
939 |
| - log2phy=log2phy, |
940 |
| - global_redundant_expert_num=global_redundant_expert_num, |
941 |
| - shared_experts=shared_experts, |
942 |
| - is_torchair=self.torchair_graph_enabled, |
943 |
| - quantized_x_for_share=shared_gate_up, |
944 |
| - dynamic_scale_for_share=shared_dequant_scale, |
945 |
| - mc2_mask=kwargs.get("mc2_mask", None)) |
946 |
| - elif fused_moe_state == FusedMoEState.AllGather: |
947 |
| - return fused_experts(hidden_states=x, |
948 |
| - w1=layer.w13_weight, |
949 |
| - w1_scale=layer.w13_weight_scale, |
950 |
| - w2=layer.w2_weight, |
951 |
| - w2_scale=layer.w2_weight_scale, |
952 |
| - topk_weights=topk_weights, |
953 |
| - topk_ids=topk_ids, |
954 |
| - top_k=top_k, |
955 |
| - expert_map=expert_map) |
956 |
| - else: |
957 |
| - # The current implementation of deepseek moe splits hidden_states |
958 |
| - # according to tp_size before they are feed into fused_moe module. |
959 |
| - # Therefore, all2all is needed no matter how dp/tp is set so as to |
960 |
| - # dispatch/combine tokens. |
961 |
| - return fused_experts_with_all2all( |
962 |
| - hidden_states=x, |
963 |
| - w1=layer.w13_weight, |
964 |
| - w1_scale=layer.w13_weight_scale, |
965 |
| - w2=layer.w2_weight, |
966 |
| - w2_scale=layer.w2_weight_scale, |
967 |
| - topk_weights=topk_weights, |
968 |
| - topk_ids=topk_ids, |
969 |
| - top_k=top_k, |
970 |
| - expert_map=expert_map, |
971 |
| - ep_group=self.ep_group, |
972 |
| - log2phy=log2phy, |
973 |
| - global_redundant_expert_num=global_redundant_expert_num, |
974 |
| - ) |
| 862 | + if shared_experts is not None: |
| 863 | + router_logits = router_logits.float() |
| 864 | + with super_kernel(prefix, |
| 865 | + "stream-fusion=1", |
| 866 | + enabled=shared_experts is not None): |
| 867 | + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern |
| 868 | + if global_num_experts == 256: |
| 869 | + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( |
| 870 | + router_logits, |
| 871 | + k=top_k, # topk当前写8 |
| 872 | + bias=e_score_correction_bias, |
| 873 | + k_group=topk_group, # fix: 4 |
| 874 | + group_count=num_expert_group, # fix 8 |
| 875 | + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) |
| 876 | + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax |
| 877 | + norm_type=1, # 0: softmax; 1: sigmoid(fix) |
| 878 | + # out_flag=False, # todo new api; 第三个输出是否输出 |
| 879 | + # y2_flag=False, # old api; 第三个输出是否输出 |
| 880 | + routed_scaling_factor=1, |
| 881 | + eps=float(1e-20)) |
| 882 | + else: |
| 883 | + topk_weights, topk_ids = select_experts( |
| 884 | + hidden_states=x, |
| 885 | + router_logits=router_logits, |
| 886 | + top_k=top_k, |
| 887 | + use_grouped_topk=use_grouped_topk, |
| 888 | + renormalize=renormalize, |
| 889 | + topk_group=topk_group, |
| 890 | + num_expert_group=num_expert_group, |
| 891 | + custom_routing_function=custom_routing_function, |
| 892 | + scoring_func=scoring_func, |
| 893 | + e_score_correction_bias=e_score_correction_bias, |
| 894 | + ) |
| 895 | + |
| 896 | + fused_moe_state = get_forward_context().fused_moe_state |
| 897 | + shared_gate_up, shared_dequant_scale = None, None |
| 898 | + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: |
| 899 | + with npu_stream_switch("moe_secondary", 0): |
| 900 | + npu_wait_tensor(quantized_x_for_share, router_logits) |
| 901 | + share_up_out, _ = shared_experts.gate_up_proj( |
| 902 | + (quantized_x_for_share, dynamic_scale_for_share)) |
| 903 | + shared_gate_up, shared_dequant_scale = share_up_out[ |
| 904 | + 0], share_up_out[1] |
| 905 | + |
| 906 | + # this is a naive implementation for experts load balance so as |
| 907 | + # to avoid accumulating too much tokens on a single rank. |
| 908 | + # currently it is only activated when doing profile runs. |
| 909 | + if enable_force_load_balance: |
| 910 | + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) |
| 911 | + |
| 912 | + topk_weights = topk_weights.to(x.dtype) |
| 913 | + |
| 914 | + if fused_moe_state == FusedMoEState.MC2: |
| 915 | + return fused_experts_with_mc2( |
| 916 | + hidden_states=x, |
| 917 | + w1=layer.w13_weight, |
| 918 | + w2=layer.w2_weight, |
| 919 | + w1_scale=layer.w13_weight_scale_fp32, |
| 920 | + w2_scale=layer.w2_weight_scale, |
| 921 | + topk_weights=topk_weights, |
| 922 | + topk_ids=topk_ids, |
| 923 | + top_k=top_k, |
| 924 | + expert_map=expert_map, |
| 925 | + moe_all_to_all_group_name=self.moe_all_to_all_group_name, |
| 926 | + log2phy=log2phy, |
| 927 | + global_redundant_expert_num=global_redundant_expert_num, |
| 928 | + shared_experts=shared_experts, |
| 929 | + is_torchair=self.torchair_graph_enabled, |
| 930 | + quantized_x_for_share=shared_gate_up, |
| 931 | + dynamic_scale_for_share=shared_dequant_scale, |
| 932 | + mc2_mask=kwargs.get("mc2_mask", None)) |
| 933 | + elif fused_moe_state == FusedMoEState.MC2_PREFILL: |
| 934 | + return fused_prefill_experts_with_mc2( |
| 935 | + hidden_states=x, |
| 936 | + w1=layer.w13_weight, |
| 937 | + w2=layer.w2_weight, |
| 938 | + w1_scale=layer.w13_weight_scale_fp32, |
| 939 | + w2_scale=layer.w2_weight_scale, |
| 940 | + topk_weights=topk_weights, |
| 941 | + topk_ids=topk_ids, |
| 942 | + top_k=top_k, |
| 943 | + expert_map=expert_map, |
| 944 | + moe_all_to_all_group_name=self.moe_all_to_all_group_name, |
| 945 | + log2phy=log2phy, |
| 946 | + global_redundant_expert_num=global_redundant_expert_num, |
| 947 | + shared_experts=shared_experts, |
| 948 | + is_torchair=self.torchair_graph_enabled, |
| 949 | + quantized_x_for_share=shared_gate_up, |
| 950 | + dynamic_scale_for_share=shared_dequant_scale, |
| 951 | + mc2_mask=kwargs.get("mc2_mask", None)) |
| 952 | + elif fused_moe_state == FusedMoEState.AllGather: |
| 953 | + return fused_experts(hidden_states=x, |
| 954 | + w1=layer.w13_weight, |
| 955 | + w1_scale=layer.w13_weight_scale, |
| 956 | + w2=layer.w2_weight, |
| 957 | + w2_scale=layer.w2_weight_scale, |
| 958 | + topk_weights=topk_weights, |
| 959 | + topk_ids=topk_ids, |
| 960 | + top_k=top_k, |
| 961 | + expert_map=expert_map) |
| 962 | + else: |
| 963 | + # The current implementation of deepseek moe splits hidden_states |
| 964 | + # according to tp_size before they are feed into fused_moe module. |
| 965 | + # Therefore, all2all is needed no matter how dp/tp is set so as to |
| 966 | + # dispatch/combine tokens. |
| 967 | + return fused_experts_with_all2all( |
| 968 | + hidden_states=x, |
| 969 | + w1=layer.w13_weight, |
| 970 | + w1_scale=layer.w13_weight_scale, |
| 971 | + w2=layer.w2_weight, |
| 972 | + w2_scale=layer.w2_weight_scale, |
| 973 | + topk_weights=topk_weights, |
| 974 | + topk_ids=topk_ids, |
| 975 | + top_k=top_k, |
| 976 | + expert_map=expert_map, |
| 977 | + ep_group=self.ep_group, |
| 978 | + log2phy=log2phy, |
| 979 | + global_redundant_expert_num=global_redundant_expert_num, |
| 980 | + ) |
975 | 981 |
|
976 | 982 | def process_weights_after_loading(self, layer):
|
977 | 983 | if self.transpose_weight:
|
|
0 commit comments