diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 9a3bad94c59b..e90634f91b8d 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -161,6 +161,7 @@ get_last_checkpoint, get_scheduler, has_length, + init_optimizer, set_seed, should_skip_data, speed_metrics, @@ -199,7 +200,6 @@ if is_datasets_available(): import datasets - try: from paddle.distributed.fleet.utils import mix_precision_utils except: @@ -381,7 +381,10 @@ def __init__( is_ema=self.args.sharded_model_from_ema, ) - if self.args.unified_checkpoint: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ): self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args) if self.sharding is not None and self.optimizer is not None: @@ -435,8 +438,9 @@ def _save_ckpt_func(state_dict, path, signal_path=None): not self.args.ignore_save_lr_and_optim ), "ignore_save_lr_and_optim should be False when using zero cost checkpoint" assert self.args.use_hybrid_parallel, "use_hybrid_parallel must be True when using zero cost checkpoint" - assert ( - not self.args.unified_checkpoint + assert not ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" ), "use_unified_checkpoint should be False when using zero cost checkpoint" assert not strtobool( os.getenv("FLAG_LLM_PDC", "False") @@ -474,7 +478,10 @@ def _save_ckpt_func(state_dict, path, signal_path=None): or isinstance(self.model, LoKrModel) or isinstance(self.model, ReFTModel) ): - if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "skip_save_model_weight" in self.args.unified_checkpoint_config + ): self.args.unified_checkpoint_config.remove("skip_save_model_weight") logger.warning( "We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config." @@ -658,14 +665,17 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: - uc_async_save = self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config + uc_async_save = ( + self.args.load_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ) resume_from_checkpoint = get_last_checkpoint( self.args.output_dir, signal_folder=self.args.output_signal_dir, uc_async_save=uc_async_save ) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": if resume_from_checkpoint is not None: use_unified_checkpoint = False if self.is_unified_checkpoint(resume_from_checkpoint): @@ -929,13 +939,18 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model: + if ( + not self.args.should_load_sharding_stage1_model + and not self.args.load_checkpoint_format == "flex_checkpoint" + ): self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint) - elif self.args.should_save_sharding_stage1_model: + elif self.args.should_save_sharding_stage1_model and not ( + self.args.load_checkpoint_format == "flex_checkpoint" + ): # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model. # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks. model = self._wrap_model(self.model_wrapped) @@ -949,13 +964,43 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + elif self.args.load_checkpoint_format == "flex_checkpoint": + model = self._wrap_model(self.model_wrapped) + if model is not self.model: + self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if resume_from_checkpoint is not None: + if not self.args.ignore_load_lr_and_optim: + model_sharded_state_dict = self.model.sharded_state_dict() + accessible_files = os.listdir(resume_from_checkpoint) + metadata_files = [file for file in accessible_files if file.endswith(".metadata")] + assert len(metadata_files) == 1, "Only support one metadata file now." + metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0])) + state_dict_metadata = metadata.state_dict_metadata + init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} + dist.load_state_dict( + sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config + ) + self._load_scheduler(resume_from_checkpoint) + else: + model_sharded_state_dict = self.model.sharded_state_dict() + sharded_state_dict = model_sharded_state_dict + dist.load_state_dict( + sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config + ) else: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model + if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) + self._load_optimizer_and_scheduler(resume_from_checkpoint) else: model = self.model_wrapped @@ -1357,6 +1402,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): logger.warning( f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" ) + elif isinstance(self.optimizer, HybridParallelOptimizer): self.optimizer._step(parameters_list) else: @@ -1438,7 +1484,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): logger.info("\nTraining completed. \n") # unlink shared_memory if used. - if self.args.unified_checkpoint: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ): self.unified_checkpoint_handler.unlink_shared_memory() if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: @@ -1451,7 +1500,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): self._load_best_model_from_peft_checkpoint() else: - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.load_unified_checkpoint( self.model, self.state.best_model_checkpoint, @@ -1501,7 +1550,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): return TrainOutput(self.state.global_step, train_loss, metrics) def _load_best_model_from_peft_checkpoint(self): - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.load_unified_checkpoint( self.model, self.state.best_model_checkpoint, @@ -1993,7 +2042,6 @@ def apply_decay_param_fun(x): grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None, **optimizer_kwargs, ) - return self.optimizer def _apply_to_optimizer(self, action): @@ -2033,6 +2081,13 @@ def _load_rng_state(self, checkpoint): return rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth") if not os.path.isfile(rng_file): logger.info( "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " @@ -2042,7 +2097,7 @@ def _load_rng_state(self, checkpoint): checkpoint_rng_state = paddle.load(rng_file, return_numpy=True) if checkpoint_rng_state.get("world_size", None) != self.args.world_size: - logger.warn("Cannot load rng states when changing world size of training job.") + logger.warning("Cannot load rng states when changing world size of training job.") return random.setstate(checkpoint_rng_state["python"]) @@ -2238,7 +2293,6 @@ def _wrap_model(self, model, training=True): mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) assert self.optimizer is not None, "optimizer is empty!" self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - # Pipeline mode if in_pipeline_parallel_mode: if self.args.amp_master_grad: @@ -2288,15 +2342,16 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap - and self.args.unified_checkpoint + and ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ) and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) - # No pipeline mode, sharding only if not in_pipeline_parallel_mode and in_sharding_parallel_mode: # Sharded DDP! @@ -2310,7 +2365,6 @@ def get_expected_keys(inputs, keys): model = paddle.distributed.fleet.meta_parallel.TensorParallel( model, hcg, strategy=fleet.fleet._user_defined_strategy ) - if ShardingOption.SHARD_OP in self.args.sharding: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use @@ -2352,6 +2406,7 @@ def get_expected_keys(inputs, keys): offload=cpu_offload, **extra_kwargs, ) + if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad: assert hasattr(optimizer, "use_main_grad"), ( "Current installed paddle doesn't support sharding stage 2 with main grad, " @@ -2377,7 +2432,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: if "split_param" in self.args.sharding_parallel_config: @@ -2629,7 +2683,10 @@ def save_model( if self.args.should_save_model_state: self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) else: - if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ): os.makedirs(signal_dir, exist_ok=True) if self.is_in_train: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2645,7 +2702,7 @@ def save_model( # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done")) if ( - self.args.unified_checkpoint + self.args.save_checkpoint_format == "unified_checkpoint" and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): @@ -2720,6 +2777,10 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) + if self.args.save_checkpoint_format == "flex_checkpoint": + model_sharded_state_dict = self.model.sharded_state_dict() + os.makedirs(output_dir, exist_ok=True) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -2764,14 +2825,16 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") - if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): + if self.args.save_checkpoint_format == "unified_checkpoint" and ( + self.args.offload_optim or self.args.tensorwise_offload_optimizer + ): self._reload_optimizer() if self.args.use_hybrid_parallel: if self.dp_group.rank <= 0 or self.args.use_expert_parallel: os.makedirs(output_dir, exist_ok=True) logger.info("Saving optimizer files.") - if self.args.unified_checkpoint: + if self.args.save_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.save_unified_optimizer( self.model, self.optimizer, @@ -2779,25 +2842,43 @@ def _save_checkpoint(self, model, metrics=None): signal_dir, ) else: - if self.dp_group.rank > 0: # this should only work for MoE saving - self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), - os.path.join(output_dir, optimizer_name), - saved_signal_path, + if self.args.save_checkpoint_format == "flex_checkpoint": + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, ) - + if self.args.should_save: + if self.tokenizer is not None and self.args.save_tokenizer: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) else: - state_dict = self.optimizer.state_dict() - save_path = os.path.join(output_dir, optimizer_name) - if self.args.use_async_save: - assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - self._async_optimizer_saver.run( - state_dict, save_path, saved_signal_path=saved_signal_path + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: - self._save_ckpt_func(state_dict, save_path, saved_signal_path) + state_dict = self.optimizer.state_dict() + save_path = os.path.join(output_dir, optimizer_name) + if self.args.use_async_save: + assert not strtobool( + os.getenv("FLAG_LLM_PDC", "False") + ), "Dont support FLAG_LLM_PDC" + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path + ) + else: + self._save_ckpt_func(state_dict, save_path, saved_signal_path) + else: - if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ): global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) @@ -2806,16 +2887,32 @@ def _save_checkpoint(self, model, metrics=None): or "remove_master_weight" not in self.args.unified_checkpoint_config ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) - if self.args.should_save or self.args.use_expert_parallel: + + if ( + self.args.should_save + or self.args.use_expert_parallel + or (self.args.data_parallel_degree > 1 and self.args.save_checkpoint_format == "flex_checkpoint") + ): if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") - if self.args.unified_checkpoint: + if self.args.save_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.save_unified_optimizer( self.model, self.optimizer, output_dir, signal_dir, ) + elif self.args.save_checkpoint_format == "flex_checkpoint": + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) + if self.args.should_save: + if self.tokenizer is not None and self.args.save_tokenizer: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) else: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( @@ -2836,7 +2933,7 @@ def _save_checkpoint(self, model, metrics=None): if self.do_grad_scaling: paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) else: - if self.args.unified_checkpoint and not self.args.use_hybrid_parallel: + if self.args.save_checkpoint_format == "unified_checkpoint" and not self.args.use_hybrid_parallel: if "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) @@ -2847,9 +2944,21 @@ def _save_checkpoint(self, model, metrics=None): ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) - if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): + if self.args.save_checkpoint_format == "unified_checkpoint" and ( + self.args.offload_optim or self.args.tensorwise_offload_optimizer + ): self._offload_optimizer() - + else: + if self.args.save_checkpoint_format == "flex_checkpoint": + dist.save_state_dict( + model_sharded_state_dict, + output_dir, + ) + if self.args.should_save: + if self.tokenizer is not None and self.args.save_tokenizer: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) self.runtime_timer.stop() # Maybe delete some older checkpoints. @@ -2952,7 +3061,10 @@ def _save( # signal_dir is used for asynchronous saving situations. signal_dir = self.args.output_signal_dir - if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ): if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]: signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1]) os.makedirs(signal_dir, exist_ok=True) @@ -2964,7 +3076,7 @@ def _save( if ( strtobool(os.getenv("FLAG_LLM_PDC", "False")) and paddle.distributed.get_rank() == 0 - and self.args.unified_checkpoint + and self.args.save_checkpoint_format == "unified_checkpoint" and "async_save" in self.args.unified_checkpoint_config ): world_size = paddle.distributed.get_world_size() @@ -2987,7 +3099,7 @@ def _save( # Good practice: save your training arguments together with the trained model paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - if self.args.unified_checkpoint: + if self.args.save_checkpoint_format == "unified_checkpoint": unified_checkpoint_config_backup = self.args.unified_checkpoint_config # backup and remove unified_checkpoint_config for not trine stage if not self.is_in_train: @@ -3064,6 +3176,7 @@ def _save( else: if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model: config_to_save = None + self.sharding_io.set_optimizer(self.optimizer) state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config( self.model, merge_tensor_parallel=merge_tensor_parallel ) @@ -3093,6 +3206,24 @@ def _save( with open(path, "w") as f: json.dump(model_meta, f) + def _load_scheduler(self, checkpoint): + if checkpoint is None: + self.runtime_timer.stop() + return + + if not self.args.ignore_load_lr_and_optim: + if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + self.lr_scheduler.set_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) + ) + else: + raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") + + if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) + ) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" self.runtime_timer.start("checkpoint loading time") @@ -3112,7 +3243,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) else: use_unified_checkpoint = False - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": if self.is_unified_checkpoint(checkpoint): use_unified_checkpoint = True else: @@ -3134,6 +3265,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model = self.model_wrapped + opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, optimizer=self.optimizer, @@ -3165,18 +3297,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.") - if not self.args.ignore_load_lr_and_optim: - if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.lr_scheduler.set_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) - ) - else: - raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") - - if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) - ) + self._load_scheduler(checkpoint) if self.args.offload_optim: logger.info("Offloading optimizer state...") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index d8d88d1cd4ad..0b9fa9ea5c16 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -37,6 +37,10 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, +) from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.io import IterableDataset from paddle.optimizer.lr import LambdaDecay @@ -1357,3 +1361,87 @@ def set_comm_config(configs, attr, dict_obj): set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None)) set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None)) return strategy + + +def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata): + """ + Initialize the optimizer's states according to its type. + + For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. + For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. + For other cases, initializes accumulators for all parameters. + + Args: + optimizer: The optimizer instance to be initialized. + """ + optimizer_state_names = [".moment1_0", ".moment2_0", ".beta1_pow_acc_0", ".beta2_pow_acc_0", ".w_0"] + inner_opt = getattr(optimizer, "_inner_opt", None) + static_to_struct_mapping = {} + model_sharded_state_dict = dict(sorted(model_sharded_state_dict.items())) + for k, v in model_sharded_state_dict.items(): + if v.local_tensor.name not in static_to_struct_mapping: + static_to_struct_mapping[v.local_tensor.name] = k + + if isinstance(inner_opt, DygraphShardingOptimizer): + local_params = optimizer._rank2params[optimizer._sharding_rank] + param_list = [] + for param in local_params: + param_name = param.name + struct_name = static_to_struct_mapping[param_name] + if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names): + continue + param_list.append(param) + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) + return + + elif isinstance(inner_opt, DygraphShardingOptimizerV2): + + def init_param_optimizer_states(param_iter): + master_weights = {} + state_dict = {} + moments = ("moment1_0", "moment2_0") + betas = ("beta1_pow_acc_0", "beta2_pow_acc_0") + for static_name, shape, no_need_master_weights in param_iter: + if not no_need_master_weights: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + prefix = f"{static_name}_fp32_master_0_" + else: + prefix = f"{static_name}_" + + for moment in moments: + key = f"{prefix}{moment}" + state_dict[key] = paddle.zeros(shape, dtype="float32") + for beta in betas: + key = f"{prefix}{beta}" + state_dict[key] = paddle.zeros((1,), dtype="float32") + return master_weights, state_dict + + def buffer_params(): + for buffer in optimizer._comm_buffer_list: + for param_name, grad_view in buffer._sharding_param_grad_view.items(): + struct_name = static_to_struct_mapping[param_name] + if not any( + struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names + ): + continue + param_begin = grad_view._param_begin + param_end = grad_view._param_end + shape = (param_end - param_begin,) + no_need_master_weights = grad_view._param.dtype == paddle.float32 + + if shape[0] > 0: + yield param_name, shape, no_need_master_weights + + master_weights, state_dict = init_param_optimizer_states(buffer_params()) + state_dict["master_weights"] = master_weights + state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + optimizer.set_state_dict(state_dict) + return + param_list = [] + for param in optimizer._parameter_list: + param_name = param.name + struct_name = static_to_struct_mapping[param_name] + if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names): + continue + param_list.append(param) + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 88da91adced6..35d387be80a2 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -407,6 +407,12 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). + save_checkpoint_format (`str`, *optional*): + Specifies the format for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured. + load_checkpoint_format (`str`, *optional*): + Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured. + aoa_config (`Optional[dict[str, list[str]]]`, *optional*): + The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ output_dir: str = field( @@ -941,6 +947,29 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use async_save instead of paddle.save."}, ) + save_checkpoint_format: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Specifies the format used to save checkpoints. " + "Available options: 'sharding_io', 'unified_checkpoint', " + "'flex_checkpoint'." + "This setting is ignored if the corresponding switch is configured." + ) + }, + ) + + load_checkpoint_format: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Specifies the format used to load checkpoints. " + "Available options: 'sharding_io', 'unified_checkpoint', " + "'flex_checkpoint'." + "This setting is ignored if the corresponding switch is configured." + ) + }, + ) ordered_save_group_size: int = field( default=0, metadata={ @@ -1106,6 +1135,13 @@ class TrainingArguments: default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"} ) + aoa_config: Optional[dict[str, list[str]]] = field( + default=None, + metadata={ + "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." + }, + ) + def __post_init__(self): world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): @@ -1205,7 +1241,8 @@ def __post_init__(self): raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") self._post_init_parallel_degree() - + self._post_init_save_checkpoint_format() + self._post_init_load_checkpoint_format() if self.to_static: assert world_size == 1 or self.enable_auto_parallel, ( "It's not supported for training in static mode except the following cases : " @@ -1859,7 +1896,10 @@ def is_context_parallel_supported(): else: if world_size > 1: if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(): - if self.unified_checkpoint: + if self.save_checkpoint_format in [ + "unified_checkpoint", + "flex_checkpoint", + ] or self.load_checkpoint_format in ["unified_checkpoint", "flex_checkpoint"]: # DP use hybrid group strategy = fleet.DistributedStrategy() fleet.init(is_collective=True, strategy=strategy) @@ -1867,16 +1907,20 @@ def is_context_parallel_supported(): paddle.distributed.init_parallel_env() if ( - self.unified_checkpoint + ( + self.save_checkpoint_format == "unified_checkpoint" + or self.load_checkpoint_format == "unified_checkpoint" + ) and self.sharding_parallel_degree > 0 and ShardingOption.FULL_SHARD in self.sharding ): logger.warning( - "Unified checkpoint currently do not support sharding stage3, set `unified_checkpoint` to False." + "Unified checkpoint currently do not support sharding stage3, disabling unified_checkpoint format." ) - self.unified_checkpoint = False + self.save_checkpoint_format = None + self.load_checkpoint_format = None - if self.unified_checkpoint: + if self.save_checkpoint_format == "unified_checkpoint" or self.load_checkpoint_format == "unified_checkpoint": unified_checkpoint_config = set(self.unified_checkpoint_config.split(" ")) if sys.platform.startswith("win") and "async_save" in self.unified_checkpoint_config: raise ValueError("Currently do not support asynchronous saving for Windows system!") @@ -2129,6 +2173,30 @@ def _post_init_parallel_degree(self): if self.use_hybrid_parallel and self.enable_auto_parallel: self.use_hybrid_parallel = False + def _post_init_save_checkpoint_format(self): + if self.save_checkpoint_format: + valid_modes = ["unified_checkpoint", "sharding_io", "flex_checkpoint"] + assert ( + self.save_checkpoint_format in valid_modes + ), f"Invalid save_checkpoint_format: {self.save_checkpoint_format}, Only these formats are allowed: {valid_modes}." + else: + if self.unified_checkpoint: + self.save_checkpoint_format = "unified_checkpoint" + elif self.save_sharded_model: + self.save_checkpoint_format = "sharding_io" + + def _post_init_load_checkpoint_format(self): + if self.load_checkpoint_format: + valid_modes = ["unified_checkpoint", "sharding_io", "flex_checkpoint"] + assert ( + self.load_checkpoint_format in valid_modes + ), f"Invalid load_checkpoint_format: {self.load_checkpoint_format}, Only these formats are allowed: {valid_modes}." + else: + if self.unified_checkpoint: + self.load_checkpoint_format = "unified_checkpoint" + elif self.load_sharded_model: + self.load_checkpoint_format = "sharding_io" + def add_moe_comm_group(self): hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs hcg = fleet.get_hybrid_communicate_group() @@ -2457,6 +2525,8 @@ def should_save_model_state(self): return True elif self.enable_auto_parallel: return True + elif self.save_checkpoint_format == "flex_checkpoint": + return False elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) @@ -2475,14 +2545,16 @@ def should_save_sharding_stage1_model(self): if self.enable_auto_parallel: return False return ( - ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model + ShardingOption.SHARD_OP in self.sharding + and self.sharding_parallel_degree > 1 + and self.save_checkpoint_format == "sharding_io" ) @property def should_load_sharding_stage1_model(self): if self.enable_auto_parallel: return False - return self.load_sharded_model + return self.load_checkpoint_format == "sharding_io" @property def should_load_dataset(self): diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..18c48f5470a8 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -30,6 +30,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -1367,7 +1370,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): - from paddlenlp.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( @@ -1995,6 +1997,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..b478b253835f 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -3167,6 +3167,19 @@ def state_dict(self, *args, **kwargs): return state_dict + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def set_state_dict(self, state_dict, *args, **kwargs): if self._single_to_pp_mapping is None: self._set_pipeline_name_mapping()