1
- import torch
2
1
import json
3
2
import random
3
+ from typing import List , Dict
4
+
5
+ import torch
4
6
5
7
class ExpertLoadBalancer (object ):
6
8
def __init__ (self , expert_map_path , global_expert_num ):
7
9
self .expert_map_path = expert_map_path
8
10
self .global_expert_num = global_expert_num
9
- self .expert_map_tensor , self .layers_num , self .ranks_num = self .expert_file_to_tensor ()
11
+ self .expert_map_tensor , self .layers_num , self .ranks_num = \
12
+ self .expert_file_to_tensor ()
10
13
11
14
def expert_file_to_tensor (self ):
12
15
with open (self .expert_map_path , "r" ) as f :
@@ -39,25 +42,29 @@ def generate_index_dicts(tensor_2d):
39
42
return dict_list
40
43
41
44
def generate_expert_placement_map (self ):
42
- expert_placement_map = torch .full ((self .layers_num , self .ranks_num , self .global_expert_num ),
43
- - 1 , dtype = torch .int32 )
45
+ expert_placement_map = torch .full (
46
+ (self .layers_num , self .ranks_num , self .global_expert_num ),
47
+ - 1 , dtype = torch .int32 )
44
48
for layer_id in range (self .layers_num ):
45
49
for gpu_id in range (self .ranks_num ):
46
50
e_ids = self .expert_map_tensor [layer_id , gpu_id ]
47
- expert_placement_map [layer_id , gpu_id , e_ids ] = torch .arange (len (e_ids ), dtype = torch .int32 )
51
+ expert_placement_map [layer_id , gpu_id , e_ids ] = \
52
+ torch .arange (len (e_ids ), dtype = torch .int32 )
48
53
return expert_placement_map
49
54
50
55
def generate_log2phy_expert_map (self , layer_id ):
51
56
concatenated = torch .flatten (self .expert_map_tensor [layer_id ])
52
- rank_expert_to_global = self .generate_index_dicts (self .expert_map_tensor [layer_id ])
53
- result_dict = {}
57
+ rank_expert_to_global = self .generate_index_dicts (
58
+ self .expert_map_tensor [layer_id ])
59
+ result_dict : Dict [int , List [int ]] = {}
54
60
for idx , value in enumerate (concatenated ):
55
61
key = value .item ()
56
62
if key not in result_dict :
57
63
result_dict [key ] = []
58
64
result_dict [key ].append (idx )
59
65
60
- log2phy_map = torch .full ((self .ranks_num , self .global_expert_num ), - 1 , dtype = torch .int32 )
66
+ log2phy_map = torch .full ((self .ranks_num , self .global_expert_num ),
67
+ - 1 , dtype = torch .int32 )
61
68
for rank in range (self .ranks_num ):
62
69
for key in result_dict :
63
70
indices_in_concat = result_dict [key ]
@@ -71,7 +78,8 @@ def generate_log2phy_expert_map(self, layer_id):
71
78
def get_rank_placement_map (self , layer_id , rank_id ):
72
79
expert_placement_map = self .generate_expert_placement_map ()
73
80
layer_expert_map = expert_placement_map [layer_id ]
74
- rank_expert_map = layer_expert_map [rank_id ].to (torch .npu .current_device ())
81
+ rank_expert_map = layer_expert_map [rank_id ].to (
82
+ torch .npu .current_device ())
75
83
rank_local_expert_num = torch .sum (torch .ne (rank_expert_map , - 1 )).item ()
76
84
return rank_local_expert_num , rank_expert_map
77
85
@@ -80,7 +88,8 @@ def get_rank_log2phy_map(self, layer_id, rank_id):
80
88
return layer_log2phy_map [rank_id ]
81
89
82
90
def get_global_redundant_expert_num (self ):
83
- global_redundant_expert_num = len (self .expert_map_tensor [0 ][0 ]) * self .ranks_num - self .global_expert_num
91
+ global_redundant_expert_num = len (self .expert_map_tensor [0 ][0 ]) \
92
+ * self .ranks_num - self .global_expert_num
84
93
return global_redundant_expert_num
85
94
86
95
0 commit comments