diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index fdce86316878..9cec6c16bd13 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -159,6 +159,7 @@ get_last_checkpoint, get_scheduler, has_length, + init_optimizer, set_seed, should_skip_data, speed_metrics, @@ -197,7 +198,6 @@ if is_datasets_available(): import datasets - try: from paddle.distributed.fleet.utils import mix_precision_utils except: @@ -914,7 +914,7 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model: + if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint: self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: @@ -934,7 +934,7 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) - else: + elif not self.args.using_flex_checkpoint: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: @@ -942,6 +942,24 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + else: + assert self.args.using_flex_checkpoint, "default using flex_checkpoint!" + + model = self._wrap_model(self.model_wrapped) + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if resume_from_checkpoint is not None: + model_sharded_state_dict = self.model.sharded_state_dict() + self.optimizer.sharded_state_dict(model_sharded_state_dict) + init_optimizer(self.optimizer) + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} + dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config) + self._load_scheduler(resume_from_checkpoint) else: model = self.model_wrapped if delay_optimizer_creation: @@ -1342,6 +1360,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): logger.warning( f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" ) + elif isinstance(self.optimizer, HybridParallelOptimizer): self.optimizer._step(parameters_list) else: @@ -1968,7 +1987,6 @@ def apply_decay_param_fun(x): grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None, **optimizer_kwargs, ) - return self.optimizer def _apply_to_optimizer(self, action): @@ -2234,7 +2252,6 @@ def _wrap_model(self, model, training=True): mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) assert self.optimizer is not None, "optimizer is empty!" self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - # Pipeline mode if in_pipeline_parallel_mode: if self.args.amp_master_grad: @@ -2284,7 +2301,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap @@ -2292,7 +2308,6 @@ def get_expected_keys(inputs, keys): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) - # No pipeline mode, sharding only if not in_pipeline_parallel_mode and in_sharding_parallel_mode: # Sharded DDP! @@ -2306,7 +2321,6 @@ def get_expected_keys(inputs, keys): model = paddle.distributed.fleet.meta_parallel.TensorParallel( model, hcg, strategy=fleet.fleet._user_defined_strategy ) - if ShardingOption.SHARD_OP in self.args.sharding: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use @@ -2348,6 +2362,7 @@ def get_expected_keys(inputs, keys): offload=cpu_offload, **extra_kwargs, ) + if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad: assert hasattr(optimizer, "use_main_grad"), ( "Current installed paddle doesn't support sharding stage 2 with main grad, " @@ -2373,7 +2388,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: if "split_param" in self.args.sharding_parallel_config: @@ -2388,7 +2402,6 @@ def get_expected_keys(inputs, keys): and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config ): self.optimizer._set_broadcast_overlap(True, model) - return model def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]: @@ -2700,6 +2713,10 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) + model_sharded_state_dict = self.model.sharded_state_dict() + if self.args.using_flex_checkpoint: + os.makedirs(output_dir, exist_ok=True) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -2763,23 +2780,32 @@ def _save_checkpoint(self, model, metrics=None): signal_dir, ) else: - if self.dp_group.rank > 0: # this should only work for MoE saving - self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), - os.path.join(output_dir, optimizer_name), - saved_signal_path, - ) - - else: - state_dict = self.optimizer.state_dict() - save_path = os.path.join(output_dir, optimizer_name) - if self.args.use_async_save: - assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - self._async_optimizer_saver.run( - state_dict, save_path, saved_signal_path=saved_signal_path + if not self.args.using_flex_checkpoint: + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: - self._save_ckpt_func(state_dict, save_path, saved_signal_path) + state_dict = self.optimizer.state_dict() + save_path = os.path.join(output_dir, optimizer_name) + if self.args.use_async_save: + assert not strtobool( + os.getenv("FLAG_LLM_PDC", "False") + ), "Dont support FLAG_LLM_PDC" + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path + ) + else: + self._save_ckpt_func(state_dict, save_path, saved_signal_path) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2800,7 +2826,7 @@ def _save_checkpoint(self, model, metrics=None): output_dir, signal_dir, ) - else: + elif not self.args.using_flex_checkpoint: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( self._filter_moe_no_sync_optimizer_params(), @@ -2814,6 +2840,13 @@ def _save_checkpoint(self, model, metrics=None): saved_signal_path, ) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) + # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -3077,6 +3110,24 @@ def _save( with open(path, "w") as f: json.dump(model_meta, f) + def _load_scheduler(self, checkpoint): + if checkpoint is None: + self.runtime_timer.stop() + return + + if not self.args.ignore_load_lr_and_optim: + if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + self.lr_scheduler.set_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) + ) + else: + raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") + + if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) + ) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" self.runtime_timer.start("checkpoint loading time") @@ -3118,6 +3169,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model = self.model_wrapped + opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, optimizer=self.optimizer, @@ -3149,18 +3201,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.") - if not self.args.ignore_load_lr_and_optim: - if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.lr_scheduler.set_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) - ) - else: - raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") - - if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) - ) + self._load_scheduler(checkpoint) if self.args.offload_optim: logger.info("Offloading optimizer state...") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index d8d88d1cd4ad..30f93e337b51 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -53,6 +53,21 @@ from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool from .utils.helper import distributed_file +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) +except: + DygraphShardingOptimizer = None + + __all__ = [ "TrainOutput", "PredictionOutput", @@ -1357,3 +1372,63 @@ def set_comm_config(configs, attr, dict_obj): set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None)) set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None)) return strategy + + +def init_optimizer(optimizer): + """ + Initialize the optimizer's states according to its type. + + For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. + For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. + For other cases, initializes accumulators for all parameters. + + Args: + optimizer: The optimizer instance to be initialized. + """ + if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer): + local_params = optimizer._rank2params[optimizer._sharding_rank] + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params) + return + + elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2): + + def init_param_optimizer_states(param_iter): + master_weights = {} + state_dict = {} + moments = ("moment1_0", "moment2_0") + betas = ("beta1_pow_acc_0", "beta2_pow_acc_0") + for static_name, shape, no_need_master_weights in param_iter: + if not no_need_master_weights: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + prefix = f"{static_name}_fp32_master_0_" + else: + prefix = f"{static_name}_" + + for moment in moments: + key = f"{prefix}{moment}" + state_dict[key] = paddle.zeros(shape, dtype="float32") + for beta in betas: + key = f"{prefix}{beta}" + state_dict[key] = paddle.zeros((1,), dtype="float32") + return master_weights, state_dict + + def buffer_params(): + for buffer in optimizer._comm_buffer_list: + for param_name, grad_view in buffer._sharding_param_grad_view.items(): + param_begin = grad_view._param_begin + param_end = grad_view._param_end + shape = (param_end - param_begin,) + no_need_master_weights = grad_view._param.dtype == paddle.float32 + + if shape[0] > 0: + yield param_name, shape, no_need_master_weights + + master_weights, state_dict = init_param_optimizer_states(buffer_params()) + state_dict["master_weights"] = master_weights + state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + + optimizer.set_state_dict(state_dict) + return + optimizer._create_accumulators( + paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list + ) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 30a3e7b3dc62..0a47988b5cd9 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -407,6 +407,10 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). + using_flex_checkpoint(`bool`, *optional*): + Whether to use FlexCheckpoint for save and load. Default is False. + aoa_config (`Optional[dict[str, list[str]]]`, *optional*): + The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ output_dir: str = field( @@ -921,6 +925,10 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use async_save instead of paddle.save."}, ) + using_flex_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "Whether use FlexCheckpoint."}, + ) ordered_save_group_size: int = field( default=0, metadata={ @@ -1082,6 +1090,13 @@ class TrainingArguments: default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"} ) + aoa_config: Optional[dict[str, list[str]]] = field( + default=None, + metadata={ + "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." + }, + ) + def __post_init__(self): world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): @@ -2355,6 +2370,8 @@ def should_save_model_state(self): return True elif self.enable_auto_parallel: return True + elif self.using_flex_checkpoint: + return False elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..18c48f5470a8 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -30,6 +30,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -1367,7 +1370,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): - from paddlenlp.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( @@ -1995,6 +1997,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..60fad0cec54b 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -3167,6 +3167,19 @@ def state_dict(self, *args, **kwargs): return state_dict + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.tensor_key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def set_state_dict(self, state_dict, *args, **kwargs): if self._single_to_pp_mapping is None: self._set_pipeline_name_mapping()