Skip to content

Commit 257b64b

Browse files
committed
fix: use dist.reduce_scatter_tensor to avoid memory leak
Signed-off-by: boying <897013703@qq.com>
1 parent 392fd72 commit 257b64b

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,11 +1342,17 @@ def forward(self,
13421342
final_hidden_states = final_hidden_states[start:end, :]
13431343
dispose_tensor(e_hidden_states)
13441344
elif fused_moe_state == FusedMoEState.AllGather:
1345-
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1346-
e_hidden_states,
1347-
"sum",
1348-
scatter_dim=0,
1349-
group=get_dp_group().device_group)
1345+
final_hidden_states_shape = (
1346+
e_hidden_states.size(0) //
1347+
self.dp_size, ) + e_hidden_states.shape[1:]
1348+
final_hidden_states = torch.empty(
1349+
final_hidden_states_shape,
1350+
dtype=e_hidden_states.dtype,
1351+
device=e_hidden_states.device)
1352+
dist.reduce_scatter_tensor(final_hidden_states,
1353+
e_hidden_states,
1354+
op=dist.ReduceOp.SUM,
1355+
group=get_dp_group().device_group)
13501356
final_hidden_states = final_hidden_states[:num_tokens]
13511357
dispose_tensor(e_hidden_states)
13521358
else:

0 commit comments

Comments
 (0)