File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -1342,11 +1342,17 @@ def forward(self,
1342
1342
final_hidden_states = final_hidden_states[start:end, :]
1343
1343
dispose_tensor(e_hidden_states)
1344
1344
elif fused_moe_state == FusedMoEState.AllGather:
1345
- final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1346
- e_hidden_states,
1347
- "sum",
1348
- scatter_dim=0,
1349
- group=get_dp_group().device_group)
1345
+ final_hidden_states_shape = (
1346
+ e_hidden_states.size(0) //
1347
+ self.dp_size, ) + e_hidden_states.shape[1:]
1348
+ final_hidden_states = torch.empty(
1349
+ final_hidden_states_shape,
1350
+ dtype=e_hidden_states.dtype,
1351
+ device=e_hidden_states.device)
1352
+ dist.reduce_scatter_tensor(final_hidden_states,
1353
+ e_hidden_states,
1354
+ op=dist.ReduceOp.SUM,
1355
+ group=get_dp_group().device_group)
1350
1356
final_hidden_states = final_hidden_states[:num_tokens]
1351
1357
dispose_tensor(e_hidden_states)
1352
1358
else:
You can’t perform that action at this time.
0 commit comments