Skip to content

Commit f1f936b

Browse files
committed
EPLB moed load collect
1 parent d639144 commit f1f936b

File tree

5 files changed

+92
-299
lines changed

5 files changed

+92
-299
lines changed

vllm_ascend/eplb/core/loader/ssd_loader.py

Lines changed: 0 additions & 295 deletions
This file was deleted.

vllm_ascend/models/deepseek_v2.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,39 @@ def forward(
723723
inputs_embeds)
724724
return hidden_states
725725

726+
def update_expert_map(self,new_expert_map,layer_id):
727+
self.model.layers[layer_id].mlp.experts.update_map(new_expert_map)
728+
729+
def update_all_expert_map(self, new_expert_map,num_moe_layers):
730+
num_moe_layers = len(new_expert_map)
731+
for layer_id in range(num_moe_layers):
732+
layer_map = new_expert_map[layer_id].to("npu")
733+
self.model.layers[3+layer_id].mlp.experts.update_map(layer_map)
734+
735+
def get_expert_map(self,layer_id):
736+
return self.model.layers[layer_id].mlp.experts.get_map()
737+
738+
def get_all_expert_map(self,num_moe_layers):
739+
all_loads = []
740+
for layer_id in range(num_moe_layers):
741+
load_tensor = self.get_expert_map(3+layer_id) # (num_experts_per_layer,)
742+
all_loads.append(load_tensor)
743+
744+
return torch.stack(all_loads, dim=0)
745+
746+
def get_moe_load(self,layer_id):
747+
return self.model.layers[layer_id].mlp.experts.get_moe_load()
748+
749+
def get_all_moe_loads(self, num_moe_layers) -> torch.Tensor:
750+
"""
751+
output: [num_moe_layers, num_experts_per_layer]
752+
"""
753+
all_loads = []
754+
for layer_id in range(num_moe_layers):
755+
load_tensor = self.get_moe_load(3+layer_id) # (num_experts_per_layer,)
756+
all_loads.append(load_tensor)
757+
758+
return torch.stack(all_loads, dim=0)
726759

727760
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
728761
pass

vllm_ascend/ops/fused_moe.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,8 @@ def __init__(
991991

992992
AscendFusedMoE.moe_counter += 1
993993
self.moe_instance_id = AscendFusedMoE.moe_counter
994+
self.moe_load = None
995+
self.topk_ids = None
994996

995997
if params_dtype is None:
996998
params_dtype = torch.get_default_dtype()
@@ -1132,7 +1134,7 @@ def forward(self,
11321134
hidden_states, router_logits)
11331135

11341136
# Matrix multiply.
1135-
hidden_states = self.quant_method.apply(
1137+
hidden_states, self.topk_ids = self.quant_method.apply(
11361138
layer=self,
11371139
x=hidden_states,
11381140
router_logits=router_logits,
@@ -1152,6 +1154,8 @@ def forward(self,
11521154
global_redundant_expert_num=self.global_redundant_expert_num,
11531155
**kwargs)
11541156

1157+
self.calculate_moe_load()
1158+
11551159
if self.enable_multistream_shared_expert and not is_prefill:
11561160
hidden_states, shared_output = hidden_states
11571161

@@ -1209,3 +1213,25 @@ def _forward_ms_fused_moe_comp(
12091213
enable_force_load_balance=enable_force_load_balance)
12101214

12111215
return hidden_states
1216+
1217+
def update_map(self,new_expert_map):
1218+
self.expert_map = new_expert_map
1219+
1220+
def get_map(self):
1221+
return self.expert_map
1222+
1223+
def get_moe_load(self):
1224+
return self.moe_load
1225+
1226+
def calculate_moe_load(self):
1227+
if self.moe_load is None:
1228+
self.moe_load = torch.zeros(self.num_experts,
1229+
dtype=torch.int64,
1230+
device=self.topk_ids.device)
1231+
1232+
ids = self.topk_ids.flatten().to(torch.int64)
1233+
1234+
ones = torch.ones_like(ids, dtype=torch.int64, device=ids.device)
1235+
self.moe_load.scatter_add_(0, ids, ones)
1236+
1237+
return self.moe_load

0 commit comments

Comments
 (0)