Skip to content

Commit af07de0

Browse files
[Distributed] refine training_pipeline_step to avoid memory leaks (#11103)
1 parent b30ad6a commit af07de0

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
from paddle import framework
4545
from paddle.distributed.fleet.meta_parallel import PipelineLayer
4646

47+
try:
48+
from paddle.distributed.fleet.meta_parallel import PipelineDatasetPreprocessor
49+
except:
50+
PipelineDatasetPreprocessor = None
51+
4752
try:
4853
from paddle.base import core
4954
except:
@@ -2583,22 +2588,32 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
25832588
# for v in self._pp_data_buffer[0].values():
25842589
# assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}"
25852590

2586-
with self.autocast_smart_context_manager():
2587-
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
2588-
self._pp_data_buffer = []
2589-
25902591
model.train()
25912592
if model._dp_comm_overlap or model._sharding_comm_overlap:
25922593
for _, buffers in model._chunk_2_comm_buffers.items():
25932594
for buffer in buffers:
25942595
buffer._acc_steps = self.args.gradient_accumulation_steps
25952596

2596-
inputs = model._prepare_training(
2597-
inputs, self.optimizer, self.lr_scheduler
2598-
) # None, None => [optimizer, lr_scheduler]
25992597
model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step
26002598
model.lr_scheduler = None
26012599

2600+
def _dataset_process_function():
2601+
# Pass a local function to forward_backward_pipeline instead of the dataset itself.
2602+
# This prevents the dataset from being passed as a direct argument to forward_backward_pipeline,
2603+
# which would create additional reference counts that cannot be cleared, leading to GPU memory leaks.
2604+
with self.autocast_smart_context_manager():
2605+
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
2606+
self._pp_data_buffer = []
2607+
2608+
return model._prepare_training(
2609+
inputs, self.optimizer, self.lr_scheduler
2610+
) # None, None => [optimizer, lr_scheduler]
2611+
2612+
if PipelineDatasetPreprocessor is None:
2613+
inputs = _dataset_process_function()
2614+
else:
2615+
inputs = PipelineDatasetPreprocessor(_dataset_process_function)
2616+
26022617
with self.autocast_smart_context_manager():
26032618
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)
26042619

0 commit comments

Comments
 (0)