Skip to content

[feat]add shared expert feature #1668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, vllm_config):
ascend_scheduler_config)

self.expert_map_path = additional_config.get("expert_map_path", None)
self.shared_expert_rank_num = additional_config.get("shared_expert_rank_num", 0)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
self.enable_weight_nz_layout = additional_config.get(
Expand Down
15 changes: 12 additions & 3 deletions vllm_ascend/ops/expert_load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

class ExpertLoadBalancer(object):

def __init__(self, expert_map_path, global_expert_num):
def __init__(self, expert_map_path, global_expert_num, shared_expert_rank_num):
self.expert_map_path = expert_map_path
self.rank_local_expert_num = 0
self.shared_expert_rank_num = shared_expert_rank_num
self.global_expert_num = global_expert_num
self.expert_map_tensor, self.layers_num, self.ranks_num = (
self._expert_file_to_tensor())
Expand All @@ -26,6 +28,7 @@
device_data.append(device["device_expert"])
tensor_data.append(device_data)
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
self.rank_local_expert_num = expert_map_tensor.shape[self.shared_expert_rank_num][0]
return expert_map_tensor, layers_num, gpus_num

def generate_index_dicts(self, tensor_2d):
Expand All @@ -49,7 +52,13 @@
dtype=torch.int32,
)
for layer_id in range(self.layers_num):
for gpu_id in range(self.ranks_num):
if self.shared_expert_rank_num > 0:
for gpu_id in range(self.shared_expert_rank_num):
e_ids = range(sel.rank_local_expert_num)

Check failure on line 57 in vllm_ascend/ops/expert_load_balancer.py

View workflow job for this annotation

GitHub Actions / lint (3.10)

Ruff (F821)

vllm_ascend/ops/expert_load_balancer.py:57:35: F821 Undefined name `sel`
expert_placement_map[layer_id, gpu_id,
e_ids] = torch.arange(len(e_ids),
dtype=torch.int32)
for gpu_id in range(self.shared_expert_rank_num, self.ranks_num):
e_ids = self.expert_map_tensor[layer_id, gpu_id]
expert_placement_map[layer_id, gpu_id,
e_ids] = torch.arange(len(e_ids),
Expand All @@ -73,7 +82,7 @@
for rank in range(self.ranks_num):
for key in result_dict:
indices_in_concat = result_dict[key]
if key in rank_expert_to_global[rank]:
if key in rank_expert_to_global[rank] and not rank < self.shared_expert_rank_num:
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
else:
chosen_index = random.choice(indices_in_concat)
Expand Down
15 changes: 10 additions & 5 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def fused_experts_with_mc2(
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
shared_expert_rank_num: int = 0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
quant_mode = 0
ep_group = get_ep_group()
Expand All @@ -143,7 +144,7 @@ def fused_experts_with_mc2(
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"shared_expert_rank_num": shared_expert_rank_num,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
Expand Down Expand Up @@ -212,7 +213,7 @@ def fused_experts_with_mc2(
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"shared_expert_rank_num": shared_expert_rank_num,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
Expand Down Expand Up @@ -903,6 +904,7 @@ def apply(
is_prefill: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
shared_expert_rank_num:int = 0,
**kwargs,
) -> torch.Tensor:

Expand Down Expand Up @@ -954,7 +956,8 @@ def apply(
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled)
is_torchair=self.torchair_graph_enabled,
shared_expert_rank_num=shared_expert_rank_num)
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -1057,10 +1060,11 @@ def __init__(

ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
self.shared_expert_rank_num = ascend_config.shared_expert_rank_num
if expert_map_path and os.path.exists(expert_map_path):
# moe expert load balance
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
self.global_num_experts)
self.global_num_experts, self.shared_expert_rank_num)
self.local_num_experts, self.expert_map = \
expert_load_balancer.get_rank_placement_map(
self.moe_instance_id,
Expand Down Expand Up @@ -1099,7 +1103,7 @@ def __init__(

assert self.quant_method is not None

local_num_experts = torch.sum(self.expert_map != -1) \
local_num_experts = self.local_num_experts \
if self.expert_map is not None else num_experts

moe_quant_params = {
Expand Down Expand Up @@ -1209,6 +1213,7 @@ def forward(self,
and self.enable_multistream_moe and not is_prefill else None,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
shared_expert_rank_num=self.shared_expert_rank_num
)

if shared_experts:
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,15 @@ def apply(
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
global_redundant_expert_num=0,
shared_expert_rank_num: int = 0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, log2phy,
global_redundant_expert_num, **kwargs)
global_redundant_expert_num, shared_expert_rank_num, **kwargs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
Expand Down
9 changes: 6 additions & 3 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def fused_experts_with_mc2(
w2_scale_bias: torch.Tensor = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
shared_expert_rank_num: int = 0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if log2phy:
topk_ids = log2phy[topk_ids]
Expand Down Expand Up @@ -242,7 +243,7 @@ def fused_experts_with_mc2(
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"shared_expert_rank_num": shared_expert_rank_num,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
Expand Down Expand Up @@ -290,7 +291,7 @@ def fused_experts_with_mc2(
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"shared_expert_rank_num": shared_expert_rank_num,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
Expand Down Expand Up @@ -738,6 +739,7 @@ def apply(
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
shared_expert_rank_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
Expand Down Expand Up @@ -807,7 +809,8 @@ def apply(
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale)
dynamic_scale_for_share=shared_dequant_scale,
shared_expert_rank_num=shared_expert_rank_num,)
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down
Loading