Skip to content

Commit c202a1d

Browse files
author
lt
committed
merge expert load 1227
2 parents ef40e69 + e82ee11 commit c202a1d

File tree

4 files changed

+190
-27
lines changed

4 files changed

+190
-27
lines changed

vllm_ascend/ops/expert_load_balancer.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,78 @@
11
import json
22
import random
3-
from typing import Dict, List
3+
from typing import Dict, List, Optional
44

55
import torch
66

7+
from vllm_ascend.ascend_config import get_ascend_config
78

8-
class ExpertLoadBalancer(object):
99

10-
def __init__(self, expert_map_path, global_expert_num):
11-
self.expert_map_path = expert_map_path
10+
class ExpertLoadBalancer:
11+
"""
12+
ExpertLoadBalancer is a singleton class responsible for managing and
13+
recording the mapping and load balancing of experts across multiple layers
14+
and devices in a distributed Mixture-of-Experts (MoE) model.
15+
"""
16+
17+
_instance = None
18+
"""The singleton instance of ExpertLoadBalancer."""
19+
20+
def __init__(self, expert_map_path: Optional[str], global_expert_num: int):
21+
"""
22+
This method should only be called once, and it raises an exception if
23+
an instance already exists.
24+
25+
Args:
26+
expert_map_path (str): Path to the expert map file. If None, only
27+
used for recording expert load.
28+
global_expert_num (int): Total number of global experts.
29+
Raises:
30+
Exception: If an instance of ExpertLoadBalancer already exists.
31+
"""
32+
33+
if ExpertLoadBalancer._instance is not None:
34+
raise Exception(
35+
"This class is a singleton, cannot be instantiated "
36+
"more than once.")
37+
1238
self.global_expert_num = global_expert_num
13-
self.expert_map_tensor, self.layers_num, self.ranks_num = (
14-
self._expert_file_to_tensor())
1539

16-
def _expert_file_to_tensor(self):
17-
with open(self.expert_map_path, "r") as f:
40+
# If expert_map_path is not provided, we only record the expert load.
41+
if expert_map_path is not None:
42+
self.expert_map_tensor, self.layers_num, self.ranks_num = (
43+
self._expert_file_to_tensor(expert_map_path))
44+
else:
45+
self.expert_map_tensor = None
46+
# TODO: change the num layer source
47+
self.layers_num = 58
48+
self.ranks_num = None
49+
50+
self._torchair_graph_enabled = \
51+
get_ascend_config().torchair_graph_config.enabled
52+
53+
self._all_layers_logical_expert_load_record = \
54+
torch.zeros((self.layers_num, self.global_expert_num),
55+
dtype=torch.int64,
56+
device=torch.npu.current_device())
57+
# Always enable expert load recording if torchair graph is enabled.
58+
self._recording = self._torchair_graph_enabled
59+
60+
@staticmethod
61+
def get_instance():
62+
if ExpertLoadBalancer._instance is None:
63+
raise ValueError(
64+
"ExpertLoadBalancer instance has not been initialized.")
65+
return ExpertLoadBalancer._instance
66+
67+
@staticmethod
68+
def init_instance(expert_map_path: Optional[str], global_expert_num: int):
69+
"""Initialize the singleton instance of ExpertLoadBalancer."""
70+
ExpertLoadBalancer._instance = ExpertLoadBalancer(
71+
expert_map_path, global_expert_num)
72+
return ExpertLoadBalancer._instance
73+
74+
def _expert_file_to_tensor(self, expert_map_path: str):
75+
with open(expert_map_path, "r") as f:
1876
data = json.load(f)
1977
layers_num = data["moe_layer_count"]
2078
gpus_num = data["layer_list"][0]["device_count"]
@@ -97,3 +155,28 @@ def get_global_redundant_expert_num(self):
97155
len(self.expert_map_tensor[0][0]) * self.ranks_num -
98156
self.global_expert_num)
99157
return global_redundant_expert_num
158+
159+
def accumulate_expert_distribution_record(self, layer_id: int,
160+
topk_ids: torch.Tensor):
161+
if not self._recording:
162+
return
163+
flattened_topk_ids = topk_ids.flatten().to(torch.int64)
164+
ones = torch.ones_like(flattened_topk_ids)
165+
self._all_layers_logical_expert_load_record[layer_id].scatter_add_(
166+
0, flattened_topk_ids, ones)
167+
168+
def start_expert_distribution_record(self):
169+
"""Start recording the expert distribution."""
170+
self._all_layers_logical_expert_load_record.zero_()
171+
self._recording = True
172+
173+
def stop_expert_distribution_record(self):
174+
"""Stop recording the expert distribution."""
175+
# If torchair graph is not enabled, we do not turn off the recording.
176+
self._recording = self._torchair_graph_enabled
177+
178+
def export_local_expert_distribution_record(self):
179+
"""Export the local expert distribution record and reset it."""
180+
local_record = self._all_layers_logical_expert_load_record.clone()
181+
self._all_layers_logical_expert_load_record.zero_()
182+
return local_record

