-
Notifications
You must be signed in to change notification settings - Fork 63
Description
Hi,
First of all, great work on the project! However, I’ve encountered an issue with memory release when using USP. Specifically, I’m using USP for end-to-end sequence parallelism outside the multi-layer Transformer blocks. After processing all Transformer blocks, the final output is gathered via all_gather
. Here is a simplified version of the code:
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
rank = dist.get_rank()
world_size = dist.get_world_size()
local_hidden_states = hidden_states.chunk(world_size, dim=0)[rank].detach().clone()
local_hidden_states = self.patch_embed(local_hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
local_rotary_pos_emb = rotary_pos_emb.chunk(world_size, dim=0)[rank].detach().clone()
for blk in self.blocks:
local_hidden_states = blk(local_hidden_states, rotary_pos_emb=local_rotary_pos_emb)
S, D = local_hidden_states.shape[:]
hidden_states_gather = torch.zeros(world_size * S, D, dtype=local_hidden_states.dtype, device=local_hidden_states.device)
dist.all_gather_into_tensor(hidden_states_gather, local_hidden_states)
return hidden_states_gather
However, I’ve noticed that the GPU memory usage keeps accumulating over time and isn’t properly released. I can provide a pickle file with memory statistics, which can be viewed on [PyTorch Memory Visualization](https://pytorch.ac.cn/memory_viz). The pickle file is
gpu_mem.zip
Upon analysis, I observed that memory created by torch.empty
on line 94 in ring/utils.py
cannot be released properly. Additionally, several operations like tensor.to(dtype)
, all_to_all
, and others also seem to have issues with memory not being freed. I suspect that this may be related to the use of USP, rather than being a problem with any single operation.
If you have any insights or suggestions that could help resolve this issue, I would greatly appreciate it!
Thanks!