|
44 | 44 | from paddle import framework |
45 | 45 | from paddle.distributed.fleet.meta_parallel import PipelineLayer |
46 | 46 |
|
| 47 | +try: |
| 48 | + from paddle.distributed.fleet.meta_parallel import PipelineDatasetPreprocessor |
| 49 | +except: |
| 50 | + PipelineDatasetPreprocessor = None |
| 51 | + |
47 | 52 | try: |
48 | 53 | from paddle.base import core |
49 | 54 | except: |
@@ -2583,22 +2588,32 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle |
2583 | 2588 | # for v in self._pp_data_buffer[0].values(): |
2584 | 2589 | # assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}" |
2585 | 2590 |
|
2586 | | - with self.autocast_smart_context_manager(): |
2587 | | - inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer) |
2588 | | - self._pp_data_buffer = [] |
2589 | | - |
2590 | 2591 | model.train() |
2591 | 2592 | if model._dp_comm_overlap or model._sharding_comm_overlap: |
2592 | 2593 | for _, buffers in model._chunk_2_comm_buffers.items(): |
2593 | 2594 | for buffer in buffers: |
2594 | 2595 | buffer._acc_steps = self.args.gradient_accumulation_steps |
2595 | 2596 |
|
2596 | | - inputs = model._prepare_training( |
2597 | | - inputs, self.optimizer, self.lr_scheduler |
2598 | | - ) # None, None => [optimizer, lr_scheduler] |
2599 | 2597 | model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step |
2600 | 2598 | model.lr_scheduler = None |
2601 | 2599 |
|
| 2600 | + def _dataset_process_function(): |
| 2601 | + # Pass a local function to forward_backward_pipeline instead of the dataset itself. |
| 2602 | + # This prevents the dataset from being passed as a direct argument to forward_backward_pipeline, |
| 2603 | + # which would create additional reference counts that cannot be cleared, leading to GPU memory leaks. |
| 2604 | + with self.autocast_smart_context_manager(): |
| 2605 | + inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer) |
| 2606 | + self._pp_data_buffer = [] |
| 2607 | + |
| 2608 | + return model._prepare_training( |
| 2609 | + inputs, self.optimizer, self.lr_scheduler |
| 2610 | + ) # None, None => [optimizer, lr_scheduler] |
| 2611 | + |
| 2612 | + if PipelineDatasetPreprocessor is None: |
| 2613 | + inputs = _dataset_process_function() |
| 2614 | + else: |
| 2615 | + inputs = PipelineDatasetPreprocessor(_dataset_process_function) |
| 2616 | + |
2602 | 2617 | with self.autocast_smart_context_manager(): |
2603 | 2618 | loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None) |
2604 | 2619 |
|
|
0 commit comments