@@ -223,6 +223,10 @@ def in_auto_parallel_align_mode():
223223
224224__all__ = ["Trainer" ]
225225
226+ FLEX_CKPT_MODEL_STATE_DIR_NAME = "model_state"
227+ FLEX_CKPT_OPT_STATE_DIR_NAME = "optimizer_states"
228+ FLEC_CKPT_MASTER_WEIGHTS_INDEX_NAME = "master_weights"
229+
226230
227231class Trainer :
228232 """
@@ -929,13 +933,13 @@ def train(
929933 self ._memory_tracker .start ()
930934
931935 if not self .args .enable_auto_parallel :
932- if not self .args .should_load_sharding_stage1_model and not self .args .using_flex_checkpoint :
936+ if not self .args .should_load_sharding_stage1_model and not self .args .load_flex_checkpoint :
933937 self ._load_from_checkpoint (resume_from_checkpoint )
934938
935939 if self .args .should_load_sharding_stage1_model :
936940 model = self ._wrap_model_and_load_sharded_checkpoint (resume_from_checkpoint )
937941
938- elif self .args .should_save_sharding_stage1_model :
942+ elif self .args .should_save_sharding_stage1_model and not self . args . load_flex_checkpoint :
939943 # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940944 # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941945 model = self ._wrap_model (self .model_wrapped )
@@ -949,36 +953,44 @@ def train(
949953 if delay_optimizer_creation :
950954 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
951955 self ._load_optimizer_and_scheduler (resume_from_checkpoint )
952- elif not self .args .using_flex_checkpoint :
956+
957+ elif self .args .load_flex_checkpoint :
953958 model = self ._wrap_model (self .model_wrapped )
954- # for the rest of this function `model` is the outside model, whether it was wrapped or not
955959 if model is not self .model :
956960 self .model_wrapped = model
961+
957962 if delay_optimizer_creation :
958963 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
959- self ._load_optimizer_and_scheduler (resume_from_checkpoint )
960- else :
961- assert self .args .using_flex_checkpoint , "default using flex_checkpoint!"
962964
965+ if resume_from_checkpoint is not None :
966+ if not self .args .ignore_load_lr_and_optim :
967+ model_sharded_state_dict = self .model .sharded_state_dict ()
968+ accessible_files = os .listdir (resume_from_checkpoint )
969+ metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
970+ assert len (metadata_files ) == 1 , "Only support one metadata file now."
971+ metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
972+ state_dict_metadata = metadata .state_dict_metadata
973+ init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
974+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
975+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
976+ dist .load_state_dict (
977+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config , offload = False
978+ )
979+ self ._load_scheduler (resume_from_checkpoint )
980+ else :
981+ model_sharded_state_dict = self .model .sharded_state_dict ()
982+ sharded_state_dict = model_sharded_state_dict
983+ dist .load_state_dict (
984+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config
985+ )
986+ else :
963987 model = self ._wrap_model (self .model_wrapped )
988+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
964989 if model is not self .model :
965990 self .model_wrapped = model
966-
967991 if delay_optimizer_creation :
968992 self .create_optimizer_and_scheduler (num_training_steps = max_steps )
969-
970- if resume_from_checkpoint is not None :
971- model_sharded_state_dict = self .model .sharded_state_dict ()
972- accessible_files = os .listdir (resume_from_checkpoint )
973- metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
974- assert len (metadata_files ) == 1 , "Only support one metadata file now."
975- metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
976- state_dict_metadata = metadata .state_dict_metadata
977- init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
978- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
979- sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
980- dist .load_state_dict (sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config )
981- self ._load_scheduler (resume_from_checkpoint )
993+ self ._load_optimizer_and_scheduler (resume_from_checkpoint )
982994 else :
983995 model = self .model_wrapped
984996 if delay_optimizer_creation :
@@ -2738,7 +2750,7 @@ def _save_checkpoint(self, model, metrics=None):
27382750 else :
27392751 self .save_model (output_dir )
27402752
2741- if self .args .using_flex_checkpoint :
2753+ if self .args .save_flex_checkpoint :
27422754 model_sharded_state_dict = self .model .sharded_state_dict ()
27432755 os .makedirs (output_dir , exist_ok = True )
27442756
@@ -2801,7 +2813,18 @@ def _save_checkpoint(self, model, metrics=None):
28012813 signal_dir ,
28022814 )
28032815 else :
2804- if not self .args .using_flex_checkpoint :
2816+ if self .args .save_flex_checkpoint :
2817+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2818+ dist .save_state_dict (
2819+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2820+ output_dir ,
2821+ )
2822+ if self .args .should_save :
2823+ if self .tokenizer is not None and self .args .save_tokenizer :
2824+ self .tokenizer .save_pretrained (output_dir )
2825+ # Good practice: save your training arguments together with the trained model
2826+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2827+ else :
28052828 if self .dp_group .rank > 0 : # this should only work for MoE saving
28062829 self ._save_ckpt_func (
28072830 self ._filter_moe_no_sync_optimizer_params (),
@@ -2821,12 +2844,7 @@ def _save_checkpoint(self, model, metrics=None):
28212844 )
28222845 else :
28232846 self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2824- else :
2825- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2826- dist .save_state_dict (
2827- {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2828- output_dir ,
2829- )
2847+
28302848 else :
28312849 if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
28322850 global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2852,7 +2870,18 @@ def _save_checkpoint(self, model, metrics=None):
28522870 output_dir ,
28532871 signal_dir ,
28542872 )
2855- elif not self .args .using_flex_checkpoint :
2873+ elif self .args .save_flex_checkpoint :
2874+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2875+ dist .save_state_dict (
2876+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2877+ output_dir ,
2878+ )
2879+ if self .args .should_save :
2880+ if self .tokenizer is not None and self .args .save_tokenizer :
2881+ self .tokenizer .save_pretrained (output_dir )
2882+ # Good practice: save your training arguments together with the trained model
2883+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2884+ else :
28562885 if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
28572886 self ._save_ckpt_func (
28582887 self ._filter_moe_no_sync_optimizer_params (),
@@ -2866,13 +2895,6 @@ def _save_checkpoint(self, model, metrics=None):
28662895 saved_signal_path ,
28672896 )
28682897
2869- else :
2870- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2871- dist .save_state_dict (
2872- {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2873- output_dir ,
2874- )
2875-
28762898 # FIXME: maybe only save one copy
28772899 paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
28782900
@@ -2893,6 +2915,18 @@ def _save_checkpoint(self, model, metrics=None):
28932915 if self .args .unified_checkpoint and (self .args .offload_optim or self .args .tensorwise_offload_optimizer ):
28942916 self ._offload_optimizer ()
28952917
2918+ else :
2919+ if self .args .save_flex_checkpoint :
2920+ dist .save_state_dict (
2921+ model_sharded_state_dict ,
2922+ output_dir ,
2923+ )
2924+ if self .args .should_save :
2925+ if self .tokenizer is not None and self .args .save_tokenizer :
2926+ self .tokenizer .save_pretrained (output_dir )
2927+ # Good practice: save your training arguments together with the trained model
2928+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2929+
28962930 self .runtime_timer .stop ()
28972931
28982932 # Maybe delete some older checkpoints.
@@ -3107,6 +3141,7 @@ def _save(
31073141 else :
31083142 if isinstance (self .model , PretrainedModel ) and self .args .should_save_sharding_stage1_model :
31093143 config_to_save = None
3144+ self .sharding_io .set_optimizer (self .optimizer )
31103145 state_dict , config_to_save , weight_name_suffix = self .sharding_io .manipulate_state_dict_and_config (
31113146 self .model , merge_tensor_parallel = merge_tensor_parallel
31123147 )
0 commit comments