@@ -407,6 +407,12 @@ class TrainingArguments:
407407 Whether to release gradients during training. Default is `False`.
408408 ckpt_quant_stage (`str`, *optional*):
409409 Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
410+ save_checkpoint_format (`str`, *optional*):
411+ Specifies the format for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured.
412+ load_checkpoint_format (`str`, *optional*):
413+ Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured.
414+ aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
415+ The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
410416 """
411417
412418 output_dir : str = field (
@@ -941,6 +947,29 @@ class TrainingArguments:
941947 default = False ,
942948 metadata = {"help" : "Whether to use async_save instead of paddle.save." },
943949 )
950+ save_checkpoint_format : Optional [str ] = field (
951+ default = None ,
952+ metadata = {
953+ "help" : (
954+ "Specifies the format used to save checkpoints. "
955+ "Available options: 'sharding_io', 'unified_checkpoint', "
956+ "'flex_checkpoint'."
957+ "This setting is ignored if the corresponding switch is configured."
958+ )
959+ },
960+ )
961+
962+ load_checkpoint_format : Optional [str ] = field (
963+ default = None ,
964+ metadata = {
965+ "help" : (
966+ "Specifies the format used to load checkpoints. "
967+ "Available options: 'sharding_io', 'unified_checkpoint', "
968+ "'flex_checkpoint'."
969+ "This setting is ignored if the corresponding switch is configured."
970+ )
971+ },
972+ )
944973 ordered_save_group_size : int = field (
945974 default = 0 ,
946975 metadata = {
@@ -1106,6 +1135,13 @@ class TrainingArguments:
11061135 default = None , metadata = {"help" : "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置" }
11071136 )
11081137
1138+ aoa_config : Optional [dict [str , list [str ]]] = field (
1139+ default = None ,
1140+ metadata = {
1141+ "help" : "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None."
1142+ },
1143+ )
1144+
11091145 def __post_init__ (self ):
11101146 world_size = paddle .distributed .get_world_size ()
11111147 if in_auto_parallel_align_mode ():
@@ -1210,7 +1246,8 @@ def __post_init__(self):
12101246 raise ValueError ("AdamW Mini currently doesn't support tensor parallelism." )
12111247
12121248 self ._post_init_parallel_degree ()
1213-
1249+ self ._post_init_save_checkpoint_format ()
1250+ self ._post_init_load_checkpoint_format ()
12141251 if self .to_static :
12151252 assert world_size == 1 or self .enable_auto_parallel , (
12161253 "It's not supported for training in static mode except the following cases : "
@@ -1864,24 +1901,31 @@ def is_context_parallel_supported():
18641901 else :
18651902 if world_size > 1 :
18661903 if not paddle .distributed .parallel .parallel_helper ._is_parallel_ctx_initialized ():
1867- if self .unified_checkpoint :
1904+ if self .save_checkpoint_format in [
1905+ "unified_checkpoint" ,
1906+ "flex_checkpoint" ,
1907+ ] or self .load_checkpoint_format in ["unified_checkpoint" , "flex_checkpoint" ]:
18681908 # DP use hybrid group
18691909 strategy = fleet .DistributedStrategy ()
18701910 fleet .init (is_collective = True , strategy = strategy )
18711911 else :
18721912 paddle .distributed .init_parallel_env ()
18731913
18741914 if (
1875- self .unified_checkpoint
1915+ (
1916+ self .save_checkpoint_format == "unified_checkpoint"
1917+ or self .load_checkpoint_format == "unified_checkpoint"
1918+ )
18761919 and self .sharding_parallel_degree > 0
18771920 and ShardingOption .FULL_SHARD in self .sharding
18781921 ):
18791922 logger .warning (
1880- "Unified checkpoint currently do not support sharding stage3, set ` unified_checkpoint` to False ."
1923+ "Unified checkpoint currently do not support sharding stage3, disabling unified_checkpoint format ."
18811924 )
1882- self .unified_checkpoint = False
1925+ self .save_checkpoint_format = None
1926+ self .load_checkpoint_format = None
18831927
1884- if self .unified_checkpoint :
1928+ if self .save_checkpoint_format == " unified_checkpoint" or self . load_checkpoint_format == "unified_checkpoint" :
18851929 unified_checkpoint_config = set (self .unified_checkpoint_config .split (" " ))
18861930 if sys .platform .startswith ("win" ) and "async_save" in self .unified_checkpoint_config :
18871931 raise ValueError ("Currently do not support asynchronous saving for Windows system!" )
@@ -2134,6 +2178,30 @@ def _post_init_parallel_degree(self):
21342178 if self .use_hybrid_parallel and self .enable_auto_parallel :
21352179 self .use_hybrid_parallel = False
21362180
2181+ def _post_init_save_checkpoint_format (self ):
2182+ if self .save_checkpoint_format :
2183+ valid_modes = ["unified_checkpoint" , "sharding_io" , "flex_checkpoint" ]
2184+ assert (
2185+ self .save_checkpoint_format in valid_modes
2186+ ), f"Invalid save_checkpoint_format: { self .save_checkpoint_format } , Only these formats are allowed: { valid_modes } ."
2187+ else :
2188+ if self .unified_checkpoint :
2189+ self .save_checkpoint_format = "unified_checkpoint"
2190+ elif self .save_sharded_model :
2191+ self .save_checkpoint_format = "sharding_io"
2192+
2193+ def _post_init_load_checkpoint_format (self ):
2194+ if self .load_checkpoint_format :
2195+ valid_modes = ["unified_checkpoint" , "sharding_io" , "flex_checkpoint" ]
2196+ assert (
2197+ self .load_checkpoint_format in valid_modes
2198+ ), f"Invalid load_checkpoint_format: { self .load_checkpoint_format } , Only these formats are allowed: { valid_modes } ."
2199+ else :
2200+ if self .unified_checkpoint :
2201+ self .load_checkpoint_format = "unified_checkpoint"
2202+ elif self .load_sharded_model :
2203+ self .load_checkpoint_format = "sharding_io"
2204+
21372205 def add_moe_comm_group (self ):
21382206 hybrid_configs = fleet .fleet ._user_defined_strategy .hybrid_configs
21392207 hcg = fleet .get_hybrid_communicate_group ()
@@ -2462,6 +2530,8 @@ def should_save_model_state(self):
24622530 return True
24632531 elif self .enable_auto_parallel :
24642532 return True
2533+ elif self .save_checkpoint_format == "flex_checkpoint" :
2534+ return False
24652535 elif self .use_hybrid_parallel :
24662536 # save on dataset rank 0
24672537 return self .sharding_parallel_rank == 0 and (self .data_parallel_rank == 0 or self .use_expert_parallel )
@@ -2480,14 +2550,16 @@ def should_save_sharding_stage1_model(self):
24802550 if self .enable_auto_parallel :
24812551 return False
24822552 return (
2483- ShardingOption .SHARD_OP in self .sharding and self .sharding_parallel_degree > 1 and self .save_sharded_model
2553+ ShardingOption .SHARD_OP in self .sharding
2554+ and self .sharding_parallel_degree > 1
2555+ and self .save_checkpoint_format == "sharding_io"
24842556 )
24852557
24862558 @property
24872559 def should_load_sharding_stage1_model (self ):
24882560 if self .enable_auto_parallel :
24892561 return False
2490- return self .load_sharded_model
2562+ return self .load_checkpoint_format == "sharding_io"
24912563
24922564 @property
24932565 def should_load_dataset (self ):
0 commit comments