161161 get_last_checkpoint ,
162162 get_scheduler ,
163163 has_length ,
164+ init_optimizer ,
164165 set_seed ,
165166 should_skip_data ,
166167 speed_metrics ,
199200if is_datasets_available ():
200201 import datasets
201202
202-
203203try :
204204 from paddle .distributed .fleet .utils import mix_precision_utils
205205except :
@@ -812,6 +812,10 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
812812 if resume_from_checkpoint is not None :
813813 path = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
814814 path = os .path .join (resume_from_checkpoint , path ).replace ("optimizer" , "ema" )
815+ if self .args .zcc_save_ema_coef is not None and self .sharding_io is not None :
816+ success , err_msg = self .sharding_io .check_same_strategy (resume_from_checkpoint )
817+ else :
818+ success , err_msg = True , None
815819 if self .args .zcc_save_ema_coef is not None and self .sharding_io is not None :
816820 success , err_msg = self .sharding_io .check_same_strategy (resume_from_checkpoint )
817821 else :
@@ -822,6 +826,11 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
822826 self .zcc_manager .set_ema_state_dict (path )
823827 else :
824828 logger .info (f"ZCC EMA does not load { path } because { err_msg } " )
829+ if success :
830+ logger .info (f"ZCC EMA load from { path } " )
831+ self .zcc_manager .set_ema_state_dict (path )
832+ else :
833+ logger .info (f"ZCC EMA does not load { path } because { err_msg } " )
825834 else :
826835 logger .info (f"ZCC EMA state dict not found, in: { path } " )
827836
@@ -929,13 +938,13 @@ def train(
929938 self ._memory_tracker .start ()
930939
931940 if not self .args .enable_auto_parallel :
932- if not self .args .should_load_sharding_stage1_model :
941+ if not self .args .should_load_sharding_stage1_model and not self . args . load_flex_checkpoint :
933942 self ._load_from_checkpoint (resume_from_checkpoint )
934943
935944 if self .args .should_load_sharding_stage1_model :
936945 model = self ._wrap_model_and_load_sharded_checkpoint (resume_from_checkpoint )
937946
938- elif self .args .should_save_sharding_stage1_model :
947+ elif self .args .should_save_sharding_stage1_model and not self . args . load_flex_checkpoint :
939948 # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940949 # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941950 model = self ._wrap_model (self .model_wrapped )
@@ -949,13 +958,43 @@ def train(
949958 if delay_optimizer_creation :
950959 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
951960 self ._load_optimizer_and_scheduler (resume_from_checkpoint )
961+ elif self .args .load_flex_checkpoint :
962+ model = self ._wrap_model (self .model_wrapped )
963+ if model is not self .model :
964+ self .model_wrapped = model
965+ if delay_optimizer_creation :
966+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
967+
968+ if resume_from_checkpoint is not None :
969+ if not self .args .ignore_load_lr_and_optim :
970+ model_sharded_state_dict = self .model .sharded_state_dict ()
971+ accessible_files = os .listdir (resume_from_checkpoint )
972+ metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
973+ assert len (metadata_files ) == 1 , "Only support one metadata file now."
974+ metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
975+ state_dict_metadata = metadata .state_dict_metadata
976+ init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
977+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
978+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
979+ dist .load_state_dict (
980+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config
981+ )
982+ self ._load_scheduler (resume_from_checkpoint )
983+ else :
984+ model_sharded_state_dict = self .model .sharded_state_dict ()
985+ sharded_state_dict = model_sharded_state_dict
986+ dist .load_state_dict (
987+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config
988+ )
952989 else :
953990 model = self ._wrap_model (self .model_wrapped )
954991 # for the rest of this function `model` is the outside model, whether it was wrapped or not
955992 if model is not self .model :
956993 self .model_wrapped = model
994+
957995 if delay_optimizer_creation :
958996 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
997+
959998 self ._load_optimizer_and_scheduler (resume_from_checkpoint )
960999 else :
9611000 model = self .model_wrapped
@@ -1357,6 +1396,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13571396 logger .warning (
13581397 f"optimizer not run, scale_before: { scale_before_value [0 ]} , scale_after: { scale_after_value [0 ]} "
13591398 )
1399+
13601400 elif isinstance (self .optimizer , HybridParallelOptimizer ):
13611401 self .optimizer ._step (parameters_list )
13621402 else :
@@ -1993,7 +2033,6 @@ def apply_decay_param_fun(x):
19932033 grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ) if self .args .max_grad_norm > 0 else None ,
19942034 ** optimizer_kwargs ,
19952035 )
1996-
19972036 return self .optimizer
19982037
19992038 def _apply_to_optimizer (self , action ):
@@ -2033,6 +2072,13 @@ def _load_rng_state(self, checkpoint):
20332072 return
20342073
20352074 rng_file = os .path .join (checkpoint , f"rng_state_{ dist .get_rank ()} .pth" )
2075+ if not os .path .isfile (rng_file ):
2076+ logger .info (
2077+ "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
2078+ "fashion, reproducibility is not guaranteed."
2079+ )
2080+ return
2081+ rng_file = os .path .join (checkpoint , f"rng_state_{ dist .get_rank ()} .pth" )
20362082 if not os .path .isfile (rng_file ):
20372083 logger .info (
20382084 "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
@@ -2238,7 +2284,6 @@ def _wrap_model(self, model, training=True):
22382284 mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype )
22392285 assert self .optimizer is not None , "optimizer is empty!"
22402286 self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2241-
22422287 # Pipeline mode
22432288 if in_pipeline_parallel_mode :
22442289 if self .args .amp_master_grad :
@@ -2288,15 +2333,13 @@ def get_expected_keys(inputs, keys):
22882333 if self .args .amp_master_grad :
22892334 self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
22902335 self .optimizer = fleet .distributed_optimizer (self .optimizer )
2291-
22922336 if (
22932337 hasattr (self .args , "enable_sharding_comm_overlap" )
22942338 and self .args .enable_sharding_comm_overlap
22952339 and self .args .unified_checkpoint
22962340 and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
22972341 ):
22982342 model .register_sharding_comm_overlap_hook (self .optimizer )
2299-
23002343 # No pipeline mode, sharding only
23012344 if not in_pipeline_parallel_mode and in_sharding_parallel_mode :
23022345 # Sharded DDP!
@@ -2310,7 +2353,6 @@ def get_expected_keys(inputs, keys):
23102353 model = paddle .distributed .fleet .meta_parallel .TensorParallel (
23112354 model , hcg , strategy = fleet .fleet ._user_defined_strategy
23122355 )
2313-
23142356 if ShardingOption .SHARD_OP in self .args .sharding :
23152357 if self .args .amp_master_grad :
23162358 mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype ) # return value has no use
@@ -2352,6 +2394,7 @@ def get_expected_keys(inputs, keys):
23522394 offload = cpu_offload ,
23532395 ** extra_kwargs ,
23542396 )
2397+
23552398 if ShardingOption .SHARD_GRAD_OP in self .args .sharding and self .args .amp_master_grad :
23562399 assert hasattr (optimizer , "use_main_grad" ), (
23572400 "Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2377,7 +2420,6 @@ def get_expected_keys(inputs, keys):
23772420 if self .args .amp_master_grad :
23782421 self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
23792422 self .optimizer = fleet .distributed_optimizer (self .optimizer )
2380-
23812423 # stage1 has v1 and v2 version
23822424 if in_sharding_parallel_mode and ShardingOption .SHARD_OP in self .args .sharding :
23832425 if "split_param" in self .args .sharding_parallel_config :
@@ -2720,6 +2762,10 @@ def _save_checkpoint(self, model, metrics=None):
27202762 else :
27212763 self .save_model (output_dir )
27222764
2765+ if self .args .save_flex_checkpoint :
2766+ model_sharded_state_dict = self .model .sharded_state_dict ()
2767+ os .makedirs (output_dir , exist_ok = True )
2768+
27232769 # Determine the new best metric / best model checkpoint
27242770 if metrics is not None and self .args .metric_for_best_model is not None :
27252771 metric_to_check = self .args .metric_for_best_model
@@ -2779,23 +2825,38 @@ def _save_checkpoint(self, model, metrics=None):
27792825 signal_dir ,
27802826 )
27812827 else :
2782- if self .dp_group . rank > 0 : # this should only work for MoE saving
2783- self ._save_ckpt_func (
2784- self . _filter_moe_no_sync_optimizer_params (),
2785- os . path . join ( output_dir , optimizer_name ) ,
2786- saved_signal_path ,
2828+ if self .args . save_flex_checkpoint :
2829+ optimizer_sharded_state_dict = self .optimizer . sharded_state_dict ( model_sharded_state_dict )
2830+ dist . save_state_dict (
2831+ { ** model_sharded_state_dict , ** optimizer_sharded_state_dict } ,
2832+ output_dir ,
27872833 )
2788-
2834+ if self .args .should_save :
2835+ if self .tokenizer is not None and self .args .save_tokenizer :
2836+ self .tokenizer .save_pretrained (output_dir )
2837+ # Good practice: save your training arguments together with the trained model
2838+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
27892839 else :
2790- state_dict = self .optimizer .state_dict ()
2791- save_path = os .path .join (output_dir , optimizer_name )
2792- if self .args .use_async_save :
2793- assert not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )), "Dont support FLAG_LLM_PDC"
2794- self ._async_optimizer_saver .run (
2795- state_dict , save_path , saved_signal_path = saved_signal_path
2840+ if self .dp_group .rank > 0 : # this should only work for MoE saving
2841+ self ._save_ckpt_func (
2842+ self ._filter_moe_no_sync_optimizer_params (),
2843+ os .path .join (output_dir , optimizer_name ),
2844+ saved_signal_path ,
27962845 )
2846+
27972847 else :
2798- self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2848+ state_dict = self .optimizer .state_dict ()
2849+ save_path = os .path .join (output_dir , optimizer_name )
2850+ if self .args .use_async_save :
2851+ assert not strtobool (
2852+ os .getenv ("FLAG_LLM_PDC" , "False" )
2853+ ), "Dont support FLAG_LLM_PDC"
2854+ self ._async_optimizer_saver .run (
2855+ state_dict , save_path , saved_signal_path = saved_signal_path
2856+ )
2857+ else :
2858+ self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2859+
27992860 else :
28002861 if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
28012862 global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2806,7 +2867,12 @@ def _save_checkpoint(self, model, metrics=None):
28062867 or "remove_master_weight" not in self .args .unified_checkpoint_config
28072868 ):
28082869 paddle .save (global_rank , os .path .join (signal_dir , f".master_weight.done.{ global_rank } " ))
2809- if self .args .should_save or self .args .use_expert_parallel :
2870+
2871+ if (
2872+ self .args .should_save
2873+ or self .args .use_expert_parallel
2874+ or (self .args .data_parallel_degree > 1 and self .args .save_flex_checkpoint )
2875+ ):
28102876 if not self .args .use_hybrid_parallel :
28112877 logger .info ("Saving optimizer files." )
28122878 if self .args .unified_checkpoint :
@@ -2816,6 +2882,17 @@ def _save_checkpoint(self, model, metrics=None):
28162882 output_dir ,
28172883 signal_dir ,
28182884 )
2885+ elif self .args .save_flex_checkpoint :
2886+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2887+ dist .save_state_dict (
2888+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2889+ output_dir ,
2890+ )
2891+ if self .args .should_save :
2892+ if self .tokenizer is not None and self .args .save_tokenizer :
2893+ self .tokenizer .save_pretrained (output_dir )
2894+ # Good practice: save your training arguments together with the trained model
2895+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
28192896 else :
28202897 if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
28212898 self ._save_ckpt_func (
@@ -2849,7 +2926,17 @@ def _save_checkpoint(self, model, metrics=None):
28492926
28502927 if self .args .unified_checkpoint and (self .args .offload_optim or self .args .tensorwise_offload_optimizer ):
28512928 self ._offload_optimizer ()
2852-
2929+ else :
2930+ if self .args .save_flex_checkpoint :
2931+ dist .save_state_dict (
2932+ model_sharded_state_dict ,
2933+ output_dir ,
2934+ )
2935+ if self .args .should_save :
2936+ if self .tokenizer is not None and self .args .save_tokenizer :
2937+ self .tokenizer .save_pretrained (output_dir )
2938+ # Good practice: save your training arguments together with the trained model
2939+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
28532940 self .runtime_timer .stop ()
28542941
28552942 # Maybe delete some older checkpoints.
@@ -3064,6 +3151,7 @@ def _save(
30643151 else :
30653152 if isinstance (self .model , PretrainedModel ) and self .args .should_save_sharding_stage1_model :
30663153 config_to_save = None
3154+ self .sharding_io .set_optimizer (self .optimizer )
30673155 state_dict , config_to_save , weight_name_suffix = self .sharding_io .manipulate_state_dict_and_config (
30683156 self .model , merge_tensor_parallel = merge_tensor_parallel
30693157 )
@@ -3093,6 +3181,24 @@ def _save(
30933181 with open (path , "w" ) as f :
30943182 json .dump (model_meta , f )
30953183
3184+ def _load_scheduler (self , checkpoint ):
3185+ if checkpoint is None :
3186+ self .runtime_timer .stop ()
3187+ return
3188+
3189+ if not self .args .ignore_load_lr_and_optim :
3190+ if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3191+ self .lr_scheduler .set_state_dict (
3192+ paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3193+ )
3194+ else :
3195+ raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3196+
3197+ if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3198+ self .scaler .load_state_dict (
3199+ paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3200+ )
3201+
30963202 def _load_optimizer_and_scheduler (self , checkpoint ):
30973203 """If optimizer and scheduler states exist, load them."""
30983204 self .runtime_timer .start ("checkpoint loading time" )
@@ -3134,6 +3240,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31343240 and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
31353241 ):
31363242 model = self .model_wrapped
3243+
31373244 opt_state_dict = self .unified_checkpoint_handler .load_unified_optimizer (
31383245 model = model ,
31393246 optimizer = self .optimizer ,
@@ -3165,18 +3272,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31653272 optimizer_name = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
31663273 raise ValueError (f"optimizer-state-dict not found, opt: { os .path .join (checkpoint , optimizer_name )} ." )
31673274
3168- if not self .args .ignore_load_lr_and_optim :
3169- if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3170- self .lr_scheduler .set_state_dict (
3171- paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3172- )
3173- else :
3174- raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3175-
3176- if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3177- self .scaler .load_state_dict (
3178- paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3179- )
3275+ self ._load_scheduler (checkpoint )
31803276
31813277 if self .args .offload_optim :
31823278 logger .info ("Offloading optimizer state..." )
0 commit comments