159159 get_last_checkpoint ,
160160 get_scheduler ,
161161 has_length ,
162+ init_optimizer ,
162163 set_seed ,
163164 should_skip_data ,
164165 speed_metrics ,
165166 split_parallel_config ,
166167)
167168from .training_args import TrainingArguments
168- from .unified_checkpoint import UnifiedCheckpointHandler
169169from .utils import reshard as reshard_util
170170from .utils .async_save import AsyncSaver
171171
197197if is_datasets_available ():
198198 import datasets
199199
200-
201200try :
202201 from paddle .distributed .fleet .utils import mix_precision_utils
203202except :
@@ -914,7 +913,7 @@ def train(
914913 self ._memory_tracker .start ()
915914
916915 if not self .args .enable_auto_parallel :
917- if not self .args .should_load_sharding_stage1_model :
916+ if not self .args .should_load_sharding_stage1_model and not self . args . using_flex_checkpoint :
918917 self ._load_from_checkpoint (resume_from_checkpoint )
919918
920919 if self .args .should_load_sharding_stage1_model :
@@ -934,14 +933,32 @@ def train(
934933 if delay_optimizer_creation :
935934 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
936935 self ._load_optimizer_and_scheduler (resume_from_checkpoint )
937- else :
936+ elif not self . args . using_flex_checkpoint :
938937 model = self ._wrap_model (self .model_wrapped )
939938 # for the rest of this function `model` is the outside model, whether it was wrapped or not
940939 if model is not self .model :
941940 self .model_wrapped = model
942941 if delay_optimizer_creation :
943942 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
944943 self ._load_optimizer_and_scheduler (resume_from_checkpoint )
944+ else :
945+ assert self .args .using_flex_checkpoint , "default using flex_checkpoint!"
946+
947+ model = self ._wrap_model (self .model_wrapped )
948+ if model is not self .model :
949+ self .model_wrapped = model
950+
951+ if delay_optimizer_creation :
952+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
953+
954+ if resume_from_checkpoint is not None :
955+ model_sharded_state_dict = self .model .sharded_state_dict ()
956+ self .optimizer .sharded_state_dict (model_sharded_state_dict )
957+ init_optimizer (self .optimizer )
958+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
959+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
960+ dist .load_state_dict (sharded_state_dict , resume_from_checkpoint )
961+ self ._load_scheduler (resume_from_checkpoint )
945962 else :
946963 model = self .model_wrapped
947964 if delay_optimizer_creation :
@@ -1342,6 +1359,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13421359 logger .warning (
13431360 f"optimizer not run, scale_before: { scale_before_value [0 ]} , scale_after: { scale_after_value [0 ]} "
13441361 )
1362+
13451363 elif isinstance (self .optimizer , HybridParallelOptimizer ):
13461364 self .optimizer ._step (parameters_list )
13471365 else :
@@ -1968,7 +1986,6 @@ def apply_decay_param_fun(x):
19681986 grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ) if self .args .max_grad_norm > 0 else None ,
19691987 ** optimizer_kwargs ,
19701988 )
1971-
19721989 return self .optimizer
19731990
19741991 def _apply_to_optimizer (self , action ):
@@ -2234,7 +2251,6 @@ def _wrap_model(self, model, training=True):
22342251 mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype )
22352252 assert self .optimizer is not None , "optimizer is empty!"
22362253 self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2237-
22382254 # Pipeline mode
22392255 if in_pipeline_parallel_mode :
22402256 if self .args .amp_master_grad :
@@ -2284,15 +2300,13 @@ def get_expected_keys(inputs, keys):
22842300 if self .args .amp_master_grad :
22852301 self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
22862302 self .optimizer = fleet .distributed_optimizer (self .optimizer )
2287-
22882303 if (
22892304 hasattr (self .args , "enable_sharding_comm_overlap" )
22902305 and self .args .enable_sharding_comm_overlap
22912306 and self .args .unified_checkpoint
22922307 and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
22932308 ):
22942309 model .register_sharding_comm_overlap_hook (self .optimizer )
2295-
22962310 # No pipeline mode, sharding only
22972311 if not in_pipeline_parallel_mode and in_sharding_parallel_mode :
22982312 # Sharded DDP!
@@ -2306,7 +2320,6 @@ def get_expected_keys(inputs, keys):
23062320 model = paddle .distributed .fleet .meta_parallel .TensorParallel (
23072321 model , hcg , strategy = fleet .fleet ._user_defined_strategy
23082322 )
2309-
23102323 if ShardingOption .SHARD_OP in self .args .sharding :
23112324 if self .args .amp_master_grad :
23122325 mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype ) # return value has no use
@@ -2348,6 +2361,7 @@ def get_expected_keys(inputs, keys):
23482361 offload = cpu_offload ,
23492362 ** extra_kwargs ,
23502363 )
2364+
23512365 if ShardingOption .SHARD_GRAD_OP in self .args .sharding and self .args .amp_master_grad :
23522366 assert hasattr (optimizer , "use_main_grad" ), (
23532367 "Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2373,7 +2387,6 @@ def get_expected_keys(inputs, keys):
23732387 if self .args .amp_master_grad :
23742388 self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
23752389 self .optimizer = fleet .distributed_optimizer (self .optimizer )
2376-
23772390 # stage1 has v1 and v2 version
23782391 if in_sharding_parallel_mode and ShardingOption .SHARD_OP in self .args .sharding :
23792392 if "split_param" in self .args .sharding_parallel_config :
@@ -2388,7 +2401,6 @@ def get_expected_keys(inputs, keys):
23882401 and "enable_stage1_broadcast_overlap" in self .args .sharding_parallel_config
23892402 ):
23902403 self .optimizer ._set_broadcast_overlap (True , model )
2391-
23922404 return model
23932405
23942406 def _prepare_input (self , data : Union [paddle .Tensor , Any ]) -> Union [paddle .Tensor , Any ]:
@@ -2700,6 +2712,10 @@ def _save_checkpoint(self, model, metrics=None):
27002712 else :
27012713 self .save_model (output_dir )
27022714
2715+ model_sharded_state_dict = self .model .sharded_state_dict ()
2716+ if self .args .using_flex_checkpoint :
2717+ os .makedirs (output_dir , exist_ok = True )
2718+
27032719 # Determine the new best metric / best model checkpoint
27042720 if metrics is not None and self .args .metric_for_best_model is not None :
27052721 metric_to_check = self .args .metric_for_best_model
@@ -2763,23 +2779,32 @@ def _save_checkpoint(self, model, metrics=None):
27632779 signal_dir ,
27642780 )
27652781 else :
2766- if self .dp_group .rank > 0 : # this should only work for MoE saving
2767- self ._save_ckpt_func (
2768- self ._filter_moe_no_sync_optimizer_params (),
2769- os .path .join (output_dir , optimizer_name ),
2770- saved_signal_path ,
2771- )
2772-
2773- else :
2774- state_dict = self .optimizer .state_dict ()
2775- save_path = os .path .join (output_dir , optimizer_name )
2776- if self .args .use_async_save :
2777- assert not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )), "Dont support FLAG_LLM_PDC"
2778- self ._async_optimizer_saver .run (
2779- state_dict , save_path , saved_signal_path = saved_signal_path
2782+ if not self .args .using_flex_checkpoint :
2783+ if self .dp_group .rank > 0 : # this should only work for MoE saving
2784+ self ._save_ckpt_func (
2785+ self ._filter_moe_no_sync_optimizer_params (),
2786+ os .path .join (output_dir , optimizer_name ),
2787+ saved_signal_path ,
27802788 )
2789+
27812790 else :
2782- self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2791+ state_dict = self .optimizer .state_dict ()
2792+ save_path = os .path .join (output_dir , optimizer_name )
2793+ if self .args .use_async_save :
2794+ assert not strtobool (
2795+ os .getenv ("FLAG_LLM_PDC" , "False" )
2796+ ), "Dont support FLAG_LLM_PDC"
2797+ self ._async_optimizer_saver .run (
2798+ state_dict , save_path , saved_signal_path = saved_signal_path
2799+ )
2800+ else :
2801+ self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2802+ else :
2803+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2804+ dist .save_state_dict (
2805+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2806+ output_dir ,
2807+ )
27832808 else :
27842809 if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
27852810 global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2800,7 +2825,7 @@ def _save_checkpoint(self, model, metrics=None):
28002825 output_dir ,
28012826 signal_dir ,
28022827 )
2803- else :
2828+ elif not self . args . using_flex_checkpoint :
28042829 if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
28052830 self ._save_ckpt_func (
28062831 self ._filter_moe_no_sync_optimizer_params (),
@@ -2814,6 +2839,13 @@ def _save_checkpoint(self, model, metrics=None):
28142839 saved_signal_path ,
28152840 )
28162841
2842+ else :
2843+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2844+ dist .save_state_dict (
2845+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2846+ output_dir ,
2847+ )
2848+
28172849 # FIXME: maybe only save one copy
28182850 paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
28192851
@@ -3077,6 +3109,24 @@ def _save(
30773109 with open (path , "w" ) as f :
30783110 json .dump (model_meta , f )
30793111
3112+ def _load_scheduler (self , checkpoint ):
3113+ if checkpoint is None :
3114+ self .runtime_timer .stop ()
3115+ return
3116+
3117+ if not self .args .ignore_load_lr_and_optim :
3118+ if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3119+ self .lr_scheduler .set_state_dict (
3120+ paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3121+ )
3122+ else :
3123+ raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3124+
3125+ if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3126+ self .scaler .load_state_dict (
3127+ paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3128+ )
3129+
30803130 def _load_optimizer_and_scheduler (self , checkpoint ):
30813131 """If optimizer and scheduler states exist, load them."""
30823132 self .runtime_timer .start ("checkpoint loading time" )
@@ -3118,6 +3168,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31183168 and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
31193169 ):
31203170 model = self .model_wrapped
3171+
31213172 opt_state_dict = self .unified_checkpoint_handler .load_unified_optimizer (
31223173 model = model ,
31233174 optimizer = self .optimizer ,
@@ -3149,18 +3200,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31493200 optimizer_name = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
31503201 raise ValueError (f"optimizer-state-dict not found, opt: { os .path .join (checkpoint , optimizer_name )} ." )
31513202
3152- if not self .args .ignore_load_lr_and_optim :
3153- if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3154- self .lr_scheduler .set_state_dict (
3155- paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3156- )
3157- else :
3158- raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3159-
3160- if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3161- self .scaler .load_state_dict (
3162- paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3163- )
3203+ self ._load_scheduler (checkpoint )
31643204
31653205 if self .args .offload_optim :
31663206 logger .info ("Offloading optimizer state..." )
0 commit comments