Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
from paddle import framework
from paddle.distributed.fleet.meta_parallel import PipelineLayer

try:
from paddle.distributed.fleet.meta_parallel import PipelineDatasetPreprocessor
except:
PipelineDatasetPreprocessor = None

try:
from paddle.base import core
except:
Expand Down Expand Up @@ -2641,24 +2646,32 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
if len(self._pp_data_buffer) != self.args.gradient_accumulation_steps:
return paddle.zeros([])

# for v in self._pp_data_buffer[0].values():
# assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}"
with self.autocast_smart_context_manager():
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
self._pp_data_buffer = []

model.train()
if model._dp_comm_overlap or model._sharding_comm_overlap:
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer._acc_steps = self.args.gradient_accumulation_steps

inputs = model._prepare_training(
inputs, self.optimizer, self.lr_scheduler
) # None, None => [optimizer, lr_scheduler]
model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step
model.lr_scheduler = None

def _dataset_process_function():
# Pass a local function to forward_backward_pipeline instead of the dataset itself.
# This prevents the dataset from being passed as a direct argument to forward_backward_pipeline,
# which would create additional reference counts that cannot be cleared, leading to GPU memory leaks.
with self.autocast_smart_context_manager():
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
self._pp_data_buffer = []

return model._prepare_training(
inputs, self.optimizer, self.lr_scheduler
) # None, None => [optimizer, lr_scheduler]

if PipelineDatasetPreprocessor is None:
inputs = _dataset_process_function()
else:
inputs = PipelineDatasetPreprocessor(_dataset_process_function)

with self.autocast_smart_context_manager():
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)

Expand Down
14 changes: 14 additions & 0 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,8 @@ def __post_init__(self):
"use_dualpipev",
"forward_backward_overlap_scheduler",
"enable_dynamic_shape",
"sync_moment",
"sync_param",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accept config is disable_p2p_cache_shape, disable_partial_send_recv."
Expand Down Expand Up @@ -1333,6 +1335,18 @@ def __post_init__(self):
in pipeline_parallel_config,
"enable_dynamic_shape": "enable_dynamic_shape" in pipeline_parallel_config,
}

pp_sync_param = "sync_param" in pipeline_parallel_config
pp_sync_moment = "sync_moment" in pipeline_parallel_config

if pp_sync_param:
logger.info("setting pp sync_param")
strategy.hybrid_configs["pp_configs"].sync_param = True

if pp_sync_moment:
logger.info("setting pp sync_moment")
strategy.hybrid_configs["pp_configs"].sync_moment = True

if dygraph_pp_configs["dp_comm_overlap"]:
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue

Expand Down