@@ -269,7 +269,7 @@ def get_group_ids(self):
269269
270270
271271class ShardingIO :
272- def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False ):
272+ def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False , is_ema = False ):
273273 self .args = args
274274 self .model = model
275275 self .optimizer = optimizer
@@ -281,6 +281,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
281281
282282 self .remap_parameter_name = remap_parameter_name
283283 self .remapper = None
284+ self .is_ema = is_ema
284285
285286 def _get_remapper (self , checkpoint ):
286287 if not self .remap_parameter_name :
@@ -351,7 +352,9 @@ def load_model_slices():
351352 structure_name_map = split_structure_name_mapping (structure_name_map , group_getter )
352353 for i in range (self .args .sharding_parallel_rank , sharding_degree , cur_sharding_degree ):
353354 tmp = self ._load_one_state_dict_from_checkpoint (
354- checkpoint , base_weight_name , self .args .sharded_name_suffix (i , j )
355+ checkpoint ,
356+ base_weight_name ,
357+ self .args .sharded_name_suffix (i , j , sharding_parallel_degree = sharding_degree ),
355358 )
356359 tmp = split_model_state (tmp , group_getter )
357360 for gid in gids :
@@ -399,24 +402,33 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
399402 """
400403 load state_dict of one shard from_checkpoint, Only load model state dict.
401404 """
405+ if self .is_ema :
406+ base_weight_name = base_weight_name .replace ("model_state" , "ema" ).replace ("pdparams" , "pdopt" )
402407 file_path = os .path .join (resume_from_checkpoint , _add_variant (base_weight_name , weight_name_suffix ))
403408 if not os .path .isfile (file_path ):
404409 raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } , no { file_path } " )
405410
406411 logger .info (f"Loading model from { resume_from_checkpoint } ." )
407412 # We load the model state dict on the CPU to avoid an OOM error.
408413 state_dict = paddle .load (file_path , return_numpy = True )
414+ if self .is_ema :
415+ state_dict .pop ("master_weights" , None )
409416 state_dict = self ._remap_parameter_name (resume_from_checkpoint , state_dict , is_opt = False )
410417 return state_dict
411418
412419 def _load_optimizer_state_of_one_shard (self , checkpoint , base_opt_name , optimizer_name_suffix , group_getter = None ):
420+ if self .is_ema :
421+ base_opt_name = base_opt_name .replace ("optimizer" , "ema" )
413422 optimizer_name = _add_variant (base_opt_name , optimizer_name_suffix )
414423 path = os .path .join (checkpoint , optimizer_name )
415424 logger .info (f"load optimizer state from { path } " )
416425 if os .path .isfile (path ):
426+ opt_state = paddleformers_load (path , map_location = "cpu" )
427+ if self .is_ema :
428+ opt_state = {"master_weights" : opt_state .get ("master_weights" , {})}
417429 return self ._remap_parameter_name (
418430 checkpoint ,
419- self ._modify_ckpt_for_compatibility (paddleformers_load ( path , map_location = "cpu" ) ),
431+ self ._modify_ckpt_for_compatibility (opt_state ),
420432 is_opt = True ,
421433 )
422434 logger .info (f"{ path } not exists" )
@@ -449,9 +461,12 @@ def _need_reshard(self, checkpoint):
449461 if sharding_strategy == SHARDING_STRATEGY_V1 :
450462 param2rank = sharding_meta ["param2rank" ]
451463 optimizer = unwrap_optimizer (self .optimizer , DygraphShardingOptimizer )
452- assert optimizer
453- if len (param2rank ) == 0 :
454- logger .warning ("The param2rank is empty. Force reshard would be performed." )
464+ if self .args .sharding_parallel_degree > 1 :
465+ assert optimizer is not None
466+ else :
467+ assert optimizer is None
468+ if len (param2rank ) == 0 or optimizer is None :
469+ logger .warning ("The param2rank is empty or sharding degree is 1. Force reshard would be performed." )
455470 return True
456471 assert len (param2rank ) == len (optimizer ._param2rank )
457472 for (k , v ) in param2rank .items ():
0 commit comments