Skip to content

Commit 541e4b6

Browse files
author
offline0806
committed
[EPLB]Initializing eplb when dynamic eplb.
1 parent 5c90a6f commit 541e4b6

File tree

4 files changed

+71
-82
lines changed

4 files changed

+71
-82
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,6 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
187187

188188
record["layer_list"].append(layer_record)
189189

190-
print(record)
191-
192190
with open(expert_map_record_path, "w") as f:
193191
json.dump(record, f, indent=4)
194192

@@ -201,7 +199,6 @@ def do_update_expert_weight(self, layer_id, local_expert_to_replace,
201199
for expert_tensor, buffer_tensor in zip(
202200
self.expert_param_per_layer[layer_id][local_expert_to_replace],
203201
self.buffer_tensor_list[buffer_tensor_id]):
204-
# expert_tensor.copy_(buffer_tensor)
205202
expert_tensor = buffer_tensor.clone()
206203

207204
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):

vllm_ascend/ops/common_fused_moe.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@
2727
FusedMoEParallelConfig # isort: skip
2828
from vllm.model_executor.layers.fused_moe.layer import (
2929
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
30-
30+
from vllm_ascend.eplb.core.eplb_utils import (
31+
determine_default_expert_map,
32+
determine_default_log2phy_map)
3133
from vllm_ascend.ascend_config import get_ascend_config
3234
from vllm_ascend.distributed.parallel_state import get_mc2_group
3335
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
3436
from vllm_ascend.ops.moe.experts_selector import select_experts
3537
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
3638
AlltoAllCommImpl, MC2CommImpl)
3739
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
38-
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
3940
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
40-
from vllm.logger import logger
4141

4242
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
4343

