Skip to content

Commit cfc3e7f

Browse files
committed
adapt_flex_checkpoint
1 parent b6f214e commit cfc3e7f

File tree

5 files changed

+357
-38
lines changed

5 files changed

+357
-38
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 132 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
get_last_checkpoint,
162162
get_scheduler,
163163
has_length,
164+
init_optimizer,
164165
set_seed,
165166
should_skip_data,
166167
speed_metrics,
@@ -199,7 +200,6 @@
199200
if is_datasets_available():
200201
import datasets
201202

202-
203203
try:
204204
from paddle.distributed.fleet.utils import mix_precision_utils
205205
except:
@@ -812,6 +812,10 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
812812
if resume_from_checkpoint is not None:
813813
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
814814
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
815+
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
816+
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
817+
else:
818+
success, err_msg = True, None
815819
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
816820
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
817821
else:
@@ -822,6 +826,11 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
822826
self.zcc_manager.set_ema_state_dict(path)
823827
else:
824828
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
829+
if success:
830+
logger.info(f"ZCC EMA load from {path}")
831+
self.zcc_manager.set_ema_state_dict(path)
832+
else:
833+
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
825834
else:
826835
logger.info(f"ZCC EMA state dict not found, in: {path}")
827836

@@ -929,13 +938,13 @@ def train(
929938
self._memory_tracker.start()
930939

931940
if not self.args.enable_auto_parallel:
932-
if not self.args.should_load_sharding_stage1_model:
941+
if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint:
933942
self._load_from_checkpoint(resume_from_checkpoint)
934943

935944
if self.args.should_load_sharding_stage1_model:
936945
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
937946

938-
elif self.args.should_save_sharding_stage1_model:
947+
elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint:
939948
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940949
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941950
model = self._wrap_model(self.model_wrapped)
@@ -949,13 +958,43 @@ def train(
949958
if delay_optimizer_creation:
950959
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
951960
self._load_optimizer_and_scheduler(resume_from_checkpoint)
961+
elif self.args.load_flex_checkpoint:
962+
model = self._wrap_model(self.model_wrapped)
963+
if model is not self.model:
964+
self.model_wrapped = model
965+
if delay_optimizer_creation:
966+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
967+
968+
if resume_from_checkpoint is not None:
969+
if not self.args.ignore_load_lr_and_optim:
970+
model_sharded_state_dict = self.model.sharded_state_dict()
971+
accessible_files = os.listdir(resume_from_checkpoint)
972+
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
973+
assert len(metadata_files) == 1, "Only support one metadata file now."
974+
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
975+
state_dict_metadata = metadata.state_dict_metadata
976+
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
977+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
978+
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
979+
dist.load_state_dict(
980+
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
981+
)
982+
self._load_scheduler(resume_from_checkpoint)
983+
else:
984+
model_sharded_state_dict = self.model.sharded_state_dict()
985+
sharded_state_dict = model_sharded_state_dict
986+
dist.load_state_dict(
987+
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
988+
)
952989
else:
953990
model = self._wrap_model(self.model_wrapped)
954991
# for the rest of this function `model` is the outside model, whether it was wrapped or not
955992
if model is not self.model:
956993
self.model_wrapped = model
994+
957995
if delay_optimizer_creation:
958996
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
997+
959998
self._load_optimizer_and_scheduler(resume_from_checkpoint)
960999
else:
9611000
model = self.model_wrapped
@@ -1357,6 +1396,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13571396
logger.warning(
13581397
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
13591398
)
1399+
13601400
elif isinstance(self.optimizer, HybridParallelOptimizer):
13611401
self.optimizer._step(parameters_list)
13621402
else:
@@ -1993,7 +2033,6 @@ def apply_decay_param_fun(x):
19932033
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
19942034
**optimizer_kwargs,
19952035
)
1996-
19972036
return self.optimizer
19982037

19992038
def _apply_to_optimizer(self, action):
@@ -2033,6 +2072,13 @@ def _load_rng_state(self, checkpoint):
20332072
return
20342073

20352074
rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth")
2075+
if not os.path.isfile(rng_file):
2076+
logger.info(
2077+
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
2078+
"fashion, reproducibility is not guaranteed."
2079+
)
2080+
return
2081+
rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth")
20362082
if not os.path.isfile(rng_file):
20372083
logger.info(
20382084
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
@@ -2238,7 +2284,6 @@ def _wrap_model(self, model, training=True):
22382284
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
22392285
assert self.optimizer is not None, "optimizer is empty!"
22402286
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
2241-
22422287
# Pipeline mode
22432288
if in_pipeline_parallel_mode:
22442289
if self.args.amp_master_grad:
@@ -2288,15 +2333,13 @@ def get_expected_keys(inputs, keys):
22882333
if self.args.amp_master_grad:
22892334
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
22902335
self.optimizer = fleet.distributed_optimizer(self.optimizer)
2291-
22922336
if (
22932337
hasattr(self.args, "enable_sharding_comm_overlap")
22942338
and self.args.enable_sharding_comm_overlap
22952339
and self.args.unified_checkpoint
22962340
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
22972341
):
22982342
model.register_sharding_comm_overlap_hook(self.optimizer)
2299-
23002343
# No pipeline mode, sharding only
23012344
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
23022345
# Sharded DDP!
@@ -2310,7 +2353,6 @@ def get_expected_keys(inputs, keys):
23102353
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
23112354
model, hcg, strategy=fleet.fleet._user_defined_strategy
23122355
)
2313-
23142356
if ShardingOption.SHARD_OP in self.args.sharding:
23152357
if self.args.amp_master_grad:
23162358
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
@@ -2352,6 +2394,7 @@ def get_expected_keys(inputs, keys):
23522394
offload=cpu_offload,
23532395
**extra_kwargs,
23542396
)
2397+
23552398
if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
23562399
assert hasattr(optimizer, "use_main_grad"), (
23572400
"Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2377,7 +2420,6 @@ def get_expected_keys(inputs, keys):
23772420
if self.args.amp_master_grad:
23782421
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
23792422
self.optimizer = fleet.distributed_optimizer(self.optimizer)
2380-
23812423
# stage1 has v1 and v2 version
23822424
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
23832425
if "split_param" in self.args.sharding_parallel_config:
@@ -2720,6 +2762,10 @@ def _save_checkpoint(self, model, metrics=None):
27202762
else:
27212763
self.save_model(output_dir)
27222764

2765+
if self.args.save_flex_checkpoint:
2766+
model_sharded_state_dict = self.model.sharded_state_dict()
2767+
os.makedirs(output_dir, exist_ok=True)
2768+
27232769
# Determine the new best metric / best model checkpoint
27242770
if metrics is not None and self.args.metric_for_best_model is not None:
27252771
metric_to_check = self.args.metric_for_best_model
@@ -2779,23 +2825,38 @@ def _save_checkpoint(self, model, metrics=None):
27792825
signal_dir,
27802826
)
27812827
else:
2782-
if self.dp_group.rank > 0: # this should only work for MoE saving
2783-
self._save_ckpt_func(
2784-
self._filter_moe_no_sync_optimizer_params(),
2785-
os.path.join(output_dir, optimizer_name),
2786-
saved_signal_path,
2828+
if self.args.save_flex_checkpoint:
2829+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2830+
dist.save_state_dict(
2831+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2832+
output_dir,
27872833
)
2788-
2834+
if self.args.should_save:
2835+
if self.tokenizer is not None and self.args.save_tokenizer:
2836+
self.tokenizer.save_pretrained(output_dir)
2837+
# Good practice: save your training arguments together with the trained model
2838+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
27892839
else:
2790-
state_dict = self.optimizer.state_dict()
2791-
save_path = os.path.join(output_dir, optimizer_name)
2792-
if self.args.use_async_save:
2793-
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
2794-
self._async_optimizer_saver.run(
2795-
state_dict, save_path, saved_signal_path=saved_signal_path
2840+
if self.dp_group.rank > 0: # this should only work for MoE saving
2841+
self._save_ckpt_func(
2842+
self._filter_moe_no_sync_optimizer_params(),
2843+
os.path.join(output_dir, optimizer_name),
2844+
saved_signal_path,
27962845
)
2846+
27972847
else:
2798-
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2848+
state_dict = self.optimizer.state_dict()
2849+
save_path = os.path.join(output_dir, optimizer_name)
2850+
if self.args.use_async_save:
2851+
assert not strtobool(
2852+
os.getenv("FLAG_LLM_PDC", "False")
2853+
), "Dont support FLAG_LLM_PDC"
2854+
self._async_optimizer_saver.run(
2855+
state_dict, save_path, saved_signal_path=saved_signal_path
2856+
)
2857+
else:
2858+
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2859+
27992860
else:
28002861
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
28012862
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
@@ -2806,7 +2867,12 @@ def _save_checkpoint(self, model, metrics=None):
28062867
or "remove_master_weight" not in self.args.unified_checkpoint_config
28072868
):
28082869
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))
2809-
if self.args.should_save or self.args.use_expert_parallel:
2870+
2871+
if (
2872+
self.args.should_save
2873+
or self.args.use_expert_parallel
2874+
or (self.args.data_parallel_degree > 1 and self.args.save_flex_checkpoint)
2875+
):
28102876
if not self.args.use_hybrid_parallel:
28112877
logger.info("Saving optimizer files.")
28122878
if self.args.unified_checkpoint:
@@ -2816,6 +2882,17 @@ def _save_checkpoint(self, model, metrics=None):
28162882
output_dir,
28172883
signal_dir,
28182884
)
2885+
elif self.args.save_flex_checkpoint:
2886+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2887+
dist.save_state_dict(
2888+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2889+
output_dir,
2890+
)
2891+
if self.args.should_save:
2892+
if self.tokenizer is not None and self.args.save_tokenizer:
2893+
self.tokenizer.save_pretrained(output_dir)
2894+
# Good practice: save your training arguments together with the trained model
2895+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
28192896
else:
28202897
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
28212898
self._save_ckpt_func(
@@ -2849,7 +2926,17 @@ def _save_checkpoint(self, model, metrics=None):
28492926

28502927
if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
28512928
self._offload_optimizer()
2852-
2929+
else:
2930+
if self.args.save_flex_checkpoint:
2931+
dist.save_state_dict(
2932+
model_sharded_state_dict,
2933+
output_dir,
2934+
)
2935+
if self.args.should_save:
2936+
if self.tokenizer is not None and self.args.save_tokenizer:
2937+
self.tokenizer.save_pretrained(output_dir)
2938+
# Good practice: save your training arguments together with the trained model
2939+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
28532940
self.runtime_timer.stop()
28542941

28552942
# Maybe delete some older checkpoints.
@@ -3064,6 +3151,7 @@ def _save(
30643151
else:
30653152
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
30663153
config_to_save = None
3154+
self.sharding_io.set_optimizer(self.optimizer)
30673155
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
30683156
self.model, merge_tensor_parallel=merge_tensor_parallel
30693157
)
@@ -3093,6 +3181,24 @@ def _save(
30933181
with open(path, "w") as f:
30943182
json.dump(model_meta, f)
30953183

3184+
def _load_scheduler(self, checkpoint):
3185+
if checkpoint is None:
3186+
self.runtime_timer.stop()
3187+
return
3188+
3189+
if not self.args.ignore_load_lr_and_optim:
3190+
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3191+
self.lr_scheduler.set_state_dict(
3192+
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
3193+
)
3194+
else:
3195+
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")
3196+
3197+
if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
3198+
self.scaler.load_state_dict(
3199+
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
3200+
)
3201+
30963202
def _load_optimizer_and_scheduler(self, checkpoint):
30973203
"""If optimizer and scheduler states exist, load them."""
30983204
self.runtime_timer.start("checkpoint loading time")
@@ -3134,6 +3240,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31343240
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
31353241
):
31363242
model = self.model_wrapped
3243+
31373244
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
31383245
model=model,
31393246
optimizer=self.optimizer,
@@ -3165,18 +3272,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31653272
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
31663273
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")
31673274

3168-
if not self.args.ignore_load_lr_and_optim:
3169-
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3170-
self.lr_scheduler.set_state_dict(
3171-
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
3172-
)
3173-
else:
3174-
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")
3175-
3176-
if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
3177-
self.scaler.load_state_dict(
3178-
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
3179-
)
3275+
self._load_scheduler(checkpoint)
31803276

31813277
if self.args.offload_optim:
31823278
logger.info("Offloading optimizer state...")

0 commit comments

Comments
 (0)