File tree Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -69,10 +69,9 @@ def generate_log2phy_map(expert_map):
69
69
num_rank_holding_expert = positive_rank_idx .size (0 )
70
70
71
71
if num_rank_holding_expert == 0 :
72
- log2phy_map [:, idx ] = torch .full (
73
- (num_ranks ,),
74
- 0 ,
75
- dtype = log2phy_map .dtype )
72
+ log2phy_map [:, idx ] = torch .full ((num_ranks ,),
73
+ 0 ,
74
+ dtype = log2phy_map .dtype )
76
75
77
76
if num_rank_holding_expert == 1 :
78
77
log2phy_map [negative_rank_idx , idx ] = torch .full (
@@ -84,8 +83,9 @@ def generate_log2phy_map(expert_map):
84
83
random .choice (log2phy_map [positive_rank_idx , idx ])
85
84
for _ in range (num_ranks - num_rank_holding_expert )
86
85
]
87
- log2phy_map [negative_rank_idx , idx ] = torch .tensor (random_list ,
88
- dtype = log2phy_map .dtype )
86
+ log2phy_map [negative_rank_idx ,
87
+ idx ] = torch .tensor (random_list ,
88
+ dtype = log2phy_map .dtype )
89
89
90
90
return log2phy_map
91
91
Original file line number Diff line number Diff line change @@ -443,8 +443,7 @@ def forward(self,
443
443
tuple ) and len (e_hidden_states ) == 2 :
444
444
e_hidden_states , shared_hidden_states = e_hidden_states
445
445
446
- if isinstance (e_hidden_states ,
447
- tuple ) and len (e_hidden_states ) == 3 :
446
+ if isinstance (e_hidden_states , tuple ) and len (e_hidden_states ) == 3 :
448
447
e_hidden_states , group_list_type , expert_tokens = e_hidden_states
449
448
450
449
if self .dynamic_eplb :
You can’t perform that action at this time.
0 commit comments