@@ -220,36 +220,34 @@ def ema_state_dict(self):
220220 ema_state_dict [k ] = tensor
221221 ema_state_dict_master_weights = {}
222222 for k , meta in self .optimizer_fusion_storage_helper .master_weights_meta .items ():
223- t = self .ema_buffer . _slice (
224- meta [ "start" ] - self . master_min_offset , meta ["end" ] - self .master_min_offset
225- ).clone ()
223+ s = meta [ "start" ] - self .master_min_offset
224+ e = meta ["end" ] - self .master_min_offset
225+ t = self . ema_buffer . _slice ( s , e ).clone ()
226226 t .get_tensor ()._set_dims (meta ["shape" ])
227227 t .name = meta ["name" ]
228228 ema_state_dict_master_weights [k ] = t
229229 ema_state_dict ["master_weights" ] = ema_state_dict_master_weights
230230 return ema_state_dict
231231
232- def load_ema_state_dict (self , path ):
233- with device_guard ("cpu" ):
234- logger .info (f"[ZCC EMA] load state dict from { path } " )
235- state_dict = paddle .load (path )
236- for k , tensor_meta in self .param_fusion_storage_helper .model_weights_metas .items ():
237- logger .info (f"[ZCC EMA] load model weight key={ k } " )
238- start = tensor_meta ["start" ]
239- end = tensor_meta ["end" ]
240- if tensor_meta ["buffer_index" ] not in self .ema_buffer_model_params :
241- continue # non fp32 has no `self.ema_buffer_model_params`
232+ def load_ema_state_dict (self , state_dict ):
233+ for k , tensor_meta in self .param_fusion_storage_helper .model_weights_metas .items ():
234+ logger .info (f"[ZCC EMA] load model weight key={ k } " )
235+ start = tensor_meta ["start" ]
236+ end = tensor_meta ["end" ]
237+ if tensor_meta ["buffer_index" ] not in self .ema_buffer_model_params :
238+ continue # non fp32 has no `self.ema_buffer_model_params`
239+ if k in state_dict :
242240 cpu_buffer = self .ema_buffer_model_params [tensor_meta ["buffer_index" ]]
243241 tensor = state_dict [k ].flatten ()
244242 cpu_buffer [start :end ] = tensor
245243
246- ema_master = state_dict ["master_weights" ]
247- for k , meta in self .optimizer_fusion_storage_helper .master_weights_meta .items ():
248- logger .info (f"[ZCC EMA] load optimizer weight key={ k } " )
249- s = meta ["start" ] - self .master_min_offset
250- e = meta ["end" ] - self .master_min_offset
251- self . ema_buffer [ s : e ] = ema_master [ k ]
252- logger . info ( "[ZCC EMA] done loading" )
244+ ema_master = state_dict ["master_weights" ]
245+ for k , meta in self .optimizer_fusion_storage_helper .master_weights_meta .items ():
246+ logger .info (f"[ZCC EMA] load optimizer weight key={ k } " )
247+ s = meta ["start" ] - self .master_min_offset
248+ e = meta ["end" ] - self .master_min_offset
249+ if k in ema_master : # state-dict is filtered
250+ self . ema_buffer [ s : e ] = ema_master [ k ]. flatten ( )
253251
254252
255253class ParamFusionStorageHelper :
@@ -408,11 +406,6 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
408406 logger .info ("[ZCC manager] Synced checkpoints." )
409407
410408 def on_step_end (self , args , state , control , model , lr_scheduler , optimizer , ** kwargs ):
411- if not isinstance (model , PipelineLayer ):
412- self .manager .zcc_pipeline_hook (0 )
413- # logger.info(
414- # f"check coef: {args.zcc_save_ema_coef} {control.should_save}, {state.global_step}, {self.zcc_ema_interval}"
415- # )
416409 if not control .should_save :
417410 if args .zcc_save_ema_coef is not None and state .global_step % self .zcc_ema_interval == 0 :
418411 self .maybe_update_zcc_worker (args , model , optimizer , state .global_step )
@@ -425,6 +418,8 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
425418 non_cached_objects = (lr_scheduler .state_dict (), state , self .get_rng_states (args ))
426419 self .manager .get_idle_worker_for_saving ((save_infos , non_cached_objects ))
427420 self .runtime_timer .stop ()
421+ if not isinstance (model , PipelineLayer ):
422+ self .manager .zcc_pipeline_hook (0 )
428423
429424 def get_rng_states (self , args ):
430425 if not args .save_rng_states :
@@ -959,7 +954,15 @@ def run(self):
959954 self .optimizer_fusion_storage_helper , self .param_fusion_storage_helper , self .ema_coef
960955 )
961956 if ema_ckpt_path is not None : # update ema if needed
962- self .zcc_ema_processor .load_ema_state_dict (ema_ckpt_path )
957+ logger .info (f"[ZCC EMA] load state dict from { ema_ckpt_path } " )
958+ with device_guard ("cpu" ):
959+ state_dict = paddle .load (ema_ckpt_path )
960+ if self .use_expert_parallel and self .dp_rank > 0 :
961+ state_dict = self ._filter_moe_no_sync_optimizer_params (
962+ self .model_meta_content , state_dict
963+ )
964+ self .zcc_ema_processor .load_ema_state_dict (state_dict )
965+ logger .info ("[ZCC EMA] done loading" )
963966 ema_ckpt_path = None
964967 elif task_type == ZCCTaskType .PREPARE :
965968 start_time = time .time ()
0 commit comments