From d0d0c7dad6b39bcdb9ba4175d1df411d84eeb29c Mon Sep 17 00:00:00 2001 From: YiYang <15594999221@163.com> Date: Tue, 8 Jul 2025 15:30:18 +0800 Subject: [PATCH] =?UTF-8?q?[feat]=20=E6=B7=BB=E5=8A=A0=E9=9D=99=E6=80=81?= =?UTF-8?q?=E5=85=B1=E4=BA=AB=E4=B8=93=E5=AE=B6=E5=A4=96=E7=BD=AE=E7=89=B9?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm_ascend/ascend_config.py | 1 + vllm_ascend/ops/expert_load_balancer.py | 15 ++++++++++++--- vllm_ascend/ops/fused_moe.py | 15 ++++++++++----- vllm_ascend/quantization/quant_config.py | 3 ++- vllm_ascend/quantization/w8a8_dynamic.py | 9 ++++++--- 5 files changed, 31 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index c3043e7a73..794a3e46cf 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -37,6 +37,7 @@ def __init__(self, vllm_config): ascend_scheduler_config) self.expert_map_path = additional_config.get("expert_map_path", None) + self.shared_expert_rank_num = additional_config.get("shared_expert_rank_num", 0) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) self.enable_weight_nz_layout = additional_config.get( diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py index c6eec64a36..f65ce035d7 100644 --- a/vllm_ascend/ops/expert_load_balancer.py +++ b/vllm_ascend/ops/expert_load_balancer.py @@ -7,8 +7,10 @@ class ExpertLoadBalancer(object): - def __init__(self, expert_map_path, global_expert_num): + def __init__(self, expert_map_path, global_expert_num, shared_expert_rank_num): self.expert_map_path = expert_map_path + self.rank_local_expert_num = 0 + self.shared_expert_rank_num = shared_expert_rank_num self.global_expert_num = global_expert_num self.expert_map_tensor, self.layers_num, self.ranks_num = ( self._expert_file_to_tensor()) @@ -26,6 +28,7 @@ def _expert_file_to_tensor(self): device_data.append(device["device_expert"]) tensor_data.append(device_data) expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + self.rank_local_expert_num = expert_map_tensor.shape[self.shared_expert_rank_num][0] return expert_map_tensor, layers_num, gpus_num def generate_index_dicts(self, tensor_2d): @@ -49,7 +52,13 @@ def generate_expert_placement_map(self): dtype=torch.int32, ) for layer_id in range(self.layers_num): - for gpu_id in range(self.ranks_num): + if self.shared_expert_rank_num > 0: + for gpu_id in range(self.shared_expert_rank_num): + e_ids = range(sel.rank_local_expert_num) + expert_placement_map[layer_id, gpu_id, + e_ids] = torch.arange(len(e_ids), + dtype=torch.int32) + for gpu_id in range(self.shared_expert_rank_num, self.ranks_num): e_ids = self.expert_map_tensor[layer_id, gpu_id] expert_placement_map[layer_id, gpu_id, e_ids] = torch.arange(len(e_ids), @@ -73,7 +82,7 @@ def generate_log2phy_expert_map(self, layer_id): for rank in range(self.ranks_num): for key in result_dict: indices_in_concat = result_dict[key] - if key in rank_expert_to_global[rank]: + if key in rank_expert_to_global[rank] and not rank < self.shared_expert_rank_num: log2phy_map[rank][key] = rank_expert_to_global[rank][key] else: chosen_index = random.choice(indices_in_concat) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fe1164fd4d..8e5f280ee6 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -122,6 +122,7 @@ def fused_experts_with_mc2( moe_all_to_all_group_name: Optional[str] = None, shared_experts: Optional[Any] = None, is_torchair: bool = False, + shared_expert_rank_num: int = 0, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 ep_group = get_ep_group() @@ -143,7 +144,7 @@ def fused_experts_with_mc2( "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, - "shared_expert_rank_num": 0, + "shared_expert_rank_num": shared_expert_rank_num, "moe_expert_num": moe_expert_num, "global_bs": global_bs, } @@ -212,7 +213,7 @@ def fused_experts_with_mc2( "expand_idx": expand_idx, "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, - "shared_expert_rank_num": 0, + "shared_expert_rank_num": shared_expert_rank_num, "moe_expert_num": moe_expert_num, "global_bs": global_bs, } @@ -903,6 +904,7 @@ def apply( is_prefill: bool = False, enable_force_load_balance: bool = False, shared_experts: Optional[Any] = None, + shared_expert_rank_num:int = 0, **kwargs, ) -> torch.Tensor: @@ -954,7 +956,8 @@ def apply( expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled) + is_torchair=self.torchair_graph_enabled, + shared_expert_rank_num=shared_expert_rank_num) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1057,10 +1060,11 @@ def __init__( ascend_config = get_ascend_config() expert_map_path = ascend_config.expert_map_path + self.shared_expert_rank_num = ascend_config.shared_expert_rank_num if expert_map_path and os.path.exists(expert_map_path): # moe expert load balance expert_load_balancer = ExpertLoadBalancer(expert_map_path, - self.global_num_experts) + self.global_num_experts, self.shared_expert_rank_num) self.local_num_experts, self.expert_map = \ expert_load_balancer.get_rank_placement_map( self.moe_instance_id, @@ -1099,7 +1103,7 @@ def __init__( assert self.quant_method is not None - local_num_experts = torch.sum(self.expert_map != -1) \ + local_num_experts = self.local_num_experts \ if self.expert_map is not None else num_experts moe_quant_params = { @@ -1209,6 +1213,7 @@ def forward(self, and self.enable_multistream_moe and not is_prefill else None, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, + shared_expert_rank_num=self.shared_expert_rank_num ) if shared_experts: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 1b06a4294a..67700e4b36 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -348,6 +348,7 @@ def apply( enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, global_redundant_expert_num=0, + shared_expert_rank_num: int = 0, **kwargs, ) -> torch.Tensor: return self.quant_method.apply( @@ -355,7 +356,7 @@ def apply( global_num_experts, expert_map, topk_group, num_expert_group, custom_routing_function, scoring_func, e_score_correction_bias, is_prefill, enable_force_load_balance, log2phy, - global_redundant_expert_num, **kwargs) + global_redundant_expert_num, shared_expert_rank_num, **kwargs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a9938c14f2..e120cebbff 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -215,6 +215,7 @@ def fused_experts_with_mc2( w2_scale_bias: torch.Tensor = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + shared_expert_rank_num: int = 0, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if log2phy: topk_ids = log2phy[topk_ids] @@ -242,7 +243,7 @@ def fused_experts_with_mc2( "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, - "shared_expert_rank_num": 0, + "shared_expert_rank_num": shared_expert_rank_num, "moe_expert_num": moe_expert_num, "global_bs": global_bs, } @@ -290,7 +291,7 @@ def fused_experts_with_mc2( "expand_idx": expand_idx, "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, - "shared_expert_rank_num": 0, + "shared_expert_rank_num": shared_expert_rank_num, "moe_expert_num": moe_expert_num, "global_bs": global_bs, } @@ -738,6 +739,7 @@ def apply( shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + shared_expert_rank_num: int = 0, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -807,7 +809,8 @@ def apply( shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale) + dynamic_scale_for_share=shared_dequant_scale, + shared_expert_rank_num=shared_expert_rank_num,) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight,