1
1
import json
2
2
import random
3
- from typing import List , Dict
3
+ from typing import Dict , List
4
4
5
5
import torch
6
6
7
+
7
8
class ExpertLoadBalancer (object ):
9
+
8
10
def __init__ (self , expert_map_path , global_expert_num ):
9
11
self .expert_map_path = expert_map_path
10
12
self .global_expert_num = global_expert_num
11
- self .expert_map_tensor , self .layers_num , self .ranks_num = \
12
- self .expert_file_to_tensor ()
13
+ self .expert_map_tensor , self .layers_num , self .ranks_num = (
14
+ self .expert_file_to_tensor () )
13
15
14
16
def expert_file_to_tensor (self ):
15
17
with open (self .expert_map_path , "r" ) as f :
@@ -43,19 +45,22 @@ def generate_index_dicts(tensor_2d):
43
45
44
46
def generate_expert_placement_map (self ):
45
47
expert_placement_map = torch .full (
46
- (self .layers_num , self .ranks_num , self .global_expert_num ),
47
- - 1 , dtype = torch .int32 )
48
+ (self .layers_num , self .ranks_num , self .global_expert_num ),
49
+ - 1 ,
50
+ dtype = torch .int32 ,
51
+ )
48
52
for layer_id in range (self .layers_num ):
49
53
for gpu_id in range (self .ranks_num ):
50
54
e_ids = self .expert_map_tensor [layer_id , gpu_id ]
51
- expert_placement_map [layer_id , gpu_id , e_ids ] = \
52
- torch .arange (len (e_ids ), dtype = torch .int32 )
55
+ expert_placement_map [layer_id , gpu_id ,
56
+ e_ids ] = torch .arange (len (e_ids ),
57
+ dtype = torch .int32 )
53
58
return expert_placement_map
54
59
55
60
def generate_log2phy_expert_map (self , layer_id ):
56
61
concatenated = torch .flatten (self .expert_map_tensor [layer_id ])
57
62
rank_expert_to_global = self .generate_index_dicts (
58
- self .expert_map_tensor [layer_id ])
63
+ self .expert_map_tensor [layer_id ])
59
64
result_dict : Dict [int , List [int ]] = {}
60
65
for idx , value in enumerate (concatenated ):
61
66
key = value .item ()
@@ -64,7 +69,8 @@ def generate_log2phy_expert_map(self, layer_id):
64
69
result_dict [key ].append (idx )
65
70
66
71
log2phy_map = torch .full ((self .ranks_num , self .global_expert_num ),
67
- - 1 , dtype = torch .int32 )
72
+ - 1 ,
73
+ dtype = torch .int32 )
68
74
for rank in range (self .ranks_num ):
69
75
for key in result_dict :
70
76
indices_in_concat = result_dict [key ]
@@ -79,7 +85,7 @@ def get_rank_placement_map(self, layer_id, rank_id):
79
85
expert_placement_map = self .generate_expert_placement_map ()
80
86
layer_expert_map = expert_placement_map [layer_id ]
81
87
rank_expert_map = layer_expert_map [rank_id ].to (
82
- torch .npu .current_device ())
88
+ torch .npu .current_device ())
83
89
rank_local_expert_num = torch .sum (torch .ne (rank_expert_map , - 1 )).item ()
84
90
return rank_local_expert_num , rank_expert_map
85
91
@@ -88,8 +94,9 @@ def get_rank_log2phy_map(self, layer_id, rank_id):
88
94
return layer_log2phy_map [rank_id ]
89
95
90
96
def get_global_redundant_expert_num (self ):
91
- global_redundant_expert_num = len (self .expert_map_tensor [0 ][0 ]) \
92
- * self .ranks_num - self .global_expert_num
97
+ global_redundant_expert_num = (
98
+ len (self .expert_map_tensor [0 ][0 ]) * self .ranks_num -
99
+ self .global_expert_num )
93
100
return global_redundant_expert_num
94
101
95
102
@@ -99,4 +106,3 @@ def get_global_redundant_expert_num(self):
99
106
# print(rank_placement_map)
100
107
# rank_phy2log_map = expert_load_balancer.get_rank_log2phy_map(1, 0)
101
108
# print(rank_phy2log_map)
102
-
0 commit comments