@@ -298,31 +298,26 @@ def __init__(
298298
self.moe_config.mc2_group = get_mc2_group()
299299
ascend_config = get_ascend_config()
300300
self.dynamic_eplb = ascend_config.dynamic_eplb
301-
self.expert_map_path = ascend_config.expert_map_path
302-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
303-
self.global_num_experts = num_experts + self.global_redundant_expert_num
304-
if self.expert_map_path and os.path.exists(self.expert_map_path) and os.access(self.expert_map_path, os.R_OK):
305-
self.expert_load_balancer = ExpertLoadBalancer(self.expert_map_path, self.global_num_experts)
306-
self.local_num_experts, self.expert_map = (self.expert_load_balancer.get_rank_placement_map(self.moe_instance_id, self.ep_rank))
307-
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(self.moe_instance_id, self.ep_rank).npu()
308-
self.global_redundant_expert_num = (self.expert_load_balancer.get_global_redundant_expert_num())
309-
else:
310-
self.local_num_experts, self.expert_map = determine_expert_map(self.ep_size, self.ep_rank, self.global_num_experts)
311-
if self.dynamic_eplb:
312-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
313-
from vllm_ascend.eplb.core.eplb_utils import (
314-
determine_default_expert_map,
315-
determine_default_log2phy_map)
316-
self.local_num_experts, self.expert_map = determine_default_expert_map(
317-
self.global_num_experts, self.ep_size, self.ep_rank,
318-
self.global_redundant_expert_num)
319-
self.log2phy = determine_default_log2phy_map(
320-
self.global_num_experts, self.ep_size, self.ep_rank,
321-
self.global_redundant_expert_num)
322-
323-
self.moe_load = None
324-
local_num_experts = (torch.sum(self.expert_map != -1) if self.expert_map is not None else num_experts)
325301
if self.dynamic_eplb:
302+
self.expert_map_path = ascend_config.expert_map_path
303+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
304+
self.global_num_experts = num_experts + self.global_redundant_expert_num
305+
if self.expert_map_path and os.path.exists(self.expert_map_path) and os.access(self.expert_map_path, os.R_OK):
306+
self.expert_load_balancer = ExpertLoadBalancer(self.expert_map_path, self.global_num_experts)
307+
self.local_num_experts, self.expert_map = (self.expert_load_balancer.get_rank_placement_map(self.moe_instance_id, self.ep_rank))
308+
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(self.moe_instance_id, self.ep_rank).npu()
309+
self.global_redundant_expert_num = (self.expert_load_balancer.get_global_redundant_expert_num())
310+
else:
311+
self.local_num_experts, self.expert_map = determine_expert_map(self.ep_size, self.ep_rank, self.global_num_experts)
312+
if self.dynamic_eplb:
313+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
314+
self.local_num_experts, self.expert_map = determine_default_expert_map(
315+
self.global_num_experts, self.ep_size, self.ep_rank,
316+
self.global_redundant_expert_num)
317+
self.log2phy = determine_default_log2phy_map(
318+
self.global_num_experts, self.ep_size, self.ep_rank,
319+
self.global_redundant_expert_num)
320+
local_num_experts = (torch.sum(self.expert_map != -1) if self.expert_map is not None else num_experts)
326321
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
327322

328323

vllm_ascend/ops/fused_moe.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
3838
from vllm.model_executor.layers.quantization.base_config import \
3939
QuantizationConfig
40+
from vllm_ascend.eplb.core.eplb_utils import (
41+
determine_default_expert_map,
42+
determine_default_log2phy_map)
4043

4144
from vllm_ascend.ascend_config import get_ascend_config
4245
from vllm_ascend.ascend_forward_context import FusedMoEState
@@ -290,37 +293,29 @@ def __init__(
290293
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
291294

292295
ascend_config = get_ascend_config()
293-
expert_map_path = ascend_config.expert_map_path
294296
self.dynamic_eplb = ascend_config.dynamic_eplb
295-
if expert_map_path and os.path.exists(expert_map_path):
296-
# moe expert load balance
297-
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
298-
self.global_num_experts)
299-
self.local_num_experts, self.expert_map = \
300-
expert_load_balancer.get_rank_placement_map(
301-
self.moe_instance_id,
302-
get_ep_group().rank_in_group)
303-
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
304-
self.moe_instance_id,
305-
get_ep_group().rank_in_group)
306-
self.global_redundant_expert_num = \
307-
expert_load_balancer.get_global_redundant_expert_num()
308-
else:
309-
# Create a tensor of size num_experts filled with -1
310-
self.local_num_experts, self.expert_map = determine_expert_map(
311-
self.ep_size,
312-
get_ep_group().rank_in_group, self.global_num_experts)
313-
if self.dynamic_eplb:
314-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
315-
from vllm_ascend.eplb.core.eplb_utils import (
316-
determine_default_expert_map,
317-
determine_default_log2phy_map)
318-
self.local_num_experts, self.expert_map = determine_default_expert_map(
319-
self.global_num_experts, self.ep_size, self.ep_rank,
320-
self.global_redundant_expert_num)
321-
self.log2phy = determine_default_log2phy_map(
322-
self.global_num_experts, self.ep_size, self.ep_rank,
323-
self.global_redundant_expert_num)
297+
if self.dynamic_eplb:
298+
self.expert_map_path = ascend_config.expert_map_path
299+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
300+
self.global_num_experts = num_experts + self.global_redundant_expert_num
301+
if self.expert_map_path and os.path.exists(self.expert_map_path) and os.access(self.expert_map_path, os.R_OK):
302+
self.expert_load_balancer = ExpertLoadBalancer(self.expert_map_path, self.global_num_experts)
303+
self.local_num_experts, self.expert_map = (self.expert_load_balancer.get_rank_placement_map(self.moe_instance_id, self.ep_rank))
304+
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(self.moe_instance_id, self.ep_rank).npu()
305+
self.global_redundant_expert_num = (self.expert_load_balancer.get_global_redundant_expert_num())
306+
else:
307+
self.local_num_experts, self.expert_map = determine_expert_map(self.ep_size, self.ep_rank, self.global_num_experts)
308+
if self.dynamic_eplb:
309+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
310+
self.local_num_experts, self.expert_map = determine_default_expert_map(
311+
self.global_num_experts, self.ep_size, self.ep_rank,
312+
self.global_redundant_expert_num)
313+
self.log2phy = determine_default_log2phy_map(
314+
self.global_num_experts, self.ep_size, self.ep_rank,
315+
self.global_redundant_expert_num)
316+
local_num_experts = (torch.sum(self.expert_map != -1) if self.expert_map is not None else num_experts)
317+
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
318+
324319

325320
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
326321

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
3838
from vllm.model_executor.layers.quantization.base_config import \
3939
QuantizationConfig
40-
40+
from vllm_ascend.eplb.core.eplb_utils import (
41+
determine_default_expert_map,
42+
determine_default_log2phy_map)
4143
from vllm_ascend.ascend_config import get_ascend_config
4244
from vllm_ascend.ascend_forward_context import FusedMoEState
4345
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -1012,27 +1014,27 @@ def __init__(
10121014

10131015
ascend_config = get_ascend_config()
10141016
self.dynamic_eplb = ascend_config.dynamic_eplb
1015-
self.expert_map_path = ascend_config.expert_map_path
1016-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
1017-
self.global_num_experts = num_experts + self.global_redundant_expert_num
1018-
if self.expert_map_path and os.path.exists(self.expert_map_path):
1019-
self.expert_load_balancer = ExpertLoadBalancer(self.expert_map_path, self.global_num_experts)
1020-
self.local_num_experts, self.expert_map = (self.expert_load_balancer.get_rank_placement_map(self.moe_instance_id, self.ep_rank))
1021-
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(self.moe_instance_id, self.ep_rank).npu()
1022-
self.global_redundant_expert_num = (self.expert_load_balancer.get_global_redundant_expert_num())
1023-
else:
1024-
self.local_num_experts, self.expert_map = determine_expert_map(self.ep_size, self.ep_rank, self.global_num_experts)
1025-
if self.dynamic_eplb:
1026-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
1027-
from vllm_ascend.eplb.core.eplb_utils import (
1028-
determine_default_expert_map,
1029-
determine_default_log2phy_map)
1030-
self.local_num_experts, self.expert_map = determine_default_expert_map(
1031-
self.global_num_experts, self.ep_size, self.ep_rank,
1032-
self.global_redundant_expert_num)
1033-
self.log2phy = determine_default_log2phy_map(
1034-
self.global_num_experts, self.ep_size, self.ep_rank,
1035-
self.global_redundant_expert_num)
1017+
if self.dynamic_eplb:
1018+
self.expert_map_path = ascend_config.expert_map_path
1019+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
1020+
self.global_num_experts = num_experts + self.global_redundant_expert_num
1021+
if self.expert_map_path and os.path.exists(self.expert_map_path) and os.access(self.expert_map_path, os.R_OK):
1022+
self.expert_load_balancer = ExpertLoadBalancer(self.expert_map_path, self.global_num_experts)
1023+
self.local_num_experts, self.expert_map = (self.expert_load_balancer.get_rank_placement_map(self.moe_instance_id, self.ep_rank))
1024+
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(self.moe_instance_id, self.ep_rank).npu()
1025+
self.global_redundant_expert_num = (self.expert_load_balancer.get_global_redundant_expert_num())
1026+
else:
1027+
self.local_num_experts, self.expert_map = determine_expert_map(self.ep_size, self.ep_rank, self.global_num_experts)
1028+
if self.dynamic_eplb:
1029+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
1030+
self.local_num_experts, self.expert_map = determine_default_expert_map(
1031+
self.global_num_experts, self.ep_size, self.ep_rank,
1032+
self.global_redundant_expert_num)
1033+
self.log2phy = determine_default_log2phy_map(
1034+
self.global_num_experts, self.ep_size, self.ep_rank,
1035+
self.global_redundant_expert_num)
1036+
local_num_experts = (torch.sum(self.expert_map != -1) if self.expert_map is not None else num_experts)
1037+
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
10361038

10371039
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
10381040
self.enable_multistream_moe = \

0 commit comments

Comments
 (0)