Skip to content

Commit ee6f328

Browse files
committed
add fix warm up & change init expert map from file
1 parent f1f7b95 commit ee6f328

File tree

4 files changed

+37
-12
lines changed

4 files changed

+37
-12
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,27 @@ def get_init_expert_map(self, num_moe_layers):
121121

122122
return all_expert_maps
123123

124+
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
125+
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(expert_map_path)
126+
for layer_idx in range(num_moe_layers):
127+
self.expert_map_per_layer_cpu[layer_idx] = \
128+
expert_map_tensor[layer_idx][self.rank_id]
129+
130+
def _expert_file_to_tensor(self, expert_map_path: str):
131+
with open(expert_map_path, "r") as f:
132+
data = json.load(f)
133+
layers_num = data["moe_layer_count"]
134+
gpus_num = data["layer_list"][0]["device_count"]
135+
136+
tensor_data = []
137+
for layer in data["layer_list"]:
138+
device_data = []
139+
for device in layer["device_list"]:
140+
device_data.append(device["device_expert"])
141+
tensor_data.append(device_data)
142+
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
143+
return expert_map_tensor, layers_num, gpus_num
144+
124145
def do_update_expert_map(self, layer_id, updated_expert_map):
125146
self.expert_map_per_layer[layer_id].copy_(updated_expert_map)
126147
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)

vllm_ascend/eplb/eplb_updator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,19 @@
2424

2525
class EplbUpdator:
2626

27-
def __init__(self, redundant_enable):
28-
self.init_eplb(redundant_enable)
27+
def __init__(self, expert_map_path):
28+
self.init_eplb(expert_map_path)
2929

3030
def set_adaptor(self, adaptor):
3131
self.adaptor = adaptor
3232
self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor)
3333
self.num_moe_layers = self.adaptor.num_moe_layers
3434

35-
def init_eplb(self, redundant_enable):
35+
def init_eplb(self, expert_map_path):
3636

37-
self.redundant_enable = redundant_enable
37+
self.redundant_enable = (expert_map_path != None)
3838
self.num_iterations: torch.int64 = 130
39+
self.expert_map_path = expert_map_path
3940

4041
self.weight_update_counter = 0
4142
self.expert_map_initialized = False
@@ -82,7 +83,8 @@ def get_update_iteration(self):
8283
def get_init_expert_map(self):
8384
try:
8485
if not self.expert_map_initialized:
85-
self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map(self.num_moe_layers)
86+
# self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map(self.num_moe_layers)
87+
self.shared_dict["expert_maps"] = self.adaptor.get_init_expert_map_from_file(self.num_moe_layers, self.expert_map_path)
8688
self.expert_map_initialized = True
8789
except Exception as e:
8890
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}", exc_info=True)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
371371
self.dynamic_eplb = ascend_config.dynamic_eplb
372372
if self.dynamic_eplb == True:
373373
self.eplb_adaptor = None
374-
self.eplb_updator = EplbUpdator(ascend_config.expert_map_path != None)
374+
self.eplb_updator = EplbUpdator(ascend_config.expert_map_path)
375375

376376
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
377377
"""Update the cached states and the persistent batch with the scheduler
@@ -1508,12 +1508,6 @@ def _dummy_run(
15081508
intermediate_tensors=intermediate_tensors,
15091509
inputs_embeds=inputs_embeds)
15101510

1511-
#EPLB
1512-
if self.dynamic_eplb == True:
1513-
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
1514-
self.eplb_updator.set_adaptor(self.eplb_adaptor)
1515-
self.eplb_updator.warm_up_eplb()
1516-
15171511
return hidden_states
15181512

15191513
def profile_run(self) -> None:
@@ -1555,6 +1549,13 @@ def profile_run(self) -> None:
15551549
self.encoder_cache.clear()
15561550
gc.collect()
15571551

1552+
def eplb_warmup(self):
1553+
#EPLBMore actions
1554+
if self.dynamic_eplb == True:
1555+
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
1556+
self.eplb_updator.set_adaptor(self.eplb_adaptor)
1557+
self.eplb_updator.warm_up_eplb()
1558+
15581559
def load_model(self) -> None:
15591560
logger.info("Starting to load model %s...", self.model_config.model)
15601561

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def compile_or_warm_up_model(self) -> None:
203203
for size in sorted(warmup_sizes, reverse=True):
204204
logger.info("Compile and warming up model for size %d", size)
205205
self.model_runner._dummy_run(size)
206+
self.model_runner.eplb_warmup()
206207
if not self.model_config.enforce_eager:
207208
self.model_runner.capture_model()
208209
# Reset the seed to ensure that the random state is not affected by

0 commit comments

Comments
 (0)