|
45 | 45 | from paddle import framework |
46 | 46 | from paddle.distributed.fleet.meta_parallel import PipelineLayer |
47 | 47 |
|
| 48 | +try: |
| 49 | + from paddle.distributed.fleet.meta_parallel import PipelineDatasetPreprocessor |
| 50 | +except: |
| 51 | + PipelineDatasetPreprocessor = None |
| 52 | + |
48 | 53 | try: |
49 | 54 | from paddle.base import core |
50 | 55 | except: |
@@ -2756,22 +2761,32 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle |
2756 | 2761 | # for v in self._pp_data_buffer[0].values(): |
2757 | 2762 | # assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}" |
2758 | 2763 |
|
2759 | | - with self.autocast_smart_context_manager(): |
2760 | | - inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer) |
2761 | | - self._pp_data_buffer = [] |
2762 | | - |
2763 | 2764 | model.train() |
2764 | 2765 | if model._dp_comm_overlap or model._sharding_comm_overlap: |
2765 | 2766 | for _, buffers in model._chunk_2_comm_buffers.items(): |
2766 | 2767 | for buffer in buffers: |
2767 | 2768 | buffer._acc_steps = self.args.gradient_accumulation_steps |
2768 | 2769 |
|
2769 | | - inputs = model._prepare_training( |
2770 | | - inputs, self.optimizer, self.lr_scheduler |
2771 | | - ) # None, None => [optimizer, lr_scheduler] |
2772 | 2770 | model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step |
2773 | 2771 | model.lr_scheduler = None |
2774 | 2772 |
|
| 2773 | + def _dataset_process_function(): |
| 2774 | + # Pass a local function to forward_backward_pipeline instead of the dataset itself. |
| 2775 | + # This prevents the dataset from being passed as a direct argument to forward_backward_pipeline, |
| 2776 | + # which would create additional reference counts that cannot be cleared, leading to GPU memory leaks. |
| 2777 | + with self.autocast_smart_context_manager(): |
| 2778 | + inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer) |
| 2779 | + self._pp_data_buffer = [] |
| 2780 | + |
| 2781 | + return model._prepare_training( |
| 2782 | + inputs, self.optimizer, self.lr_scheduler |
| 2783 | + ) # None, None => [optimizer, lr_scheduler] |
| 2784 | + |
| 2785 | + if PipelineDatasetPreprocessor is None: |
| 2786 | + inputs = _dataset_process_function() |
| 2787 | + else: |
| 2788 | + inputs = PipelineDatasetPreprocessor(_dataset_process_function) |
| 2789 | + |
2775 | 2790 | with self.autocast_smart_context_manager(): |
2776 | 2791 | loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None) |
2777 | 2792 |
|
|
0 commit comments