File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -1342,11 +1342,17 @@ def forward(self,
1342
1342
final_hidden_states = final_hidden_states [start :end , :]
1343
1343
dispose_tensor (e_hidden_states )
1344
1344
elif fused_moe_state == FusedMoEState .AllGather :
1345
- final_hidden_states = dist ._functional_collectives .reduce_scatter_tensor (
1346
- e_hidden_states ,
1347
- "sum" ,
1348
- scatter_dim = 0 ,
1349
- group = get_dp_group ().device_group )
1345
+ final_hidden_states_shape = (
1346
+ e_hidden_states .size (0 ) //
1347
+ self .dp_size , ) + e_hidden_states .shape [1 :]
1348
+ final_hidden_states = torch .empty (
1349
+ final_hidden_states_shape ,
1350
+ dtype = e_hidden_states .dtype ,
1351
+ device = e_hidden_states .device )
1352
+ dist .reduce_scatter_tensor (final_hidden_states ,
1353
+ e_hidden_states ,
1354
+ op = dist .ReduceOp .SUM ,
1355
+ group = get_dp_group ().device_group )
1350
1356
final_hidden_states = final_hidden_states [:num_tokens ]
1351
1357
dispose_tensor (e_hidden_states )
1352
1358
else :
You can’t perform that action at this time.
0 commit comments