4
4
5
5
import torch
6
6
7
+
7
8
class ExpertLoadBalancer (object ):
8
9
def __init__ (self , expert_map_path , global_expert_num ):
9
10
self .expert_map_path = expert_map_path
10
11
self .global_expert_num = global_expert_num
11
- self .expert_map_tensor , self .layers_num , self .ranks_num = \
12
- self .expert_file_to_tensor ()
12
+ self .expert_map_tensor , self .layers_num , self .ranks_num = (
13
+ self .expert_file_to_tensor ()
14
+ )
13
15
14
16
def expert_file_to_tensor (self ):
15
17
with open (self .expert_map_path , "r" ) as f :
@@ -43,28 +45,33 @@ def generate_index_dicts(tensor_2d):
43
45
44
46
def generate_expert_placement_map (self ):
45
47
expert_placement_map = torch .full (
46
- (self .layers_num , self .ranks_num , self .global_expert_num ),
47
- - 1 , dtype = torch .int32 )
48
+ (self .layers_num , self .ranks_num , self .global_expert_num ),
49
+ - 1 ,
50
+ dtype = torch .int32 ,
51
+ )
48
52
for layer_id in range (self .layers_num ):
49
53
for gpu_id in range (self .ranks_num ):
50
54
e_ids = self .expert_map_tensor [layer_id , gpu_id ]
51
- expert_placement_map [layer_id , gpu_id , e_ids ] = \
52
- torch .arange (len (e_ids ), dtype = torch .int32 )
55
+ expert_placement_map [layer_id , gpu_id , e_ids ] = torch .arange (
56
+ len (e_ids ), dtype = torch .int32
57
+ )
53
58
return expert_placement_map
54
59
55
60
def generate_log2phy_expert_map (self , layer_id ):
56
61
concatenated = torch .flatten (self .expert_map_tensor [layer_id ])
57
62
rank_expert_to_global = self .generate_index_dicts (
58
- self .expert_map_tensor [layer_id ])
63
+ self .expert_map_tensor [layer_id ]
64
+ )
59
65
result_dict : Dict [int , List [int ]] = {}
60
66
for idx , value in enumerate (concatenated ):
61
67
key = value .item ()
62
68
if key not in result_dict :
63
69
result_dict [key ] = []
64
70
result_dict [key ].append (idx )
65
71
66
- log2phy_map = torch .full ((self .ranks_num , self .global_expert_num ),
67
- - 1 , dtype = torch .int32 )
72
+ log2phy_map = torch .full (
73
+ (self .ranks_num , self .global_expert_num ), - 1 , dtype = torch .int32
74
+ )
68
75
for rank in range (self .ranks_num ):
69
76
for key in result_dict :
70
77
indices_in_concat = result_dict [key ]
@@ -78,8 +85,7 @@ def generate_log2phy_expert_map(self, layer_id):
78
85
def get_rank_placement_map (self , layer_id , rank_id ):
79
86
expert_placement_map = self .generate_expert_placement_map ()
80
87
layer_expert_map = expert_placement_map [layer_id ]
81
- rank_expert_map = layer_expert_map [rank_id ].to (
82
- torch .npu .current_device ())
88
+ rank_expert_map = layer_expert_map [rank_id ].to (torch .npu .current_device ())
83
89
rank_local_expert_num = torch .sum (torch .ne (rank_expert_map , - 1 )).item ()
84
90
return rank_local_expert_num , rank_expert_map
85
91
@@ -88,8 +94,9 @@ def get_rank_log2phy_map(self, layer_id, rank_id):
88
94
return layer_log2phy_map [rank_id ]
89
95
90
96
def get_global_redundant_expert_num (self ):
91
- global_redundant_expert_num = len (self .expert_map_tensor [0 ][0 ]) \
92
- * self .ranks_num - self .global_expert_num
97
+ global_redundant_expert_num = (
98
+ len (self .expert_map_tensor [0 ][0 ]) * self .ranks_num - self .global_expert_num
99
+ )
93
100
return global_redundant_expert_num
94
101
95
102
@@ -99,4 +106,3 @@ def get_global_redundant_expert_num(self):
99
106
# print(rank_placement_map)
100
107
# rank_phy2log_map = expert_load_balancer.get_rank_log2phy_map(1, 0)
101
108
# print(rank_phy2log_map)
102
-
0 commit comments