File tree Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -1384,16 +1384,15 @@ def forward(self,
1384
1384
if isinstance (e_hidden_states , tuple ):
1385
1385
e_hidden_states , shared_hidden_states = e_hidden_states
1386
1386
1387
- if ( tp_size > 1 and fused_moe_state not in [
1388
- FusedMoEState . AllGather , FusedMoEState . AllGatherEP ,
1389
- FusedMoEState . NaiveMulticast
1390
- ] and not replace_allreduce ):
1391
- dist . all_gather ( list ( chunk_hidden_states ), e_hidden_states ,
1392
- self . tp_group )
1393
- final_hidden_states = torch . cat ( chunk_hidden_states , dim = 0 )
1387
+ if fused_moe_state != FusedMoEState . AllGather :
1388
+ if tp_size > 1 :
1389
+ dist . all_gather ( list ( chunk_hidden_states ), e_hidden_states ,
1390
+ self . tp_group )
1391
+ final_hidden_states = torch . cat ( chunk_hidden_states , dim = 0 )
1392
+ else :
1393
+ final_hidden_states = e_hidden_states
1394
1394
if num_tokens < tp_size :
1395
1395
final_hidden_states = final_hidden_states [:num_tokens ]
1396
- dispose_tensor (e_hidden_states )
1397
1396
elif self .dp_size > 1 :
1398
1397
if fused_moe_state == FusedMoEState .NaiveMulticast :
1399
1398
start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
You can’t perform that action at this time.
0 commit comments