Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
from paddle import framework
from paddle.distributed.fleet.meta_parallel import PipelineLayer

try:
from paddle.distributed.fleet.meta_parallel import PipelineDatasetPreprocessor
except:
PipelineDatasetPreprocessor = None

try:
from paddle.base import core
except:
Expand Down Expand Up @@ -2583,22 +2588,32 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
# for v in self._pp_data_buffer[0].values():
# assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}"

with self.autocast_smart_context_manager():
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
self._pp_data_buffer = []

model.train()
if model._dp_comm_overlap or model._sharding_comm_overlap:
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._acc_steps = self.args.gradient_accumulation_steps

inputs = model._prepare_training(
inputs, self.optimizer, self.lr_scheduler
) # None, None => [optimizer, lr_scheduler]
model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step
model.lr_scheduler = None

def _dataset_process_function():
# Pass a local function to forward_backward_pipeline instead of the dataset itself.
# This prevents the dataset from being passed as a direct argument to forward_backward_pipeline,
# which would create additional reference counts that cannot be cleared, leading to GPU memory leaks.
with self.autocast_smart_context_manager():
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
self._pp_data_buffer = []

return model._prepare_training(
inputs, self.optimizer, self.lr_scheduler
) # None, None => [optimizer, lr_scheduler]

if PipelineDatasetPreprocessor is None:
inputs = _dataset_process_function()
else:
inputs = PipelineDatasetPreprocessor(_dataset_process_function)

with self.autocast_smart_context_manager():
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)

Expand Down
Loading