vllm_ascend/ops/fused_moe.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -945,15 +945,15 @@ def apply(
945945
top_k=top_k,
946946
expert_map=expert_map,
947947
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
948-
shared_experts=shared_experts)
948+
shared_experts=shared_experts), topk_ids
949949
elif fused_moe_state == FusedMoEState.AllGather:
950950
return fused_experts(hidden_states=x,
951951
w1=layer.w13_weight,
952952
w2=layer.w2_weight,
953953
topk_weights=topk_weights,
954954
topk_ids=topk_ids,
955955
top_k=top_k,
956-
expert_map=expert_map)
956+
expert_map=expert_map), topk_ids
957957
elif MOE_ALL2ALL_BUFFER:
958958
return fused_experts_with_all2all_buffer(
959959
hidden_states=x,
@@ -965,16 +965,17 @@ def apply(
965965
max_model_len=self.max_model_len,
966966
global_batch_size=self.global_batch_size,
967967
expert_map=expert_map,
968-
ep_group=get_ep_group())
968+
ep_group=get_ep_group()), topk_ids
969969
else:
970-
return fused_experts_with_all2all(hidden_states=x,
971-
w1=layer.w13_weight,
972-
w2=layer.w2_weight,
973-
topk_weights=topk_weights,
974-
topk_ids=topk_ids,
975-
top_k=top_k,
976-
expert_map=expert_map,
977-
ep_group=get_ep_group())
970+
return fused_experts_with_all2all(
971+
hidden_states=x,
972+
w1=layer.w13_weight,
973+
w2=layer.w2_weight,
974+
topk_weights=topk_weights,
975+
topk_ids=topk_ids,
976+
top_k=top_k,
977+
expert_map=expert_map,
978+
ep_group=get_ep_group()), topk_ids
978979

979980

980981
class AscendFusedMoE(FusedMoE):
@@ -1048,22 +1049,26 @@ def __init__(
10481049
self.log2phy = None
10491050
self.global_redundant_expert_num = 0
10501051

1052+
self.expert_load_balancer = ExpertLoadBalancer.get_instance()
10511053
ascend_config = get_ascend_config()
10521054
expert_map_path = ascend_config.expert_map_path
10531055
self.dynamic_eplb = ascend_config.dynamic_eplb
10541056
if expert_map_path and os.path.exists(expert_map_path):
1057+
# only support in MC2 and graph mode
1058+
if not (VLLM_ENABLE_MC2
1059+
and ascend_config.torchair_graph_config.enabled):
1060+
raise NotImplementedError(
1061+
"EPLB is only supported in MC2 and graph mode")
10551062
# moe expert load balance
1056-
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
1057-
self.global_num_experts)
10581063
self.local_num_experts, self.expert_map = \
1059-
expert_load_balancer.get_rank_placement_map(
1064+
self.expert_load_balancer.get_rank_placement_map(
10601065
self.moe_instance_id,
10611066
get_ep_group().rank_in_group)
1062-
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
1067+
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
10631068
self.moe_instance_id,
10641069
get_ep_group().rank_in_group)
10651070
self.global_redundant_expert_num = \
1066-
expert_load_balancer.get_global_redundant_expert_num()
1071+
self.expert_load_balancer.get_global_redundant_expert_num()
10671072
else:
10681073
# Create a tensor of size num_experts filled with -1
10691074
self.local_num_experts, self.expert_map = determine_expert_map(
@@ -1199,6 +1204,9 @@ def forward(self,
11991204
if self.dynamic_eplb == True:
12001205
self.calculate_moe_load()
12011206

1207+
self.expert_load_balancer.accumulate_expert_distribution_record(
1208+
self.moe_instance_id, self.topk_ids)
1209+
12021210
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
12031211
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
12041212
self.tp_group)
@@ -1216,6 +1224,27 @@ def forward(self,
12161224
dispose_tensor(e_hidden_states)
12171225
else:
12181226
final_hidden_states = e_hidden_states
1227+
self.expert_load_balancer.accumulate_expert_distribution_record(
1228+
self.moe_instance_id, topk_ids)
1229+
1230+
if self.dp_size > 1:
1231+
if VLLM_ENABLE_MC2 and not is_prefill:
1232+
...
1233+
elif self.torchair_graph_enabled:
1234+
if USING_LCCL_COM: # type: ignore
1235+
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1236+
e_hidden_states,
1237+
"sum",
1238+
scatter_dim=0,
1239+
group=get_dp_group().device_group)
1240+
elif self.torchair_graph_enabled and not is_prefill:
1241+
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1242+
e_hidden_states,
1243+
"sum",
1244+
scatter_dim=0,
1245+
group=get_dp_group().device_group)
1246+
else:
1247+
e_hidden_states = get_ep_group().combine(e_hidden_states)
12191248

