Skip to content

Commit 03bc1a0

Browse files
authored
[AutoParallel] GradientClipByGlobalNorm: Add align mode branch in autotrainer (#10960)
* Add test branch for align mode * fix test bug * rerun CI * add the warning to warn the max_grad_norm is not 0.0
1 parent b6f214e commit 03bc1a0

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import random
1818
import time
19+
import types
1920
from typing import Any, Dict, Optional, Union
2021

2122
import numpy as np
@@ -24,6 +25,7 @@
2425
import paddle.distributed.auto_parallel.intermediate.parallelize as parallelize
2526
import paddle.nn as nn
2627
from paddle.distributed import fleet
28+
from paddle.distributed.auto_parallel._utils import _patch_grads_for_step
2729
from paddle.profiler.utils import switch_job_schedule_profiler
2830
from tqdm.auto import tqdm
2931

@@ -518,6 +520,18 @@ def _inner_training_loop(
518520
npu_accelerate_plugin(self.optimizer)
519521

520522
model, dist_loader = self._wrap_for_auto(model, train_dataloader)
523+
524+
if (
525+
dist.in_auto_parallel_align_mode()
526+
): # When in auto parallel align mode, patching the optimizer step function
527+
528+
orig_step = (
529+
self.optimizer.step.__func__ if hasattr(self.optimizer.step, "__func__") else self.optimizer.step
530+
)
531+
decorator = _patch_grads_for_step(amp_master_grad=self.args.amp_master_grad)
532+
new_step = decorator(orig_step)
533+
self.optimizer.__dict__["step"] = types.MethodType(new_step, self.optimizer)
534+
521535
train_dataloader = dist_loader()
522536
if resume_from_checkpoint is not None:
523537
self._load_from_checkpoint(resume_from_checkpoint)

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,12 @@ class TrainingArguments:
11091109
def __post_init__(self):
11101110
world_size = paddle.distributed.get_world_size()
11111111
if in_auto_parallel_align_mode():
1112-
self.max_grad_norm = 0.0
1112+
# self.max_grad_norm = 0.0
1113+
# The current auto_hybrid_pp has aligned the handling of ClipGradByGlobalNorm with the original dygraph semi-auto parallel and dynamic manual-parallel modes and can correctly handle grad_clip, so it is no longer necessary to set max_grad_norm=0.0.
1114+
if self.max_grad_norm != 0.0:
1115+
warnings.warn(
1116+
"max_grad_norm is not 0.0,We will execute ClipGradByGlobalNorm,if you want to disable it,please set max_grad_norm=0.0"
1117+
)
11131118
os.environ["FLAGS_max_inplace_grad_add"] = "65536"
11141119
os.environ["FLAGS_embedding_deterministic"] = "1"
11151120
os.environ["FLAGS_cudnn_deterministic"] = "1"

scripts/distribute/ci_case_auto.sh

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_hybrid_pp() {
935935
--data_impl "mmap" \
936936
--enable_auto_parallel 1 \
937937
--to_static 0 \
938-
--max_grad_norm 0.0 \
938+
--max_grad_norm 1.0 \
939939
>>${log_path}/$FUNCNAME 2>&1
940940
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
941941
ips=-1
@@ -1003,7 +1003,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_hybrid_pp() {
10031003
--data_impl "mmap" \
10041004
--enable_auto_parallel 1 \
10051005
--to_static 0 \
1006-
--max_grad_norm 0.0 \
1006+
--max_grad_norm 1.0 \
10071007
--resume_from_checkpoint "${case_out_dir}/checkpoint-9" \
10081008
>>${log_path}/$FUNCNAME 2>&1
10091009
pp_resume_from_hybrid_ckpt_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
@@ -1012,6 +1012,69 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_hybrid_pp() {
10121012
echo "pp_resume from hybrid ckpt result: loss=$pp_resume_from_hybrid_ckpt_loss ips=$pp_resume_from_hybrid_ckpt_ips mem=$pp_resume_from_hybrid_ckpt_mem"
10131013

10141014
check_result $FUNCNAME ${loss} ${pp_resume_from_hybrid_ckpt_loss} ${ips} ${pp_resume_from_hybrid_ckpt_ips} ${mem} ${pp_resume_from_hybrid_ckpt_mem}
1015+
1016+
echo "=========== $FUNCNAME run dygraph auto hybrid pp in align mode ==========="
1017+
export FLAGS_enable_auto_parallel_align_mode=1
1018+
task_name="llama_auto_bs8_fp16_dp2mp2pp2_hybrid_pp_in_align_mode"
1019+
align_mode_case_out_dir="output/$task_name"
1020+
align_mode_case_log_dir="output/$task_name""_log"
1021+
rm -rf $align_mode_case_out_dir
1022+
rm -rf $align_mode_case_log_dir
1023+
1024+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $align_mode_case_log_dir run_pretrain_auto.py \
1025+
--model_type "llama_pp" \
1026+
--model_name_or_path "facebook/llama-7b" \
1027+
--tokenizer_name_or_path "facebook/llama-7b" \
1028+
--input_dir "./data" \
1029+
--output_dir $align_mode_case_out_dir \
1030+
--split 949,50,1 \
1031+
--max_seq_length 2048 \
1032+
--hidden_size 1024 \
1033+
--intermediate_size 3072 \
1034+
--num_hidden_layers 8 \
1035+
--num_attention_heads 32 \
1036+
--per_device_train_batch_size 4 \
1037+
--per_device_eval_batch_size 4 \
1038+
--n_microbatch 4 \
1039+
--gradient_accumulation_steps 1 \
1040+
--use_flash_attention 1 \
1041+
--use_fused_rms_norm 0 \
1042+
--fp16 1 \
1043+
--fp16_opt_level "O2" \
1044+
--amp_master_grad 1 \
1045+
--scale_loss 1024 \
1046+
--pipeline_parallel_degree 2 \
1047+
--pipeline_schedule_mode "FThenB" \
1048+
--tensor_parallel_degree 2 \
1049+
--sharding_parallel_degree 1 \
1050+
--learning_rate 0.0001 \
1051+
--min_learning_rate 0.00001 \
1052+
--max_steps 10 \
1053+
--save_steps 20 \
1054+
--weight_decay 0.01 \
1055+
--warmup_ratio 0.01 \
1056+
--logging_steps 1 \
1057+
--dataloader_num_workers 1 \
1058+
--sharding "" \
1059+
--eval_steps 1000000 \
1060+
--disable_tqdm true \
1061+
--continue_training 0 \
1062+
--recompute 0 \
1063+
--do_train \
1064+
--do_eval \
1065+
--device "gpu" \
1066+
--data_impl "mmap" \
1067+
--enable_auto_parallel 1 \
1068+
--to_static 0 \
1069+
--max_grad_norm 1.0 \
1070+
>>${log_path}/$FUNCNAME 2>&1
1071+
align_mode_loss=`cat $align_mode_case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1072+
align_mode_ips=-1
1073+
align_mode_mem=-1
1074+
echo "result: loss=$align_mode_loss ips=$align_mode_ips mem=$align_mode_mem"
1075+
1076+
check_result $FUNCNAME ${loss} ${align_mode_loss} ${ips} ${align_mode_ips} ${mem} ${align_mode_mem}
1077+
10151078
echo "=========== $FUNCNAME run end ==========="
10161079
fi
10171080
}

0 commit comments

Comments
 (0)