Skip to content

Commit 7c2ff96

Browse files
authored
Merge branch 'develop' into adapter_flex_checkpoint
2 parents 1d439b0 + 119ed11 commit 7c2ff96

File tree

49 files changed

+693
-380
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+693
-380
lines changed

.github/workflows/approval.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ jobs:
2323
- name: Update paddle
2424
run: |
2525
wget -q --no-proxy https://xly-devops.bj.bcebos.com/PaddleTest/PaddleNLP/PaddleNLP-develop.tar.gz --no-check-certificate
26-
rm -rf PaddleNLP-develop && tar zxf PaddleNLP-develop.tar.gz >/dev/null
27-
mv PaddleNLP-develop PaddleNLP && rm -rf PaddleNLP-develop.tar.gz >/dev/null
28-
cd PaddleNLP/
26+
tar zxf PaddleNLP-develop.tar.gz --strip-components=1 >/dev/null
27+
rm -rf PaddleNLP-develop.tar.gz >/dev/null
2928
git fetch origin pull/${PR_ID}/head
3029
git checkout -b origin_pr FETCH_HEAD
3130
git remote add upstream https://github.com/PaddlePaddle/PaddleNLP.git
@@ -44,5 +43,5 @@ jobs:
4443
- name: Display Required Approvers
4544
if: steps.check-bypass.outputs.can-skip != 'true'
4645
run: |
47-
cd PaddleNLP/scripts/ci_approval
46+
cd scripts/ci_approval
4847
bash -x run_ci_approval.sh

csrc/gpu/unittest/test_get_padding_offset_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,5 @@ def test_get_padding_offset_v2(self):
6464
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
6565
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
6666

67-
6867
if __name__ == "__main__":
6968
unittest.main()

llm/tools/preprocess/create_pretraining_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def get_whole_word_mask_tokens(tokens, words, max_word_length=6):
176176
i += 1
177177
continue
178178

179-
# add "##" mark on the middel tokens of Chinese words
179+
# add "##" mark on the middle tokens of Chinese words
180180
# such as ["通过", "利用"] -> ["通", "##过", "利", "##用"]
181181
has_add = False
182182
for length in range(max_word_length, 0, -1):

llm/utils/fused_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def sp_async_reducesctter(x_grad):
106106
def sync_mp_allreduce(task, dist_tensor):
107107
mp_placement_index = dist_tensor.process_mesh.dim_names.index("mp")
108108
new_placments = list()
109-
for idx, placment in enumerate(dist_tensor.placements):
109+
for idx, placement in enumerate(dist_tensor.placements):
110110
if idx == mp_placement_index:
111111
new_placments.append(dist.Replicate())
112112
else:
113-
new_placments.append(placment)
113+
new_placments.append(placement)
114114
place = paddle.framework._current_expected_place()
115115
place = paddle.framework._get_paddle_place(place)
116116

llm/utils/sp_async_reduce_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def forward_pre_hook(layer, input):
172172
ipp = id2ipp[id(layer)]
173173

174174

175-
def forward_post_hook(layer, input, ouput):
175+
def forward_post_hook(layer, input, output):
176176
paddle.nn.functional.linear = paddle_nn_functional_linear
177177
if is_fused_matmul_bias_supported():
178178
paddle.incubate.nn.functional.fused_linear = paddle_incubate_nn_functional_fused_linear

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 137 additions & 46 deletions
Large diffs are not rendered by default.

paddlenlp/trainer/auto_trainer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import random
1818
import time
19+
import types
1920
from typing import Any, Dict, Optional, Union
2021

2122
import numpy as np
@@ -24,6 +25,7 @@
2425
import paddle.distributed.auto_parallel.intermediate.parallelize as parallelize
2526
import paddle.nn as nn
2627
from paddle.distributed import fleet
28+
from paddle.distributed.auto_parallel._utils import _patch_grads_for_step
2729
from paddle.profiler.utils import switch_job_schedule_profiler
2830
from tqdm.auto import tqdm
2931

