Skip to content
Closed
152 changes: 117 additions & 35 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
get_last_checkpoint,
get_scheduler,
has_length,
init_optimizer,
set_seed,
should_skip_data,
speed_metrics,
Expand Down Expand Up @@ -199,7 +200,6 @@
if is_datasets_available():
import datasets


try:
from paddle.distributed.fleet.utils import mix_precision_utils
except:
Expand Down Expand Up @@ -929,13 +929,13 @@ def train(
self._memory_tracker.start()

if not self.args.enable_auto_parallel:
if not self.args.should_load_sharding_stage1_model:
if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint:
self._load_from_checkpoint(resume_from_checkpoint)

if self.args.should_load_sharding_stage1_model:
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)

elif self.args.should_save_sharding_stage1_model:
elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint:
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
model = self._wrap_model(self.model_wrapped)
Expand All @@ -949,6 +949,36 @@ def train(
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)

elif self.args.load_flex_checkpoint:
model = self._wrap_model(self.model_wrapped)
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
if not self.args.ignore_load_lr_and_optim:
model_sharded_state_dict = self.model.sharded_state_dict()
accessible_files = os.listdir(resume_from_checkpoint)
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
assert len(metadata_files) == 1, "Only support one metadata file now."
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
state_dict_metadata = metadata.state_dict_metadata
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config, offload=False
)
self._load_scheduler(resume_from_checkpoint)
else:
model_sharded_state_dict = self.model.sharded_state_dict()
sharded_state_dict = model_sharded_state_dict
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
else:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
Expand Down Expand Up @@ -1357,6 +1387,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
logger.warning(
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
)

elif isinstance(self.optimizer, HybridParallelOptimizer):
self.optimizer._step(parameters_list)
else:
Expand Down Expand Up @@ -1993,7 +2024,6 @@ def apply_decay_param_fun(x):
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
**optimizer_kwargs,
)

return self.optimizer

def _apply_to_optimizer(self, action):
Expand Down Expand Up @@ -2238,7 +2268,6 @@ def _wrap_model(self, model, training=True):
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

# Pipeline mode
if in_pipeline_parallel_mode:
if self.args.amp_master_grad:
Expand Down Expand Up @@ -2288,15 +2317,13 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model.register_sharding_comm_overlap_hook(self.optimizer)

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand All @@ -2310,7 +2337,6 @@ def get_expected_keys(inputs, keys):
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
model, hcg, strategy=fleet.fleet._user_defined_strategy
)

if ShardingOption.SHARD_OP in self.args.sharding:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
Expand Down Expand Up @@ -2352,6 +2378,7 @@ def get_expected_keys(inputs, keys):
offload=cpu_offload,
**extra_kwargs,
)

if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
assert hasattr(optimizer, "use_main_grad"), (
"Current installed paddle doesn't support sharding stage 2 with main grad, "
Expand All @@ -2377,7 +2404,6 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
Expand Down Expand Up @@ -2720,6 +2746,10 @@ def _save_checkpoint(self, model, metrics=None):
else:
self.save_model(output_dir)

if self.args.save_flex_checkpoint:
model_sharded_state_dict = self.model.sharded_state_dict()
os.makedirs(output_dir, exist_ok=True)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand Down Expand Up @@ -2779,23 +2809,38 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir,
)
else:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
if self.args.save_flex_checkpoint:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)

if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(
os.getenv("FLAG_LLM_PDC", "False")
), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)

else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
Expand All @@ -2806,7 +2851,12 @@ def _save_checkpoint(self, model, metrics=None):
or "remove_master_weight" not in self.args.unified_checkpoint_config
):
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))
if self.args.should_save or self.args.use_expert_parallel:

if (
self.args.should_save
or self.args.use_expert_parallel
or (self.args.data_parallel_degree > 1 and not self.args.use_hybrid_parallel)
):
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2816,6 +2866,17 @@ def _save_checkpoint(self, model, metrics=None):
output_dir,
signal_dir,
)
elif self.args.save_flex_checkpoint:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
Expand Down Expand Up @@ -2850,6 +2911,18 @@ def _save_checkpoint(self, model, metrics=None):
if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
self._offload_optimizer()

else:
if self.args.save_flex_checkpoint:
dist.save_state_dict(
model_sharded_state_dict,
output_dir,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

self.runtime_timer.stop()

# Maybe delete some older checkpoints.
Expand Down Expand Up @@ -3064,6 +3137,7 @@ def _save(
else:
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
config_to_save = None
self.sharding_io.set_optimizer(self.optimizer)
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
self.model, merge_tensor_parallel=merge_tensor_parallel
)
Expand Down Expand Up @@ -3093,6 +3167,24 @@ def _save(
with open(path, "w") as f:
json.dump(model_meta, f)

def _load_scheduler(self, checkpoint):
if checkpoint is None:
self.runtime_timer.stop()
return

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
self.runtime_timer.start("checkpoint loading time")
Expand Down Expand Up @@ -3134,6 +3226,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model = self.model_wrapped

opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
model=model,
optimizer=self.optimizer,
Expand Down Expand Up @@ -3165,18 +3258,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)
self._load_scheduler(checkpoint)

if self.args.offload_optim:
logger.info("Offloading optimizer state...")
Expand Down
Loading
Loading