12201249
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
12211250
final_hidden_states = tensor_model_parallel_all_reduce(

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def apply(
617617
global_redundant_expert_num: int = 0,
618618
shared_experts: Optional[Any] = None,
619619
**kwargs,
620-
) -> torch.Tensor:
620+
) -> Tuple[torch.Tensor, torch.Tensor]:
621621
assert router_logits.shape[
622622
1] == global_num_experts, "Number of global experts mismatch"
623623

vllm_ascend/worker/worker_v1.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
1818
#
1919

20+
import time
21+
from pathlib import Path
2022
from typing import Optional
2123

2224
import torch
@@ -38,9 +40,13 @@
3840
from vllm.v1.outputs import ModelRunnerOutput
3941
from vllm.v1.worker.worker_base import WorkerBase
4042

41-
from vllm_ascend.ascend_config import init_ascend_config
43+
import vllm_ascend.envs as envs_ascend
44+
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
4245
from vllm_ascend.device_allocator.camem import CaMemAllocator
43-
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
46+
from vllm_ascend.distributed.parallel_state import (get_ep_group,
47+
get_etp_group,
48+
init_ascend_model_parallel)
49+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4450
from vllm_ascend.platform import NPUPlatform
4551
from vllm_ascend.utils import try_register_lib
4652
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -233,6 +239,33 @@ def profile(self, is_start: bool = True):
233239
else:
234240
self.profiler.stop()
235241

242+
def expert_distribution_record(self, is_start: bool = True):
243+
assert envs.VLLM_EXPERT_DISTRIBUTION_RECORDER_DIR is not None, \
244+
"VLLM_EXPERT_DISTRIBUTION_RECORDER_DIR is not set. " \
245+
"Please set it to enable expert distribution recording."
246+
247+
if is_start:
248+
logger.info("Starting expert distribution record.")
249+
self.expert_load_balancer.start_expert_distribution_record()
250+
else:
251+
logger.info("Stopping expert distribution record.")
252+
self.expert_load_balancer.stop_expert_distribution_record()
253+
254+
def dump_expert_distribution_record(self):
255+
assert envs.VLLM_EXPERT_DISTRIBUTION_RECORDER_DIR is not None, \
256+
"VLLM_EXPERT_DISTRIBUTION_RECORDER_DIR is not set. " \
257+
"Please set it to enable expert distribution recording."
258+
259+
logger.info("Dumping expert distribution record.")
260+
local_expert_distribution = \
261+
self.expert_load_balancer.export_local_expert_distribution_record()
262+
263+
ep_rank = get_ep_group().rank_in_group
264+
etp_rank = get_etp_group().rank_in_group
265+
_dump_to_file(
266+
f"expert_distribution_recorder_{time.time()}_{ep_rank}_{etp_rank}.pt",
267+
local_expert_distribution)
268+
236269
def add_lora(self, lora_request: LoRARequest) -> bool:
237270
return self.model_runner.add_lora(lora_request)
238271

@@ -277,6 +310,15 @@ def _init_worker_distributed_environment(self) -> None:
277310
)
278311
ensure_kv_transfer_initialized(self.vllm_config)
279312

313+
# Initialize the expert load balancer.
314+
if self.vllm_config.model_config.is_deepseek_mla:
315+
expert_map_path = get_ascend_config().expert_map_path
316+
num_logical_experts = \
317+
self.vllm_config.model_config.hf_config.n_routed_experts
318+
self.expert_load_balancer = ExpertLoadBalancer.init_instance(
319+
expert_map_path=expert_map_path,
320+
global_expert_num=num_logical_experts)
321+
280322
def _init_profiler(self):
281323
# Torch profiler. Enabled and configured through env vars:
282324
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
@@ -310,3 +352,12 @@ def _init_profiler(self):
310352
torch_profiler_trace_dir))
311353
else:
312354
return None
355+
356+
357+
def _dump_to_file(name, data):
358+
save_dir = Path(envs.VLLM_EXPERT_DISTRIBUTION_RECORDER_DIR)
359+
path_output = save_dir / name
360+
logger.info(f"Write expert distribution to {path_output}")
361+
if not save_dir.exists():
362+
save_dir.mkdir(parents=True, exist_ok=True)
363+
torch.save(data, str(path_output))

0 commit comments

Comments
 (0)