@@ -518,6 +520,18 @@ def _inner_training_loop(
518520
npu_accelerate_plugin(self.optimizer)
519521

520522
model, dist_loader = self._wrap_for_auto(model, train_dataloader)
523+
524+
if (
525+
dist.in_auto_parallel_align_mode()
526+
): # When in auto parallel align mode, patching the optimizer step function
527+
528+
orig_step = (
529+
self.optimizer.step.__func__ if hasattr(self.optimizer.step, "__func__") else self.optimizer.step
530+
)
531+
decorator = _patch_grads_for_step(amp_master_grad=self.args.amp_master_grad)
532+
new_step = decorator(orig_step)
533+
self.optimizer.__dict__["step"] = types.MethodType(new_step, self.optimizer)
534+
521535
train_dataloader = dist_loader()
522536
if resume_from_checkpoint is not None:
523537
self._load_from_checkpoint(resume_from_checkpoint)

paddlenlp/trainer/trainer.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
from paddle import framework
4646
from paddle.distributed.fleet.meta_parallel import PipelineLayer
4747

48+
try:
49+
from paddle.distributed.fleet.meta_parallel import PipelineDatasetPreprocessor
50+
except:
51+
PipelineDatasetPreprocessor = None
52+
4853
try:
4954
from paddle.base import core
5055
except:
@@ -2756,22 +2761,32 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
27562761
# for v in self._pp_data_buffer[0].values():
27572762
# assert isinstance(v, paddle.Tensor), f"Only support tensor as pipeline mode input, got type {type(v)}"
27582763

2759-
with self.autocast_smart_context_manager():
2760-
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
2761-
self._pp_data_buffer = []
2762-
27632764
model.train()
27642765
if model._dp_comm_overlap or model._sharding_comm_overlap:
27652766
for _, buffers in model._chunk_2_comm_buffers.items():
27662767
for buffer in buffers:
27672768
buffer._acc_steps = self.args.gradient_accumulation_steps
27682769

2769-
inputs = model._prepare_training(
2770-
inputs, self.optimizer, self.lr_scheduler
2771-
) # None, None => [optimizer, lr_scheduler]
27722770
model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step
27732771
model.lr_scheduler = None
27742772

2773+
def _dataset_process_function():
2774+
# Pass a local function to forward_backward_pipeline instead of the dataset itself.
2775+
# This prevents the dataset from being passed as a direct argument to forward_backward_pipeline,
2776+
# which would create additional reference counts that cannot be cleared, leading to GPU memory leaks.
2777+
with self.autocast_smart_context_manager():
2778+
inputs = model._prepare_pipeline_inputs_func(self._pp_data_buffer)
2779+
self._pp_data_buffer = []
2780+
2781+
return model._prepare_training(
2782+
inputs, self.optimizer, self.lr_scheduler
2783+
) # None, None => [optimizer, lr_scheduler]
2784+
2785+
if PipelineDatasetPreprocessor is None:
2786+
inputs = _dataset_process_function()
2787+
else:
2788+
inputs = PipelineDatasetPreprocessor(_dataset_process_function)
2789+
27752790
with self.autocast_smart_context_manager():
27762791
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)
27772792

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,12 @@ class TrainingArguments:
11451145
def __post_init__(self):
11461146
world_size = paddle.distributed.get_world_size()
11471147
if in_auto_parallel_align_mode():
1148-
self.max_grad_norm = 0.0
1148+
# self.max_grad_norm = 0.0
1149+
# The current auto_hybrid_pp has aligned the handling of ClipGradByGlobalNorm with the original dygraph semi-auto parallel and dynamic manual-parallel modes and can correctly handle grad_clip, so it is no longer necessary to set max_grad_norm=0.0.
1150+
if self.max_grad_norm != 0.0:
1151+
warnings.warn(
1152+
"max_grad_norm is not 0.0,We will execute ClipGradByGlobalNorm,if you want to disable it,please set max_grad_norm=0.0"
1153+
)
11491154
os.environ["FLAGS_max_inplace_grad_add"] = "65536"
11501155
os.environ["FLAGS_embedding_deterministic"] = "1"
11511156
os.environ["FLAGS_cudnn_deterministic"] = "1"

paddlenlp/transformers/clipseg/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def forward(
340340
attn_weights = nn.functional.softmax(attn_weights, axis=-1)
341341

342342
if output_attentions:
343-
# this operation is a bit akward, but it's required to
343+
# this operation is a bit awkward, but it's required to
344344
# make sure that attn_weights keeps its gradient.
345345
# In order to do so, attn_weights have to reshaped
346346
# twice and have to be reused in the following

0 commit comments

Comments
 (0)