diff --git a/docs/source/assets/eplb_swift_balancer.png b/docs/source/assets/eplb_swift_balancer.png new file mode 100644 index 0000000000..ed696a057c Binary files /dev/null and b/docs/source/assets/eplb_swift_balancer.png differ diff --git a/docs/source/index.md b/docs/source/index.md index 5f421b4679..0f0fad0b0c 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -49,6 +49,7 @@ user_guide/env_vars user_guide/additional_config user_guide/graph_mode.md user_guide/release_notes +user_guide/eplb_swift_balancer ::: % How to contribute to the vLLM Ascend project diff --git a/docs/source/user_guide/eplb_swift_balancer.md b/docs/source/user_guide/eplb_swift_balancer.md new file mode 100644 index 0000000000..3ff018bcad --- /dev/null +++ b/docs/source/user_guide/eplb_swift_balancer.md @@ -0,0 +1,33 @@ +# Swift Balancer + +## Overview +Experts rebalancing of MoE models for LLM serving is a mandatory option.Changing experts dynamically would have a negative impact on TTFT and TPOT while stop-the-world. +Asynchronously expert load balacing would be a better choice. +We have launched SwiftBalancer to support dynamic experts load balancing with Zero-overhead experts movement. + +## Design + +![alt text](../assets/eplb_swift_balancer.png) + +The overall workflow involves: +1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm + recording and add-operator. +2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. +3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. +4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. +5. Lanch ibatch_send_recv in async_stream before forward. +6. After forward, wait for the ibatch_send_recv finish, then do upate expert map and expert weights. + +In our profiling shows experts transforming is hidden in the bubble between forward iterations. Cpu time cost of eplb algo. and other python operator such as log2phy +would be hidden by eplb worker process too. + +## Examples + +Currently swift balancer optimize 5ms TPOT with ep size 64 while cost less than 2ms for every layer expert movement. + +We add new parameters for eplb: +"dynamic_eplb":true --- enable dynamic eplb +"num_iterations_eplb_update": 400 -- forward iterations when eplb would begin +"gate_eplb":true -- eplb would update only once, false by default. +"num_wait_worker_iterations":30 -- forward iterations when eplb worker will finish cpu task. In our test defualt value 30 would cover most cases. +--additional-config '{ "dynamic_eplb":true,"num_iterations_eplb_update":400, "gate_eplb":true, "num_wait_worker_iterations":30}' \ No newline at end of file diff --git a/examples/eplb_generate_map.py b/examples/eplb_generate_map.py new file mode 100644 index 0000000000..b498e73a06 --- /dev/null +++ b/examples/eplb_generate_map.py @@ -0,0 +1,65 @@ +import numpy as np +import json +import argparse + + +def split_and_insert(n, k, m): + ''' + n: expert num + k: card num + m: redundant expert num, make sure m%k==0 + ''' + + A = np.arange(n) + + B = np.random.choice(n, size=m, replace=False) + + groups = np.array_split(A, k) + + for j in range(m // k): + for i in range(k): + groups[i] = np.append(groups[i], B[i + j * k]) + return np.concatenate(groups) + + +def random_generation(n_layer=58, n_expert=256, start_layer_idx=0, device_count=128, n_redundant=128, output_name=""): + expert_data = {} + expert_data["moe_layer_count"] = n_layer + layer_list = [] + for i in range(n_layer): + layer = {"layer_id": start_layer_idx + i, "device_count": device_count} + random_placement = split_and_insert(n_expert, device_count, n_redundant) + device_list = [] + step = random_placement.shape[0] // device_count + for j in range(device_count): + device = {} + device["device_id"] = j + device["device_expert"] = random_placement[j * step: (j + 1) * step].tolist() + device_list.append(device) + layer["device_list"] = device_list + layer_list.append(layer) + + expert_data["layer_list"] = layer_list + json_file_path = output_name + + with open(json_file_path, "w") as f: + json.dump(expert_data, f, indent=4) + + print(f"JSON file generated: {json_file_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="python generate_map.py --n_layers 2 --n_experts 256 --card_num 8 --n_redundant 8 --output expert_map.json") + parser.add_argument("--n_layers", type=int, required=True) + parser.add_argument("--n_experts", type=int, required=True) + parser.add_argument("--card_num", type=int, required=True) + parser.add_argument("--n_redundant", type=int, default=0) + parser.add_argument("--output", type=str, default="expert_map.json") + args = parser.parse_args() + + n_layers = args.n_layers + n_experts = args.n_experts + card_num = args.card_num + n_redundant = args.n_redundant + output = args.output + + random_generation(n_layers, n_experts, 0, card_num, n_redundant, output) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8ea67994ea..838edff927 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -37,6 +37,10 @@ def __init__(self, vllm_config): ascend_scheduler_config) self.expert_map_path = additional_config.get("expert_map_path", None) + self.dynamic_eplb = additional_config.get("dynamic_eplb", False) + self.num_iterations_eplb_update = additional_config.get("num_iterations_eplb_update", 400) + self.gate_eplb = additional_config.get("gate_eplb", False) + self.num_wait_worker_iterations = additional_config.get("num_wait_worker_iterations", 30) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) self.enable_weight_nz_layout = additional_config.get( diff --git a/vllm_ascend/eplb/__init__.py b/vllm_ascend/eplb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/eplb/adaptor/abstract_adaptor.py b/vllm_ascend/eplb/adaptor/abstract_adaptor.py new file mode 100644 index 0000000000..0fffa6123d --- /dev/null +++ b/vllm_ascend/eplb/adaptor/abstract_adaptor.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from abc import ABC, abstractmethod + +class EplbAdaptor(): + + def __init__(self, **args): + pass + + @abstractmethod + def get_rank_expert_workload(self): + raise NotImplementedError + + @abstractmethod + def get_init_expert_map(self, num_moe_layers=None): + raise NotImplementedError + + @abstractmethod + def do_update_expert_map(self, layer_id: Any, + updated_expert_map: Any) -> Any: + raise NotImplementedError + + @abstractmethod + def do_update_expert_weight(self, layer_id: Any, + local_expert_to_replace: Any, + buffer_tensor_id: Any) -> Any: + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py new file mode 100644 index 0000000000..df6259e39a --- /dev/null +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -0,0 +1,204 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os +import json +import torch +import random +import torch.distributed as dist +import numpy as np + +from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor +from vllm.logger import logger + + + +class VllmEplbAdaptor(EplbAdaptor): + + def __init__(self, model, **args): + super().__init__(**args) + self.model = model + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + self.param_dict = dict(self.model.named_parameters()) + self.num_dense_layers = self.model.config.first_k_dense_replace + self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers + self.global_expert_num = self.model.config.n_routed_experts + + + # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here + self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset", + "w2_weight_scale", "w2_weight_offset"] + + self.expert_map_per_layer = dict() # reference to expert map on device for expert map update + self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently + for layer_idx in range(self.num_moe_layers): + self.expert_map_per_layer[self.num_dense_layers + layer_idx] =\ + self.model.get_expert_map(self.num_dense_layers + layer_idx) + + # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved + num_buffer_tensor = torch.where(self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + self.buffer_tensor_list = [[] for _ in range(num_buffer_tensor)] + self.init_buffer_tensor(num_buffer_tensor) + + self.expert_param_per_layer = dict() + self.init_expert_param_per_layer() + + self.log2phy_map_per_layer = dict() + for layer_idx in range(self.num_moe_layers): + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\ + self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + self.all_topk_ids = [] + + def init_buffer_tensor(self, num_buffer_tensor): + for name in self.expert_weight_names: + complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name + expert_tensor = self.param_dict[complete_name].data[0:num_buffer_tensor] + buffer_tensors = torch.empty_like(expert_tensor) + for buffer_id in range(num_buffer_tensor): + self.buffer_tensor_list[buffer_id].append(buffer_tensors[buffer_id]) + + def init_expert_param_per_layer(self): + num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) +\ + ".mlp.experts." + self.expert_weight_names[0]].data.shape[0] + for moe_layer_id in range(self.num_moe_layers): + layer_idx = self.num_dense_layers + moe_layer_id + self.expert_param_per_layer[layer_idx] = list() + for local_expert_id in range(num_local_expert): + self.expert_param_per_layer[layer_idx].append( + [self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name].data[local_expert_id] + for name in self.expert_weight_names] + ) + + def get_rank_expert_workload(self) -> torch.Tensor: + self.moe_load = self.model.get_all_moe_loads() + return self.moe_load + + def get_init_expert_map(self, num_moe_layers): + expert_map = self.model.get_all_expert_map(num_moe_layers) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + gathered = torch.empty((world_size, *expert_map.shape), # [W, L, E] + dtype=expert_map.dtype, + device=expert_map.device) + + dist.all_gather_into_tensor(gathered, expert_map) + all_maps = gathered.permute(1, 0, 2) + all_expert_maps = all_maps.cpu() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \ + all_expert_maps[layer_idx][self.rank_id] + + return all_expert_maps + + def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path): + + try: + expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(expert_map_path) + expert_map_all = self.local2global(expert_map_tensor) + except (TypeError, FileNotFoundError, OSError): + expert_map_all = self.determine_expert_map_all() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[layer_idx+3] = \ + expert_map_all[layer_idx][self.rank_id] + return expert_map_all + + def _expert_file_to_tensor(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + logger.error(f"failed to read expert_map_path: {expert_map_path}") + + def do_update_expert_map(self, layer_id, updated_expert_map): + self.expert_map_per_layer[layer_id].copy_(updated_expert_map) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): + for expert_tensor, buffer_tensor in zip( + self.expert_param_per_layer[layer_id][local_expert_to_replace], + self.buffer_tensor_list[buffer_tensor_id] + ): + expert_tensor.copy_(buffer_tensor) + + def do_update_log2phy_map(self, layer_id, updated_log2phy_map): + if self.log2phy_map_per_layer[layer_id] is not None: + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map) + + def local2global(self, + placement_local: torch.Tensor + ) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def determine_expert_map_all(self): + + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.num_moe_layers, self.world_size, self.global_expert_num), + -1, + dtype=torch.int32 + ) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(self.num_moe_layers, -1) + + return expert_map_all \ No newline at end of file diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py new file mode 100644 index 0000000000..db06037595 --- /dev/null +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import torch +import torch.distributed as dist +from enum import Enum + +from vllm.logger import logger + +class ExpertWeightUpdateState(Enum): + WAITING = 0 # waiting for updated expert_map by EplbWorker + READY = 1 # ready for d2d expert weights updating + TRANSFERING = 2 # d2d finished and waiting for updating expert_map into model + +class D2DExpertWeightLoader: + + def __init__(self, eplb_adaptor): + self.comm_op_list = None + self.eplb_adaptor = eplb_adaptor + + self.updated_expert_map = None + self.updated_log2phy_map = None + self.layer_id = -1 # layer id to be updated + self.state = ExpertWeightUpdateState.WAITING + self.recv_expert_list = [] + self.mock_flag = True + + def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, + updated_expert_map, layer_id): + # When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task + if self.state != ExpertWeightUpdateState.WAITING: + logger.error("current d2d weight update tasks are on-going, cannot accept new weight update task") + return + + # If neither send nor receive task is needed for this layer on this rank, return + if not (expert_send_info or expert_recv_info): + return + + self.updated_expert_map = updated_expert_map + + self.layer_id = layer_id + self.comm_op_list = [] + for send_info in expert_send_info: + dst_rank, global_expert_id_to_send = send_info + local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item() + for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]: + self.comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank)) + + buffer_tensor_id = 0 + for recv_info in expert_recv_info: + recv_rank, global_expert_id_to_recv = recv_info + for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]: + self.comm_op_list.append(dist.P2POp(dist.irecv, buffer_tensor, recv_rank)) + local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item() + self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id)) + buffer_tensor_id += 1 + + self.state = ExpertWeightUpdateState.READY + + def set_log2phy_map(self, log2phy_map): + self.updated_log2phy_map = log2phy_map + + def asyn_expert_weight_transfer(self, reqs): + # Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched + if self.state != ExpertWeightUpdateState.READY: + return + + # set asynchronous stream for d2d expert weight transfer + if self.comm_op_list: + ret_list = dist.batch_isend_irecv(self.comm_op_list) + reqs.extend(ret_list) + + self.state = ExpertWeightUpdateState.TRANSFERING + + def update_expert_map_and_weight(self, reqs, redundant_enable): + # Only after send/recv tasks have been luanched, expert_map and weight can be updated + if self.state != ExpertWeightUpdateState.TRANSFERING: + return + + # Waiting for send/recv tasks finish + for req in reqs: + req.wait() + + if self.comm_op_list is not None: + self.comm_op_list = None + + # update expert_map + self.eplb_adaptor.do_update_expert_map(self.layer_id, self.updated_expert_map) + + #update log2phy_map + if redundant_enable: + self.eplb_adaptor.do_update_log2phy_map(self.layer_id, self.updated_log2phy_map) + + # update expert weight + buffer_tensor_id = 0 + for recv_expert_info in self.recv_expert_list: + local_expert_to_replace, buffer_tensor_id = recv_expert_info + self.eplb_adaptor.do_update_expert_weight(self.layer_id, local_expert_to_replace, buffer_tensor_id) + + logger.info(f"[EPLB] finished update expert weight for layer: {self.layer_id}") + + self.recv_expert_list = [] + self.updated_expert_map = None + self.layer_id = -1 + self.state = ExpertWeightUpdateState.WAITING + + def load_impl(self, old_expert_table, new_expert_table): + raise NotImplementedError + diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py new file mode 100644 index 0000000000..70de56f5d7 --- /dev/null +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import random + + +def generate_log2phy_map(expert_map): + num_local_experts = expert_map.max() + 1 + log2phy_map = expert_map.clone() + num_ranks, num_global_expert = log2phy_map.shape + + row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks,\ + num_global_expert) * num_local_experts + log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1] + + for idx in range(num_global_expert): + positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0] + negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0] + num_rank_holding_expert = positive_rank_idx.size(0) + + if num_rank_holding_expert == 1: + log2phy_map[negative_rank_idx, idx] = torch.full((num_ranks - 1,), + log2phy_map[positive_rank_idx, idx].item(), + dtype=log2phy_map.dtype) + else: + random_list = [random.choice(log2phy_map[positive_rank_idx, idx]) + for _ in range(num_ranks - num_rank_holding_expert)] + log2phy_map[negative_rank_idx, idx] = torch.tensor(random_list,\ + dtype=log2phy_map.dtype) + + return log2phy_map + + +def determine_default_log2phy_map(global_expert_num, world_size, rank_id): + + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.world_size, self.global_expert_num), -1, dtype=torch.int32 + ) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[r, start:end] = local_ids + + log2phy_map_all = generate_log2phy_map(expert_map_all) + + return log2phy_map_all[rank_id] \ No newline at end of file diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py new file mode 100644 index 0000000000..07550aff6c --- /dev/null +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -0,0 +1,412 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import time +import numpy as np +import networkx as nx +import torch +import torch_npu +import logging +import torch.distributed as dist +from multiprocessing import Process, Queue, Manager +from abc import ABC, abstractmethod +from vllm.logger import logger + +from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory, DynamicConfig +from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map + + +class EplbWorker: + + def __init__(self, shared_dict, policy_type, enable_d2d: bool = True, redundant_enable=0): + self.policy_type = policy_type + self.policy = PolicyFactory.generate_policy(policy_type, DynamicConfig()) + self.shared_dict = shared_dict + self.old_expert_maps = None + self.enable_d2d = enable_d2d + self.redundant_enable = redundant_enable + self.rank_id = dist.get_rank() + + def do_update(self): + # put data in to queue + # in process self.policy.generate_policy() + # get epxert table && tensor + + # async stream + # D2D + # H2D + + # Get initial expert_map + if self.old_expert_maps is None: + self.old_expert_maps = self.get_init_expert_maps() + if self.old_expert_maps is not None: + self.num_local_experts = self.old_expert_maps.max() + 1 + else: + raise ValueError("Failed to get expert_maps from shared_dict.") + + # Get MOE load information + load_info = self.fetch_and_sum_load_info() + if load_info is None: + return + + # Get the updated expert table based on the workload information + old_placement = self.global2local(self.old_expert_maps, self.num_local_experts) + _, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement) + + if not torch.is_tensor(new_placement): + new_placement = torch.tensor(new_placement) + self.check_expert_placement(old_placement, new_placement) + new_expert_maps = self.local2global(new_placement) + self.update_expert_map(new_expert_maps) + logger.debug(f"[EPLB Process new_map differs, performing D2D") + + update_info = self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps) + self.old_expert_maps = new_expert_maps + logger.info("EPLB Process compute complete") + + packed_update_info = self.pack_update_info(update_info) + + return packed_update_info + + def check_expert_placement(self, old_placement, new_placement): + num_layers = old_placement.shape[0] + num_ranks = old_placement.shape[1] + + for layer_id in range(num_layers): + # check if any logical expert is not placed on any rank + if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel(): + logger.error(f"There exists expert not placed on any rank in layer {layer_id}") + new_placement[layer_id] = old_placement[layer_id] + continue + + for rank_id in range(num_ranks): + new_placement_check = new_placement[layer_id][rank_id] + old_placement_check = old_placement[layer_id][rank_id] + + # check if same logical experts are placed on the same NPU + if new_placement_check.numel() != torch.unique(new_placement_check).numel(): + logger.error(f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid") + new_placement[layer_id] = old_placement[layer_id] + break + + # check if there is any experts movement inside one NPU + expert_not_move = torch.isin(new_placement_check, old_placement_check) + if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]): + logger.error(f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid") + new_placement[layer_id] = old_placement[layer_id] + break + + def compose_expert_update_info_bipartite(self, updated_expert_maps_org, current_expert_maps_org): + # transform numpy array to torch tensor + updated_expert_maps = updated_expert_maps_org.clone() + current_expert_maps = current_expert_maps_org.clone() + updated_expert_maps = np.array(updated_expert_maps) + current_expert_maps = np.array(current_expert_maps) + + num_layers = current_expert_maps.shape[0] + num_ranks = current_expert_maps.shape[1] + num_experts = current_expert_maps.shape[2] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + updated_expert_maps_this_layer_org = updated_expert_maps_org[layer_id] + + from typing import Any + + expert_send_info_this_layer: dict[Any, Any] = {} + expert_recv_info_this_layer: dict[Any, Any] = {} + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if (np.equal(updated_expert_maps_this_layer, + current_expert_maps_this_layer)).all(): + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = np.where((current_expert_maps_this_layer == -1) + & (updated_expert_maps_this_layer != -1)) + + # record src ranks for potential transfer + src_ranks_set = dict() + for idx in range(len(dst_rank_indices)): + expert_id = experts_to_recv[idx].item() + if expert_id not in src_ranks_set: + src_ranks_set[expert_id] = np.where( + current_expert_maps_this_layer[:, expert_id] != -1)[0] + + # loop until all experts are scheduled + while len(dst_rank_indices) > 0: + # construct bipartite graph + graph_expert_update = nx.Graph() + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + # add src ranks + src_rank_ids = src_ranks_set[expert_id] + graph_expert_update.add_nodes_from(src_rank_ids, bipartite=0) + # add dest rank + graph_expert_update.add_node(str(dst_rank_id), bipartite=1) + # add edges + for src_rank_id in src_rank_ids: + graph_expert_update.add_edge(src_rank_id, str(dst_rank_id)) + + # graph may not be connected + connected_components = list(nx.connected_components(graph_expert_update)) + all_matches = {} + # matching in this loop + for i, component in enumerate(connected_components): + subgraph = graph_expert_update.subgraph(component) + component_matching = nx.bipartite.maximum_matching(subgraph) + all_matches.update(component_matching) + + for src_rank, dst_rank in all_matches.items(): + dst_rank = int(dst_rank) + assert src_rank != dst_rank + if graph_expert_update.nodes[src_rank]['bipartite'] == 0: + # currently not scheduled experts in rank dst_rank + experts_v = experts_to_recv[np.where( + dst_rank_indices == dst_rank)] + # src: src_rank, dest: dst_rank, expert: expert_id + expert_id = np.intersect1d(experts_v, np.where( + current_expert_maps_this_layer[src_rank] != -1))[0] + + # record send/rcv pairs + if src_rank not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank] = [] + if dst_rank not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank] = [] + expert_send_info_this_layer[src_rank].append((dst_rank, expert_id)) + expert_recv_info_this_layer[dst_rank].append((src_rank, expert_id)) + + remove_index = np.where(np.logical_and( + dst_rank_indices == dst_rank, experts_to_recv == expert_id)) + + # update + dst_rank_indices = np.delete( + dst_rank_indices, remove_index) + experts_to_recv = np.delete(experts_to_recv, remove_index) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, + updated_expert_maps_this_layer_org, layer_id) + + # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases + def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps): + num_layers = current_expert_maps.shape[0] + num_ranks = current_expert_maps.shape[1] + num_experts = current_expert_maps.shape[2] + + for layer_id in range(num_layers): + updated_expert_maps_this_layer = updated_expert_maps[layer_id] + current_expert_maps_this_layer = current_expert_maps[layer_id] + + expert_send_info_this_layer: dict[Any, Any] = {} + expert_recv_info_this_layer: dict[Any, Any] = {} + + # Guard Clause: if there is no expert weight update, avoid subsequent processing + if torch.equal(updated_expert_maps_this_layer, current_expert_maps_this_layer): + yield (expert_send_info_this_layer, expert_recv_info_this_layer, updated_expert_maps_this_layer, layer_id) + + # Parse expert_ids each rank needs to receive from other ranks + dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \ + & (updated_expert_maps_this_layer != -1)) + + # Parse expert_ids each rank needs to send to other ranks + src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \ + & (updated_expert_maps_this_layer == -1)) + + for idx in range(len(dst_rank_indices)): + dst_rank_id = dst_rank_indices[idx].item() + expert_id = experts_to_recv[idx].item() + if dst_rank_id not in expert_recv_info_this_layer: + expert_recv_info_this_layer[dst_rank_id] = [] + + if not torch.isin(torch.tensor(expert_id), experts_to_send).any(): + # if expert_id are not sent out from any npu, it will be copied from one npu holding this expert + candidate_src_rank_indices = torch.where(current_expert_maps_this_layer[:, expert_id] != -1)[0] + else: + candidate_src_rank_indices = src_rank_indices[experts_to_send == expert_id] + + #TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node... + src_rank_id = candidate_src_rank_indices[0].item() + if src_rank_id not in expert_send_info_this_layer: + expert_send_info_this_layer[src_rank_id] = [] + + expert_send_info_this_layer[src_rank_id].append((dst_rank_id, expert_id)) + expert_recv_info_this_layer[dst_rank_id].append((src_rank_id, expert_id)) + + yield (expert_send_info_this_layer, expert_recv_info_this_layer, updated_expert_maps_this_layer, layer_id) + + + def calculate_rebalance_experts(self, load_info, old_placement): + """ + Compute `new_map` by calling the `rebalance_experts` method of the policy instance. + """ + if self.old_expert_maps is None: + return False, None, None + + changed, priority, new_map = self.policy.rebalance_experts(old_placement, load_info) + return changed, priority, new_map + + def get_init_expert_maps(self): + """ + Read the initial expert_map from shared_dict. + """ + return self.shared_dict.get("expert_maps", None) + + def fetch_and_sum_load_info(self): + """ + Each time the subprocess is awakened, read the latest moe_load + (shape: [num_moe_layers, num_experts_per_layer]) from shared_dict. + """ + return self.shared_dict.get("moe_load", None) + + def update_expert_map(self, expert_maps): + + self.shared_dict["expert_maps"] = expert_maps + + def global2local(self, + placement: torch.Tensor, + E_local: int + ) -> tuple[torch.Tensor, torch.Tensor]: + + L, G, _ = placement.shape + device = placement.device + + pt_local = torch.full((L, G, E_local), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement >= 0 + l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True) + + slot_idx = placement[l_idx, g_idx, k_idx] + + pt_local[l_idx, g_idx, slot_idx] = k_idx + + return pt_local + + + def local2global(self, + placement_local: torch.Tensor + ) -> torch.Tensor: + + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def pack_update_info(self, update_info_generator): + """ + Pack a list of update info tuples for efficient IPC. + """ + send_all = [] + recv_all = [] + maps = [] + log2phy_all = [] + layer_ids = [] + + for send_info, recv_info, new_expert_map, layer_id in update_info_generator: + + send_info_this_rank = send_info[self.rank_id] if self.rank_id in send_info else [] + recv_info_this_rank = recv_info[self.rank_id] if self.rank_id in recv_info else [] + send_all.append(send_info_this_rank) + recv_all.append(recv_info_this_rank) + + maps.append(new_expert_map[self.rank_id].numpy().tolist()) + + if self.redundant_enable: + log2phy_map = generate_log2phy_map(new_expert_map) + log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist()) + else: + log2phy_all.append([]) + + layer_ids.append(layer_id) + + return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + +class EplbProcess: + def __init__(self, shared_dict, planner_q, block_update_q, redundant_enable, policy_type: int = 0, enable_d2d: bool = True): + """ + Args: + shared_dict: Cross-process shared dict returned by Manager().dict() + policy_type: Integer passed to PolicyFactory.generate_policy + enable_d2d: Whether to enable D2D loading + """ + self.shared_dict = shared_dict + self.policy_type = policy_type + self.enable_d2d = enable_d2d + self.planner_q = planner_q + self.block_update_q = block_update_q + self.redundant_enable = redundant_enable + + # Create EplbWorker instance + self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d, self.redundant_enable) + + + def worker_process(self, planner_q, block_update_q): + """ + Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. + """ + while True: + try: + + planner_q.get() + + packed_update_info = self.worker.do_update() + + while True: + if not block_update_q.empty(): + continue + block_update_q.put(packed_update_info) + break + + except Exception as e: + logger.warning(f"[EPLB subprocess Exiting due to error: {e}", exc_info=True) + break + + def _launch_process(self): + """ + Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). + """ + proc = Process( + target=self.worker_process, + args=(self.planner_q,self.block_update_q), + daemon=True + ) + + proc.start() + return proc + diff --git a/vllm_ascend/eplb/core/policy/__init__.py b/vllm_ascend/eplb/core/policy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/eplb/core/policy/policy_abstract.py b/vllm_ascend/eplb/core/policy/policy_abstract.py new file mode 100644 index 0000000000..36b2e1d376 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_abstract.py @@ -0,0 +1,40 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from abc import abstractmethod + + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host + ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed + num_die_per_host = 8 # Number of dies on each host machine + + +class EplbPolicy: + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + Pass in the weights and return expert replication and placement under relevant constraints. + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py b/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py new file mode 100644 index 0000000000..7fa030724f --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py @@ -0,0 +1,338 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +from collections import defaultdict +import numpy as np + +from .policy_abstract import EplbPolicy, DynamicConfig + + +class DynamicTable: + # workload_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + workload_table = None + + # placement_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + placement_table = None + + +class DynamicEplb(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def add_redundant(current_expert_table, expert_workload, num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + # Split hot (high-load) experts into redundant experts + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + # Step 1: Sort the items by weight in descending order (we are sorting by weight now) + # Sort based on the second element (the second value of each tuple) + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + # Step 2: Calculate the number of items per box + expert_num = route_expert_num + num_redundancy_expert + items_per_box = expert_num // card_num # Number of items per box + remaining_items = expert_num % card_num # Number of items per box + + # Step 3: Initialize card_num boxes with empty lists to store item IDs + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num # To store the total weight of each box + box_counts = [0] * card_num # To store the number of items in each box + index = 0 + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + cur_weight = 0 + for item, weight in origin_weights: + if item == i: + cur_weight = weight + + boxes[index].append(i) + boxes_weights[index].append(cur_weight) + box_weights[index] += cur_weight + box_counts[index] += 1 + index += 1 + + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + origin_weights = [origin_weights[idx] for idx in sorted_indices] + # Step 4: Distribute items into boxes based on weight + for item_id, weight in origin_weights: + # Find the box with the least items but not full + min_box_index = -1 + for i in range(card_num): + if item_id in boxes[i]: + continue + # Only choose boxes that still have space (box_counts[i] < items_per_box) + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + # Place the item (id) into the selected box + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + # If there's an imbalance in the remaining items, reduce the "remaining_items" counter + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + # Step 5: Output each box's contents and total weight + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], # List of item IDs in the box + "weight": boxes_weights[i], + "total_weight": box_weights[i], # Total weight in this box + "item_count": box_counts[i] # Number of items in the box + }) + + return result, boxes + + # Split hot (high-load) experts into redundant experts + @staticmethod + def compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert): + route_expert_num = len(origin_weights) + route_expert_redundancy = [[] for _ in range(route_expert_num)] + for i in range(num_redundancy_expert): + sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] + weights = [origin_weights[idx] for idx in sorted_indices] + tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) + route_expert_redundancy[weights[0][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) + weights[0] = (weights[0][0], avg_weight) + origin_weights = weights + + expert_num = route_expert_num + num_redundancy_expert + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + all_weights = np.zeros((expert_num,), dtype='object') + all_weights[: route_expert_num] = origin_weights + + index = route_expert_num + for i in range(route_expert_num): + redundancy_num = len(route_expert_redundancy[i]) + for _ in range(redundancy_num): + for item, weight in origin_weights: + if item == i: + all_weights[index] = (item, weight) + index += 1 + + sorted_indices = np.argsort([t[1] for t in all_weights], kind='stable')[::-1] + all_weights = [all_weights[idx] for idx in sorted_indices] + for item_id, weight in all_weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + if item_id not in boxes[i]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + # Scheme without redundant experts + @staticmethod + def compute_balanced_pack(origin_weights, card_num): + sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1] + weights = origin_weights[sorted_indices] + expert_num = len(weights) + if card_num == 0: + raise RuntimeError("card_num can not be 0.") + items_per_box = expert_num // card_num + remaining_items = expert_num % card_num + + boxes = [[] for _ in range(card_num)] + boxes_weights = [[] for _ in range(card_num)] + box_weights = [0] * card_num + box_counts = [0] * card_num + + for item_id, weight in weights: + min_box_index = -1 + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: + min_box_index = i + + boxes[min_box_index].append(item_id) + boxes_weights[min_box_index].append(weight) + box_weights[min_box_index] += weight + box_counts[min_box_index] += 1 + + if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: + remaining_items -= 1 + + result = [] + for i in range(card_num): + result.append({ + "box_index": i + 1, + "items": boxes[i], + "weight": boxes_weights[i], + "total_weight": box_weights[i], + "item_count": box_counts[i] + }) + + return result, boxes + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [int(x) for x in current_expert_table[layer_id][card_id]] + new_list = [int(x) for x in global_deployment[layer_id][card_id]] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + + def rebalance_experts(self, current_expert_table, expert_workload): + + info = DynamicTable() + info.workload_table = np.array(expert_workload) + info.placement_table = np.array(current_expert_table) + layer_num, num_npus, experts_per_npu= info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + # Perform load balancing and deploy redundant experts + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + # Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards + if num_original_expert != expert_num: + raise ValueError(f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}") + + if num_npus <= 0: + raise ValueError("the number of NPUs must be greater than 0") + + if num_npus < num_redundancy_expert: + raise ValueError(f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}") + + # Number of experts deployed on each card includes one redundant expert + global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + # Iterate to obtain the placement strategy for each layer, taking computational balance into account + max_heat_per_layer_after = np.zeros([layer_num]) + for layer in range(layer_num): + # Get the expert IDs and their corresponding workloads for the current layer; + # workloads need to be normalized, and one redundant expert is added per card + weights = np.zeros((expert_num,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads[layer]): + weights[expert_id] = (expert_id, workload_weight) + + # Obtain the globally balanced placement strategy for each layer + result, layer_deployment = self.original_compute_balanced_pack_redundancy( + weights, num_npus, num_redundancy_expert + ) + + global_deployment[layer] = layer_deployment + max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_weight'])['total_weight'] + + new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) + # Obtain the priority of each layer + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + return change, per_layer_priority, np.array(new_global_deployment).tolist() + diff --git a/vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py b/vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py new file mode 100644 index 0000000000..0aa9525de7 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_dynamic_ep_v2.py @@ -0,0 +1,672 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +from collections import defaultdict +import numpy as np +from abc import abstractmethod + +class DynamicConfig: + placement_policy = None + + max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host + ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed + num_die_per_host = 8 # Number of dies on each host machine + + +class EplbPolicy: + def __init__(self, config: DynamicConfig): + self.config = config + + @abstractmethod + def rebalance_experts(self, current_expert_table, expert_workload): + """ + Pass in the weights and return expert replication and placement under relevant constraints. + INPUT: + current_expert_table: [layerId, rankId, expert_num_i] + expert_workload = expert_table[layer0][rankId][expert_num_i] + + RETURNED: (res, expert_table) + res: + 1 -- table_changed + 0 -- not_changed + + expert_table: [layerId, rankId, expert_num_i] + expert_num_i --- [0, MaxExpertPerRank] + expertID = expert_table[layer0][rankId][expert_num_i] + array_values: + [0, 1, 2, 3, 248] + [4, 5, 6, 7, 254] + [8, 9, 10, 11, 71] + ... + [252, 253, 254, 255, 0] + """ + pass + +class DynamicTable: + # workload_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + workload_table = None + + # placement_table: + # 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position + # Size: number of layers * number of GPUs * number of experts per GPU per layer + # The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer + # For experts that are not available or collected, the value is set to -1 + placement_table = None + +class DynamicEplbV2(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + + @staticmethod + def safe_divide(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a / b + + @staticmethod + def safe_exact_divide(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a // b + + @staticmethod + def safe_mod(a, b): + if b == 0: + print("Division by zero is not allowed") + return 0 + return a % b + + @staticmethod + def add_redundant(current_expert_table, expert_workload, num_original_expert): + layer_num, npu_num, experts_per_npu = expert_workload.shape + workload_new = np.zeros((layer_num, num_original_expert)) + for layer_idx in range(layer_num): + workload_dict = defaultdict(int) + placement_layer = current_expert_table[layer_idx].copy() + workload_layer = expert_workload[layer_idx].copy() + for npu_idx in range(npu_num): + for expert_idx in range(experts_per_npu): + workload_dict[placement_layer[npu_idx][expert_idx]] += workload_layer[npu_idx][expert_idx] + for expert_idx in range(num_original_expert): + workload_new[layer_idx][expert_idx] = workload_dict[expert_idx] + return workload_new + + @staticmethod + def get_redundant_num(npu_num, counts): + redundant_num_each_npu = np.sum(counts - 1) + return redundant_num_each_npu + + @staticmethod + def calculate_max_heat_per_layer(workload_table, layer_num): + max_heat_per_layer = [] + for layer_idx in range(layer_num): + npu_heats_now = np.sum(workload_table[layer_idx], axis=1) + max_heat_per_layer.append(np.max(npu_heats_now)) + return max_heat_per_layer + + def calculate_initial_imbalance(self, global_deployment, new_layer_workloads): + + device_num = global_deployment.shape[1] + layer_imbalance = [] + expert_num = np.zeros_like(new_layer_workloads) + for layer_id, layer in enumerate(global_deployment): + for device in layer: + for expert_id in device: + expert_num[layer_id][expert_id] += 1 + + for layer_id, layer in enumerate(global_deployment): + cur_layer_max_workload = 0 + total_workload = 0 + for box in layer: + box_workload = 0 + for expert_id in box: + update_workload = self.safe_divide(new_layer_workloads[layer_id][expert_id], expert_num[layer_id][expert_id]) + box_workload += update_workload + total_workload += update_workload + if cur_layer_max_workload < box_workload: + cur_layer_max_workload = box_workload + + cur_layer_imbalance = self.safe_divide(cur_layer_max_workload, (self.safe_divide(total_workload, device_num))) + layer_imbalance.append(cur_layer_imbalance) + + return layer_imbalance + + + def compute_redundant_assignments(self, base_experts, num_redundant_experts, num_experts): + + redundant_assignments = [[] for _ in range(num_experts)] + current_weights = base_experts.copy() + + for i in range(num_redundant_experts): + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + target_expert = sorted_weights[0] + expert_id, original_weight = target_expert + + current_redundancy = len(redundant_assignments[expert_id]) + new_avg_weight = self.safe_divide(original_weight * (current_redundancy + 1), (current_redundancy + 2)) + + redundant_assignments[expert_id].append(num_experts + i) + current_weights[sorted_indices[0]] = (expert_id, new_avg_weight) + + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + return redundant_assignments, sorted_weights + + def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos, num_experts, num_exist_expert, device_assignments, device_counts, expert_from_device, com_between_devices): + + current_weights = np.zeros((num_experts,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + current_weights[expert_id] = (expert_id, workload_weight) + + devices_with_slots = [] + for device_id, device_rendun_pos in enumerate(rendun_pos): + if len(device_rendun_pos) != 0: + devices_with_slots.append(device_id) + + while devices_with_slots: + sorted_indices = np.argsort([w for _, w in current_weights], kind='stable')[::-1] + sorted_weights = [current_weights[i] for i in sorted_indices] + + for index, target_weight in enumerate(sorted_weights): + expert_id, original_weight = target_weight + if original_weight == -1: + print("Error:Redundant expert failure re-occurred") + redundancy_successful = True + break + redundancy_successful = False + for cur_device_id in devices_with_slots: + if expert_id not in device_assignments[cur_device_id]: + pos = rendun_pos[cur_device_id].pop() + if len(rendun_pos[cur_device_id]) == 0: + devices_with_slots = [device_id for device_id in devices_with_slots if device_id != cur_device_id] + device_assignments[cur_device_id][pos] = expert_id + device_counts[cur_device_id] += 1 + communication_box_index = expert_from_device[expert_id] + com_between_devices[cur_device_id][communication_box_index] = expert_id + new_weight = self.safe_divide((original_weight * num_exist_expert[expert_id]), (num_exist_expert[expert_id] + 1)) + sorted_weights[index] = (expert_id, new_weight) + num_exist_expert[expert_id] += 1 + redundancy_successful = True + break + if redundancy_successful: + break + + sorted_indices = np.argsort([id for id, _ in sorted_weights], kind='stable') + sorted_weights = [sorted_weights[i][1] for i in sorted_indices] + + return sorted_weights, device_assignments, device_counts, com_between_devices + + @staticmethod + def prepare_expert_list(base_experts, redundant_assignments, num_redundant_experts): + redundant_expert_list = np.empty(num_redundant_experts, dtype=object) + + index = 0 + num_experts = len(redundant_assignments) + for expert_id in range(num_experts): + for _ in redundant_assignments[expert_id]: + redundant_expert_list[index] = (expert_id, next(w for eid, w in base_experts if eid == expert_id)) + index += 1 + + sorted_indices = np.argsort([w for _, w in redundant_expert_list], kind='stable')[::-1] + return [redundant_expert_list[i] for i in sorted_indices] + + + @staticmethod + def non_redundant_expert_information(origin_deployment, updated_weights, rendun_pos): + + device_num = len(origin_deployment) + num_experts_per_device = origin_deployment.shape[1] + device_assignments = [[-1 for _ in range(num_experts_per_device)] for _ in range(device_num)] + device_weights = [[0 for _ in range(num_experts_per_device)] for _ in range(device_num)] + device_loads = [0] * device_num + device_counts = [0] * device_num + + for device_id, device in enumerate(origin_deployment): + for index, expert_id in enumerate(device): + if index in rendun_pos[device_id]: + continue + device_assignments[device_id][index] = expert_id + cur_weight = next(weight for expert_id_of_weight, weight in updated_weights if expert_id_of_weight == expert_id) + device_weights[device_id][index] = cur_weight + device_loads[device_id] += cur_weight + device_counts[device_id] += 1 + + return device_assignments, device_weights, device_loads, device_counts + + def recomputing_initial_weight(self, layer_workloads, device_assignments): + num_all_experts = [0] * len(layer_workloads) + for device in device_assignments: + for expert_id in device: + if expert_id != -1: + num_all_experts[expert_id] += 1 + + cur_layer_workload = [] + for expert_id, weight in enumerate(layer_workloads): + if num_all_experts[expert_id] == 0: + cur_layer_workload.append(-1) + else: + cur_layer_workload.append(self.safe_divide(weight, num_all_experts[expert_id])) + + return cur_layer_workload, num_all_experts + + def distribute_redun_experts(self, layer_workloads, device_assignments, device_weights, device_loads, device_counts, redundant_expert_list, + expert_from_device, num_experts, rendun_pos): + + num_devices = len(device_assignments) + com_between_devices = [{} for _ in range(num_devices)] + + for expert_id, weight in redundant_expert_list: + candidate = -1 + for dev_id in range(num_devices): + if len(rendun_pos[dev_id]) == 0: + continue + if expert_id in device_assignments[dev_id]: + continue + if candidate == -1 or device_loads[dev_id] < device_loads[candidate]: + candidate = dev_id + if candidate != -1: + pos = rendun_pos[candidate].pop() + device_assignments[candidate][pos] = expert_id + device_weights[candidate][pos] = weight + device_loads[candidate] += weight + device_counts[candidate] += 1 + + communication_box_index = expert_from_device[expert_id] + com_between_devices[candidate][communication_box_index] = expert_id + + if any(sublist for sublist in rendun_pos): + cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(layer_workloads, device_assignments) + + update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments(cur_layer_workload, rendun_pos, + num_experts, num_exist_expert, + device_assignments, device_loads, + expert_from_device, com_between_devices) + + device_loads = [0] * len(device_counts) + for device_id, device in enumerate(device_assignments): + for index, expert_id in enumerate(device): + device_weights[device_id][index] = update_workload[expert_id] + device_loads[device_id] += update_workload[expert_id] + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + def redundancy_again(self, layer_workloads, origin_weights, origin_deployment, expert_from_device, num_node, + is_node_redundant, rendun_pos): + + + num_experts = len(origin_weights) + if is_node_redundant: + num_experts = num_experts * num_node + + num_redundant_experts = 0 + for rank_empty_pos in rendun_pos: + num_redundant_experts += len(rank_empty_pos) + + redundant_assignments, updated_weights = self.compute_redundant_assignments(origin_weights, + num_redundant_experts, + num_experts) + + redundant_expert_list = self.prepare_expert_list(updated_weights, redundant_assignments, num_redundant_experts) + + device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information( + origin_deployment, updated_weights, rendun_pos) + + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts( + layer_workloads, + device_assignments, + device_weights, + device_loads, + device_counts, + redundant_expert_list, + expert_from_device, + num_experts, + rendun_pos) + + + return device_assignments, device_weights, device_loads, device_counts, com_between_devices + + @staticmethod + def generate_allocation_report(device_assignments, device_weights, device_loads, device_counts): + + report = [] + max_load = 0.0 + + for dev_id in range(len(device_assignments)): + current_load = device_loads[dev_id] + max_load = max(max_load, current_load) + + report.append({ + "device_id": dev_id + 1, + "assigned_experts": device_assignments[dev_id], + "expert_weights": device_weights[dev_id], + "total_load": current_load, + "expert_count": device_counts[dev_id] + }) + + return report, max_load + + @staticmethod + def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id, next_device_id, cur_layer_result, + com_between_devices): + + cur_device_deployment = cur_layer_result[cur_device_id]['assigned_experts'] + next_device_deployment = cur_layer_result[next_device_id]['assigned_experts'] + + cur_device_weight = cur_layer_result[cur_device_id]['expert_weights'] + next_device_weight = cur_layer_result[next_device_id]['expert_weights'] + + cur_expert_id = cur_device_deployment[cur_exchange_index] + next_expert_id = next_device_deployment[next_exchange_index] + cur_device_deployment[cur_exchange_index] = next_expert_id + next_device_deployment[next_exchange_index] = cur_expert_id + + cur_expert_weight = cur_device_weight[cur_exchange_index] + next_expert_weight = next_device_weight[next_exchange_index] + cur_device_weight[cur_exchange_index] = next_expert_weight + next_device_weight[next_exchange_index] = cur_expert_weight + + cur_layer_result[cur_device_id]['total_load'] += next_expert_weight - cur_expert_weight + cur_layer_result[next_device_id]['total_load'] += cur_expert_weight - next_expert_weight + + com_between_devices[cur_device_id][next_device_id] = next_expert_id + com_between_devices[next_device_id][cur_device_id] = cur_expert_id + + def redundant_expert_deployment(self, layer_workloads, original_deployment, expert_from_device, node_num, + is_node_redundant, rendun_pos): + device_num, per_device_expert_num = original_deployment.shape + route_expert_num = layer_workloads.shape[0] + per_node_device_num = self.safe_exact_divide(device_num, node_num) + per_node_route_expert_num = per_node_device_num * (per_device_expert_num - 1) + + weights = np.zeros((route_expert_num,), dtype='object') + for expert_id, workload_weight in enumerate(layer_workloads): + weights[expert_id] = (expert_id, workload_weight) + + if is_node_redundant: + + device_assignments = [] + device_weights = [] + device_loads = [] + device_counts = [] + com_between_devices = [] + + for node_id in range(node_num): + cur_node_weights = weights[ + node_id * per_node_route_expert_num: (node_id + 1) * per_node_route_expert_num] + cur_original_deployment = original_deployment[ + node_id * per_node_device_num: (node_id + 1) * per_node_device_num] + + cur_node_rendun_pos = rendun_pos[node_id * per_node_device_num: (node_id + 1) * per_node_device_num] + + cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again( + layer_workloads, + cur_node_weights, + cur_original_deployment, + expert_from_device, + node_num, + is_node_redundant, + cur_node_rendun_pos) + device_assignments += cur_device_assignments + device_weights += cur_device_weights + device_loads += cur_device_loads + device_counts += cur_device_counts + com_between_devices += cur_com_between_devices + + else: + device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again( + layer_workloads, + weights, + original_deployment, + expert_from_device, + node_num, + is_node_redundant, + rendun_pos) + report, max_load = self.generate_allocation_report(device_assignments, device_weights, device_loads, + device_counts) + + return report, max_load, com_between_devices + + @staticmethod + def two_device_exchange_experts(cur_device_result, exchange_device_result, cur_exchanged_expert_id, + next_exchanged_expert_id, ave_workload, increment, num_redundancy_expert): + + cur_device_weight = cur_device_result['expert_weights'] + next_device_weight = exchange_device_result['expert_weights'] + + cur_device_expert_id = cur_device_result['assigned_experts'] + next_device_expert_id = exchange_device_result['assigned_experts'] + + cur_device_total_weight = cur_device_result['total_load'] + next_device_total_weight = exchange_device_result['total_load'] + max_weight = max(cur_device_total_weight, next_device_total_weight) + + cur_exchange_index = -1 + next_exchange_index = -1 + + for index, weight in enumerate(cur_device_weight): + for next_index, next_weight in enumerate(next_device_weight): + change_flag = True + if (cur_device_expert_id[index] in next_device_expert_id or next_device_expert_id[next_index] in cur_device_expert_id): + change_flag = False + if (cur_device_expert_id[index] not in cur_exchanged_expert_id) and ( + next_device_expert_id[next_index] not in next_exchanged_expert_id) and change_flag: + + cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight + next_total_weight_after_exchange = next_device_total_weight - next_weight + weight + exchange_max_weight = max(cur_total_weight_after_exchange, next_total_weight_after_exchange) + if exchange_max_weight < max_weight and (max_weight - exchange_max_weight) >= ( + ave_workload * increment): + max_weight = exchange_max_weight + cur_exchange_index = index + next_exchange_index = next_index + + return cur_exchange_index, next_exchange_index + + def expert_exchange_between_devices(self, ave_workload, increment, cur_layer_result, com_between_devices, num_redundancy_expert, + node_idx=0, per_node_device_num=0, is_node_redundant=False): + + if is_node_redundant: + cur_devices_result = cur_layer_result[node_idx * per_node_device_num:(node_idx + 1) * per_node_device_num] + else: + cur_devices_result = cur_layer_result + + devices_total_weight = [] + for device in cur_devices_result: + devices_total_weight.append((device['total_load'], device['device_id'] - 1)) + + exchange_frequency = 100 + while exchange_frequency > 0: + exchange_frequency -= 1 + devices_total_weight.sort(key=lambda x: x[0]) + max_weight_device_id = devices_total_weight[-1][1] + exchange = False + for index in range(0, len(devices_total_weight) - 1): + min_weight_device_id = devices_total_weight[index][1] + if min_weight_device_id not in com_between_devices[max_weight_device_id]: + cur_exchanged_expert_id = list(com_between_devices[max_weight_device_id].values()) + next_exchanged_expert_id = list(com_between_devices[min_weight_device_id].values()) + + cur_exchange_index, next_exchange_index = self.two_device_exchange_experts( + cur_layer_result[max_weight_device_id], + cur_layer_result[min_weight_device_id], + cur_exchanged_expert_id, + next_exchanged_expert_id, + ave_workload, + increment, + num_redundancy_expert) + + if cur_exchange_index != -1: + self.exchange_expert(cur_exchange_index, + next_exchange_index, + max_weight_device_id, + min_weight_device_id, + cur_layer_result, + com_between_devices) + + devices_total_weight[-1] = ( + cur_layer_result[max_weight_device_id]['total_load'], max_weight_device_id) + devices_total_weight[index] = ( + cur_layer_result[min_weight_device_id]['total_load'], min_weight_device_id) + exchange = True + break + + if not exchange: + break + + def exchange_experts(self, layer_result, layer_com_between_devices, num_nodes, device_num, is_node_redundant, + ave_workload, increment, num_redundancy_expert, org_deployment): + + global_deployment = [] + + if is_node_redundant: + per_node_device_num = self.safe_exact_divide(device_num, num_nodes) + for node_idx in range(num_nodes): + self.expert_exchange_between_devices(ave_workload, increment, layer_result, + layer_com_between_devices, num_redundancy_expert, + node_idx, per_node_device_num, is_node_redundant) + else: + self.expert_exchange_between_devices(ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert) + + max_workload = 0 + for box in layer_result: + global_deployment.append(box['assigned_experts']) + if max_workload < box['total_load']: + max_workload = box['total_load'] + + global_deployment = np.array(global_deployment) + + return global_deployment, max_workload + + + def count_elements(self, lst): + count = 0 + for item in lst: + if isinstance(item, list): + count += self.count_elements(item) + else: + count += 1 + return count + + + @staticmethod + def constraint_expert_local_exchange(current_expert_table, global_deployment): + for layer_id in range(len(global_deployment)): + for card_id in range(len(global_deployment[layer_id])): + current_list = [int(x) for x in current_expert_table[layer_id][card_id]] + new_list = [int(x) for x in global_deployment[layer_id][card_id]] + num = len(new_list) + + new_index = [-1] * num + new_result = [-1] * num + remaining_elements = [] + + for i in range(num): + flag = True + for j in range(num): + if new_list[i] == current_list[j] and new_index[j] == -1: + new_index[j] = 0 + new_result[j] = current_list[j] + flag = False + break + if flag: + remaining_elements.append(new_list[i]) + + index = 0 + for k in range(num): + if new_result[k] == -1: + new_result[k] = remaining_elements[index] + index += 1 + + global_deployment[layer_id][card_id] = new_result + + return global_deployment + + + def rebalance_experts(self, current_expert_table, expert_workload, is_node_redundant = False, increment = 0.01): + info = DynamicTable() + info.workload_table = expert_workload.numpy() + info.placement_table = current_expert_table.numpy() + layer_num, num_npus, experts_per_npu= info.workload_table.shape + expert_ids, counts = np.unique(info.placement_table[0], return_counts=True) + num_redundancy_expert = self.get_redundant_num(num_npus, counts) + num_original_expert = len(expert_ids) + layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) + max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) + npu_heat_all_origin = sum(max_heat_per_layer_before) + + num_node = self.safe_exact_divide(num_npus, 8) + layer_num = layer_workloads.shape[0] + expert_num = layer_workloads.shape[1] + expert_from_device = np.zeros((layer_num, num_original_expert)) + + if num_original_expert != expert_num: + raise ValueError(f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})") + + if num_npus <= 0: + raise ValueError("The number of NPUs must be greater than 0") + + if num_npus < num_redundancy_expert: + raise ValueError(f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})") + + global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)] + layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads) + max_heat_per_layer_after = np.zeros([layer_num]) + sum_num = 0 + for layer in range(layer_num): + # print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer]) + if layer_initial_imbalance[layer] < 1.01: + global_deployment[layer] = info.placement_table[layer] + continue + + ave_workload = self.safe_divide(np.sum(layer_workloads[layer]), num_npus) + + rendun_pos = [[] for _ in range(num_npus)] + existing_experts = set() + for device_id, device in enumerate(info.placement_table[layer]): + for index, expert_id in enumerate(device): + if expert_id not in existing_experts: + existing_experts.add(expert_id) + expert_from_device[layer][expert_id] = device_id + else: + rendun_pos[device_id].append(index) + + result, max_workload, com_between_devices = self.redundant_expert_deployment(layer_workloads[layer], + info.placement_table[layer], + expert_from_device[layer], + num_node, is_node_redundant, rendun_pos) + # print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload)) + + global_deployment[layer], new_max_workload = self.exchange_experts(result, com_between_devices, + num_node, num_npus, is_node_redundant, ave_workload, + increment, num_redundancy_expert, info.placement_table[layer]) + # print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload)) + + for device_id in range(num_npus): + com_between_devices[device_id] = {key: value for key, value in + com_between_devices[device_id].items()} + sum_num += self.count_elements(com_between_devices[device_id]) + + max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_load'])['total_load'] + + layer_changed_ratio = [] + for layer_idx in range(layer_num): + layer_changed_ratio.append(self.safe_divide(max_heat_per_layer_after[layer_idx], max_heat_per_layer_before[layer_idx])) + + per_layer_priority = np.argsort(layer_changed_ratio) + npu_heat_all_after = sum(max_heat_per_layer_after) + + change = 0 + if npu_heat_all_after < 0.95 * npu_heat_all_origin: + change = 1 + + new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) + + return change, per_layer_priority, np.array(new_global_deployment).tolist() \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/policy_factory.py b/vllm_ascend/eplb/core/policy/policy_factory.py new file mode 100644 index 0000000000..5f70f05fc1 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_factory.py @@ -0,0 +1,22 @@ +# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from .policy_abstract import EplbPolicy, DynamicConfig +from .policy_random import RandomLoadBalance +from .policy_dynamic_ep import DynamicEplb +from .policy_dynamic_ep_v2 import DynamicEplbV2 + + + +class PolicyFactory: + @staticmethod + def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy: + policy = { + # Constraint applying Dynamic EPLB policy V2: + # If there exists redundant expert: + # only one redundant expert can be placed in one NPU and its physical expert index must be 0 + + # Applying greedy d2d expert weight update composing + 0:RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3 + 1:DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load + 2:DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle + } + return policy.get(policy_type, RandomLoadBalance)(config) diff --git a/vllm_ascend/eplb/core/policy/policy_random.py b/vllm_ascend/eplb/core/policy/policy_random.py new file mode 100644 index 0000000000..ff2234f372 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_random.py @@ -0,0 +1,30 @@ +# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import copy +import random +import torch +import torch + +from .policy_abstract import EplbPolicy, DynamicConfig + +random.seed(42) + +class RandomLoadBalance(EplbPolicy): + def __init__(self, config: DynamicConfig): + super().__init__(config) + + def rebalance_experts(self, current_expert_table, expert_workload): + new_table = copy.deepcopy(current_expert_table) + num_layers = len(current_expert_table) + num_card = len(current_expert_table[0]) + + for i in range(num_layers): + # randomly choose two card + # indices = random.sample(range(num_card), 2) + indices = [3,1] + + # swap redundant experts + expert_id_to_exchange = new_table[i][indices[0]][-1].clone() + new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1] + new_table[i][indices[1]][-1] = expert_id_to_exchange + + return 1, [-i for i in range(num_layers)], new_table \ No newline at end of file diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py new file mode 100644 index 0000000000..c50ec766bf --- /dev/null +++ b/vllm_ascend/eplb/eplb_updator.py @@ -0,0 +1,217 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +import numpy +import torch.distributed as dist +import vllm.envs as envs +from multiprocessing import Queue, Manager + +from vllm.logger import logger +from vllm_ascend.eplb.core.eplb_worker import EplbProcess +from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader +from vllm_ascend.ascend_config import get_ascend_config +from queue import Queue +from typing import Any + + +class EplbUpdator: + + def __init__(self, expert_map_path): + self.ascend_config = get_ascend_config() + self.init_eplb(expert_map_path) + + def set_adaptor(self, adaptor): + self.adaptor = adaptor + self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor) + self.num_moe_layers = self.adaptor.num_moe_layers + self.global_expert_num = self.adaptor.global_expert_num + + def init_eplb(self, expert_map_path): + self.rank_id = dist.get_rank() + self.num_expert_load_gather = 10 + self.periodic_load_gather = True + self.redundant_enable = (expert_map_path is not None) + self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update + self.expert_map_path = expert_map_path + + try: + if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + except Exception as e: + self.num_expert_load_gather = self.num_iterations_eplb_update + self.periodic_load_gather = False + + self.expert_map_initialized = False + self.gate_eplb = self.ascend_config.gate_eplb + + self.reqs = [] + self.update_info_all = [] + + self.cur_iterations: torch.int64 = 0 + + self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations + + self.planner_block_queue = Queue() + self.block_update_queue = Queue(maxsize=1) + + self.manager = Manager() + self.shared_dict = self.manager.dict({ + # 当前rank_id的专家表[num_layers,num_experts] + "expert_map": None, + # 热度负载信息 [num_layers, world_size, num_experts] + "moe_load": None, + # 所有的专家表[num_layers, world_size, num_experts] + "expert_maps": None, + }) + + self.eplb = EplbProcess( + shared_dict = self.shared_dict, + planner_q = self.planner_block_queue, + block_update_q = self.block_update_queue, + redundant_enable = self.redundant_enable, + policy_type = 1, + enable_d2d = True + ) + + self.eplb_process = self.eplb._launch_process() + + logger.info(f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})") + + def update_iteration(self): + self.cur_iterations += 1 + if self.cur_iterations == (self.num_iterations_eplb_update +\ + self.num_wait_worker_iterations + self.num_moe_layers): + self.adaptor.model.clear_all_moe_loads() + if not self.gate_eplb: + self.cur_iterations = 0 + + def get_update_info_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update + self.num_wait_worker_iterations - 1) + + def wakeup_eplb_worker_flag(self): + return self.cur_iterations == (self.num_iterations_eplb_update - 1) + + def update_expert_weight_flag(self): + weight_update_counter = self.cur_iterations - (self.num_iterations_eplb_update + self.num_wait_worker_iterations) + return (weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers) + + def get_init_expert_map(self): + try: + if not self.expert_map_initialized: + self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers, self.expert_map_path) + self.expert_map_initialized = True + except Exception as e: + logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", exc_info=True) + + def wakeup_eplb_worker(self): + self.planner_block_queue.put(1) + + def forward_before(self): + if self.update_expert_weight_flag(): + (expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(0) + rank_id = torch.distributed.get_rank() + if self.redundant_enable: + log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map)) + self.eplb_loader.set_log2phy_map(log2phy_map_this_rank) + updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map)) + #logger.info(f"check update info, layer = {layer_id}, send = {expert_send_info_this_rank}, recv = {expert_recv_info_this_rank}") + self.eplb_loader.generate_expert_d2d_transfer_task(expert_send_info, expert_recv_info, + updated_expert_map_this_rank, layer_id + self.adaptor.num_dense_layers) + + # set asynchronous stream for d2d expert weight update + self.reqs = [] + self.eplb_loader.asyn_expert_weight_transfer(self.reqs) + + def take_update_info_from_eplb_process(self): + # Batch after eplb process being triggered, get update info provided by eplb process + if self.get_update_info_flag(): + self.update_info_all = self.block_update_queue.get() + + + def forward_end(self): + if self.wakeup_eplb_worker_flag(): + moe_load = self.compute_and_set_moe_load(is_clear=True) + self.wakeup_eplb_worker() + + if self.update_expert_weight_flag(): + self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable) + + self.update_iteration() + + def compute_and_set_moe_load(self, is_clear=False): + local_load = self.adaptor.get_rank_expert_workload() + + self._gather_buffer = None + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") + else: + moe_load = local_load.unsqueeze(1) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") + return moe_load + + def warm_up_eplb(self): + + self.get_init_expert_map() + self.compute_and_set_moe_load() + + src_tensor = torch.empty((1,), device=self.device) + self_rank = dist.get_rank() + + comm_op_list = [] + + for dst_rank in range(self.world_size): + if dst_rank == self_rank: + continue + comm_op_list.append( + dist.P2POp(dist.isend, src_tensor, dst_rank) + ) + + for src_rank in range(self.world_size): + if src_rank == self_rank: + continue + comm_op_list.append( + dist.P2POp(dist.irecv, src_tensor, src_rank) + ) + if comm_op_list: + reqs = dist.batch_isend_irecv(comm_op_list) + + for req in reqs: + req.wait() + + def shutdown(self): + """ + Clean up the EPLB process. + """ + if self.eplb_process.is_alive(): + self.eplb_process.terminate() + self.eplb_process.join() + logger.info("[ModelRunner] EPLB process terminated") diff --git a/vllm_ascend/eplb/tool/eplb_utils.py b/vllm_ascend/eplb/tool/eplb_utils.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index fb1ed6f11b..fe5b35f792 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -735,6 +735,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.num_dense_layers = self.config.first_k_dense_replace + self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers self.model = CustomDeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) @@ -773,6 +775,32 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params + def get_expert_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_map() + + def get_log2phy_map(self, layer_id): + return self.model.layers[layer_id].mlp.experts.get_log2phy_map() + + def get_all_expert_map(self, num_moe_layers): + all_loads = [] + for layer_id in range(num_moe_layers): + load_tensor = self.get_expert_map(layer_id + self.num_dense_layers) # (num_experts_per_layer,) + all_loads.append(load_tensor) + + return torch.stack(all_loads, dim=0) + + def get_all_moe_loads(self): + all_moe_loads = torch.stack( + [self.model.layers[layer_id + self.num_dense_layers].mlp.experts.moe_load \ + for layer_id in range(self.num_moe_layers)], + dim=0 + ) + return all_moe_loads + + def clear_all_moe_loads(self): + for layer_id in range(self.num_moe_layers): + self.model.layers[layer_id + self.num_dense_layers].mlp.experts.clear_moe_load() + class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): pass diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 3fa9c8be74..c78727b859 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1143,6 +1143,7 @@ def __init__( ascend_config = get_ascend_config() expert_map_path = ascend_config.expert_map_path + self.dynamic_eplb = ascend_config.dynamic_eplb if expert_map_path and os.path.exists(expert_map_path): # moe expert load balance expert_load_balancer = ExpertLoadBalancer(expert_map_path, @@ -1158,6 +1159,10 @@ def __init__( # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( self.ep_size, self.ep_rank, self.global_num_experts) + if self.dynamic_eplb: + from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map + self.log2phy = determine_default_log2phy_map(self.global_num_experts, + self.ep_size, self.ep_rank) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = ( @@ -1188,6 +1193,10 @@ def __init__( local_num_experts = (torch.sum(self.expert_map != -1) if self.expert_map is not None else num_experts) + self.moe_load = None + if self.dynamic_eplb: + self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + moe_quant_params = { "num_experts": local_num_experts, "hidden_size": hidden_size, @@ -1337,9 +1346,19 @@ def forward( token_dispatcher=self.token_dispatcher, ) - if shared_experts: - if isinstance(e_hidden_states, tuple): - e_hidden_states, shared_hidden_states = e_hidden_states + if isinstance(e_hidden_states, tuple): + if len(e_hidden_states) == 4: + e_hidden_states, shared_hidden_states, expert_token_num, group_list_type = e_hidden_states + else: + e_hidden_states, expert_token_num, group_list_type = e_hidden_states + + if self.dynamic_eplb: + self.moe_load += expert_token_num if group_list_type else \ + torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) + + # if shared_experts: + # if isinstance(e_hidden_states, tuple): + # e_hidden_states, shared_hidden_states = e_hidden_states if fused_moe_state != FusedMoEState.AllGather: if tp_size > 1: @@ -1370,6 +1389,18 @@ def forward( else: return final_hidden_states + def update_map(self,new_expert_map): + self.expert_map = new_expert_map + + def get_map(self): + return self.expert_map + + def get_log2phy_map(self): + return self.log2phy + + def clear_moe_load(self): + self.moe_load.zero_() + # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_fused_moe_comp( diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index c2fe1db51c..ae315139d4 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -221,7 +221,7 @@ def fused_experts_with_mc2( mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: assert mc2_mask is not None - if log2phy: + if log2phy is not None: topk_ids = log2phy[topk_ids] quant_mode = 2 ep_group = get_mc2_group() @@ -326,14 +326,15 @@ def fused_experts_with_mc2( hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + group_list_type = 1 if shared_experts is None: - return hidden_states + return hidden_states, expert_token_nums, group_list_type else: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(shared_act, down_out_list) shared_output, _ = shared_experts.down_proj( (shared_act, swiglu_out_scale)) - return hidden_states, shared_output + return hidden_states, shared_output, expert_token_nums, group_list_type def fused_prefill_experts_with_mc2( @@ -551,7 +552,7 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor, ) if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states + return final_hidden_states, expert_tokens, group_list_type def fused_experts(hidden_states: torch.Tensor, @@ -665,7 +666,7 @@ def fused_experts(hidden_states: torch.Tensor, if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states + return final_hidden_states, expert_tokens, group_list_type class AscendW8A8DynamicLinearMethod: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1207618706..c25ae062a8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -86,6 +86,10 @@ from vllm_ascend.utils import ProfileExecuteDuration from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer +from vllm_ascend.eplb.eplb_updator import EplbUpdator +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -387,6 +391,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank + #EPLB + self.dynamic_eplb = ascend_config.dynamic_eplb + if self.dynamic_eplb == True: + self.eplb_adaptor = None + self.is_eplb_warmuped = False + self.eplb_updator = EplbUpdator(ascend_config.expert_map_path) + # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True self.in_profile_run = False @@ -1335,11 +1346,18 @@ def execute_model( # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) + + if self.dynamic_eplb: + self.eplb_updator.forward_before() + (attn_metadata, hidden_states, spec_decode_metadata, positions, num_scheduled_tokens, sample_indices, finished_sending, finished_recving) = (self._process_reqs(scheduler_output, intermediate_tensors)) + if self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + with ProfileExecuteDuration().capture_async("post process"): logits = self.model.compute_logits(hidden_states[sample_indices], None) @@ -1453,6 +1471,9 @@ def execute_model( logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) + if self.dynamic_eplb: + self.eplb_updator.forward_end() + return model_runner_output def kv_connector_no_forward( @@ -1638,6 +1659,9 @@ def _dummy_run( attn_state=attn_state, ) + if not is_torchair_compile and not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.forward_before() + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1724,6 +1748,13 @@ def _dummy_run( if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) self.drafter.dummy_run(num_reqs, with_prefill=with_prefill) + + if self.in_profile_run and self.dynamic_eplb: + self.model.clear_all_moe_loads() + if not is_torchair_compile and not self.in_profile_run and self.dynamic_eplb: + self.eplb_updator.take_update_info_from_eplb_process() + self.eplb_updator.forward_end() + return hidden_states @contextmanager @@ -1775,6 +1806,14 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + def eplb_warmup(self): + #EPLB + if self.dynamic_eplb and not self.is_eplb_warmuped: + self.is_eplb_warmuped = True + self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + self.eplb_updator.set_adaptor(self.eplb_adaptor) + self.eplb_updator.warm_up_eplb() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index dc004f4b89..c07d1b9568 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -196,6 +196,7 @@ def load_model(self) -> None: self.model_runner.load_model() def compile_or_warm_up_model(self) -> None: + self.model_runner.eplb_warmup() warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [