Skip to content

Commit 99bdc4b

Browse files
author
offline0806
committed
Merge remote-tracking branch 'upstream_gitee/main_eplb_0909'
2 parents 7725088 + 3ea4fff commit 99bdc4b

File tree

6 files changed

+241
-6
lines changed

6 files changed

+241
-6
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import unittest
3+
from vllm_ascend.eplb.core.eplb_utils import determine_default_expert_map
4+
5+
6+
class TestDetermineDefaultExpertMap(unittest.TestCase):
7+
8+
def test_world_size_1(self):
9+
global_expert_num = 8
10+
world_size = 1
11+
global_redundant_expert_num = 0
12+
13+
expected_counts = [8]
14+
expected_maps = [[0, 1, 2, 3, 4, 5, 6, 7]]
15+
16+
local_count, expert_map = determine_default_expert_map(
17+
global_expert_num, world_size, 0, global_redundant_expert_num)
18+
19+
self.assertEqual(local_count, expected_counts[0])
20+
21+
expected_tensor = torch.tensor(expected_maps[0], dtype=torch.int32)
22+
self.assertTrue(torch.all(expert_map == expected_tensor).item())
23+
24+
def test_equal_distribution(self):
25+
global_expert_num = 6
26+
world_size = 3
27+
global_redundant_expert_num = 0
28+
29+
expected_counts = [2, 2, 2]
30+
expected_maps = [
31+
[0, 1, -1, -1, -1, -1], # rank 0
32+
[-1, -1, 0, 1, -1, -1], # rank 1
33+
[-1, -1, -1, -1, 0, 1] # rank 2
34+
]
35+
36+
for rank_id in range(world_size):
37+
local_count, expert_map = determine_default_expert_map(
38+
global_expert_num, world_size, rank_id,
39+
global_redundant_expert_num)
40+
41+
self.assertEqual(
42+
local_count,
43+
expected_counts[rank_id],
44+
)
45+
46+
expected_tensor = torch.tensor(expected_maps[rank_id],
47+
dtype=torch.int32)
48+
self.assertTrue(torch.all(expert_map == expected_tensor).item())
49+
50+
def test_unequal_distribution(self):
51+
global_expert_num = 10
52+
world_size = 3
53+
global_redundant_expert_num = 0
54+
55+
expected_counts = [3, 3, 4]
56+
expected_maps = [
57+
[0, 1, 2, -1, -1, -1, -1, -1, -1, -1], # rank 0
58+
[-1, -1, -1, 0, 1, 2, -1, -1, -1, -1], # rank 1
59+
[-1, -1, -1, -1, -1, -1, 0, 1, 2, 3] # rank 2
60+
]
61+
62+
for rank_id in range(world_size):
63+
local_count, expert_map = determine_default_expert_map(
64+
global_expert_num, world_size, rank_id,
65+
global_redundant_expert_num)
66+
67+
self.assertEqual(local_count, expected_counts[rank_id])
68+
69+
expected_tensor = torch.tensor(expected_maps[rank_id],
70+
dtype=torch.int32)
71+
self.assertTrue(torch.all(expert_map == expected_tensor).item())
72+
73+
def test_with_redundancy(self):
74+
global_expert_num = 7
75+
world_size = 3
76+
global_redundant_expert_num = 2
77+
78+
expected_counts = [3, 3, 3]
79+
expected_maps = [
80+
[0, 1, 2, -1, -1, -1, -1], # rank 0
81+
[-1, -1, 0, 1, 2, -1, -1], # rank 1
82+
[-1, -1, -1, -1, 0, 1, 2] # rank 2
83+
]
84+
85+
for rank_id in range(world_size):
86+
local_count, expert_map = determine_default_expert_map(
87+
global_expert_num, world_size, rank_id,
88+
global_redundant_expert_num)
89+
90+
self.assertEqual(local_count, expected_counts[rank_id])
91+
92+
expected_tensor = torch.tensor(expected_maps[rank_id],
93+
dtype=torch.int32)
94+
self.assertTrue(torch.all(expert_map == expected_tensor).item())
95+
96+
def test_redundancy_at_boundary(self):
97+
global_expert_num = 5
98+
world_size = 2
99+
global_redundant_expert_num = 1
100+
101+
expected_counts = [3, 3]
102+
expected_maps = [[0, 1, 2, -1, -1], [-1, -1, 0, 1, 2]]
103+
104+
for rank_id in range(world_size):
105+
local_count, expert_map = determine_default_expert_map(
106+
global_expert_num, world_size, rank_id,
107+
global_redundant_expert_num)
108+
109+
self.assertEqual(local_count, expected_counts[rank_id])
110+
111+
expected_tensor = torch.tensor(expected_maps[rank_id],
112+
dtype=torch.int32)
113+
self.assertTrue(torch.all(expert_map == expected_tensor).item())

vllm_ascend/ascend_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def __init__(self, vllm_config):
4545
ascend_scheduler_config)
4646

4747
self.expert_map_path = additional_config.get("expert_map_path", None)
48+
self.expert_map_record_path = additional_config.get(
49+
"expert_map_record_path",
50+
None) # Provide path to export expert map
4851
# Eplb config
52+
self.init_redundancy_expert = additional_config.get(
53+
"init_redundancy_expert",
54+
0)
4955
self.dynamic_eplb = additional_config.get("dynamic_eplb", False)
5056
self.num_iterations_eplb_update = additional_config.get("num_iterations_eplb_update", 400)
5157
self.gate_eplb = additional_config.get("gate_eplb", False)

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.distributed as dist
2222
from vllm.logger import logger
2323

24+
from vllm_ascend.ascend_config import get_ascend_config
2425
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
2526

2627

@@ -39,6 +40,7 @@ def __init__(self, model, **args):
3940
self.num_dense_layers = self.model.config.first_k_dense_replace
4041
self.global_expert_num = self.model.config.n_routed_experts
4142
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
43+
self.init_redundancy_expert = get_ascend_config().init_redundancy_expert
4244

4345
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
4446
if self.model.quant_config is not None:
@@ -158,6 +160,35 @@ def _expert_file_to_tensor(self, expert_map_path: str):
158160
return expert_map_tensor, layers_num, gpus_num
159161
logger.error(f"failed to read expert_map_path: {expert_map_path}")
160162

163+
def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
164+
num_local_experts = expert_maps.max() + 1
165+
expert_maps_local = self.global2local(expert_maps, num_local_experts)
166+
167+
expert_maps_list = expert_maps_local.tolist()
168+
record: dict[str, Any] = {
169+
"moe_layer_count": len(expert_maps_list),
170+
"layer_list": []
171+
}
172+
173+
for layer_idx, layer_data in enumerate(expert_maps_list):
174+
layer_record: dict[str, Any] = {
175+
"layer_id": layer_idx,
176+
"device_count": len(layer_data),
177+
"device_list": []
178+
}
179+
180+
for device_idx, experts in enumerate(layer_data):
181+
device_record = {
182+
"device_id": device_idx,
183+
"device_expert": experts
184+
}
185+
layer_record["device_list"].append(device_record)
186+
187+
record["layer_list"].append(layer_record)
188+
189+
with open(expert_map_record_path, "w") as f:
190+
json.dump(record, f, indent=4)
191+
161192
def do_update_expert_map(self, layer_id, updated_expert_map):
162193
self.expert_map_per_layer[layer_id].copy_(updated_expert_map)
163194
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
@@ -173,6 +204,26 @@ def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
173204
if self.log2phy_map_per_layer[layer_id] is not None:
174205
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
175206

207+
def global2local(self, placement: torch.Tensor,
208+
E_local: int) -> torch.Tensor:
209+
210+
L, G, _ = placement.shape
211+
device = placement.device
212+
213+
pt_local = torch.full((L, G, E_local),
214+
fill_value=-1,
215+
dtype=torch.long,
216+
device=device)
217+
218+
valid = placement >= 0
219+
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
220+
221+
slot_idx = placement[l_idx, g_idx, k_idx]
222+
223+
pt_local[l_idx, g_idx, slot_idx] = k_idx
224+
225+
return pt_local
226+
176227
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
177228

178229
L, G, E_local = placement_local.shape
@@ -198,7 +249,10 @@ def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
198249
return placement_global
199250

200251
def determine_expert_map_all(self):
201-
252+
if self.world_size == 1:
253+
local_ids = torch.arange(self.global_expert_num, dtype=torch.int32)
254+
return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1)
255+
202256
local_num_experts = self.global_expert_num // self.world_size
203257

204258
expert_map_all = torch.full(
@@ -215,6 +269,13 @@ def determine_expert_map_all(self):
215269
start = r * local_num_experts
216270
end = self.global_expert_num
217271
local_count = self.global_expert_num - r * local_num_experts
272+
273+
if r < self.init_redundancy_expert:
274+
local_count += 1
275+
if end < self.global_expert_num:
276+
end += 1
277+
else:
278+
start -= 1
218279

219280
local_ids = torch.arange(local_count, dtype=torch.int32)
220281
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(

vllm_ascend/eplb/core/eplb_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,37 @@
2020
import torch
2121

2222

23+
def determine_default_expert_map(global_expert_num, world_size, rank_id,
24+
global_redundant_expert_num):
25+
if world_size == 1:
26+
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
27+
return (global_expert_num, local_ids)
28+
29+
local_num_experts = global_expert_num // world_size
30+
31+
expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32)
32+
33+
if rank_id < world_size - 1:
34+
start = rank_id * local_num_experts
35+
end = (rank_id + 1) * local_num_experts
36+
local_count = local_num_experts
37+
else:
38+
start = rank_id * local_num_experts
39+
end = global_expert_num
40+
local_count = global_expert_num - rank_id * local_num_experts
41+
42+
if rank_id < global_redundant_expert_num:
43+
local_count += 1
44+
if end < global_expert_num:
45+
end += 1
46+
else:
47+
start -= 1
48+
49+
local_ids = torch.arange(local_count, dtype=torch.int32)
50+
expert_map[start:end] = local_ids
51+
52+
return (local_count, expert_map)
53+
2354
def generate_log2phy_map(expert_map):
2455
num_local_experts = expert_map.max() + 1
2556
log2phy_map = expert_map.clone()
@@ -50,7 +81,13 @@ def generate_log2phy_map(expert_map):
5081
return log2phy_map
5182

5283

53-
def determine_default_log2phy_map(global_expert_num, world_size, rank_id):
84+
def determine_default_log2phy_map(global_expert_num, world_size, rank_id, global_redundant_expert_num):
85+
if world_size == 1:
86+
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
87+
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
88+
log2phy_map_all = generate_log2phy_map(expert_map_all)
89+
return log2phy_map_all[rank_id]
90+
5491
local_num_experts = global_expert_num // world_size
5592

5693
expert_map_all = torch.full((world_size, global_expert_num),
@@ -66,6 +103,13 @@ def determine_default_log2phy_map(global_expert_num, world_size, rank_id):
66103
start = r * local_num_experts
67104
end = global_expert_num
68105
local_count = global_expert_num - r * local_num_experts
106+
107+
if r < global_redundant_expert_num:
108+
local_count += 1
109+
if end < global_expert_num:
110+
end += 1
111+
else:
112+
start -= 1
69113

70114
local_ids = torch.arange(local_count, dtype=torch.int32)
71115
expert_map_all[r, start:end] = local_ids

vllm_ascend/eplb/eplb_updator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def init_eplb(self, expert_map_path, process):
3939
self.periodic_load_gather = True
4040
self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update
4141
self.expert_map_path = expert_map_path
42+
self.expert_map_record_path = self.ascend_config.expert_map_record_path
4243

4344
try:
4445
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
@@ -82,6 +83,11 @@ def update_iteration(self):
8283
self.cur_iterations += 1
8384
if self.cur_iterations == (self.num_iterations_eplb_update + \
8485
self.num_wait_worker_iterations + self.num_moe_layers):
86+
if self.expert_map_record_path is not None:
87+
self.adaptor._export_tensor_to_file(
88+
self.shared_dict["expert_maps"],
89+
self.expert_map_record_path)
90+
8591
self.adaptor.model.clear_all_moe_loads()
8692
if not self.gate_eplb:
8793
self.cur_iterations = 0

vllm_ascend/ops/fused_moe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,16 @@ def __init__(
311311
self.ep_size,
312312
get_ep_group().rank_in_group, self.global_num_experts)
313313
if self.dynamic_eplb:
314-
from vllm_ascend.eplb.core.eplb_utils import \
315-
determine_default_log2phy_map
314+
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
315+
from vllm_ascend.eplb.core.eplb_utils import (
316+
determine_default_expert_map,
317+
determine_default_log2phy_map)
318+
self.local_num_experts, self.expert_map = determine_default_expert_map(
319+
self.global_num_experts, self.ep_size, self.ep_rank,
320+
self.global_redundant_expert_num)
316321
self.log2phy = determine_default_log2phy_map(
317-
self.global_num_experts, self.ep_size, self.ep_rank
318-
)
322+
self.global_num_experts, self.ep_size, self.ep_rank,
323+
self.global_redundant_expert_num)
319324

320325
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
321326

0 commit comments

Comments
 (0)