Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
if not os.path.isfile(file_path):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}")

logger.info(f"Loading model from {resume_from_checkpoint} .")
logger.info(f"Loading model from {file_path}.")
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = paddle.load(file_path, return_numpy=True)
state_dict = self._remap_parameter_name(resume_from_checkpoint, state_dict, is_opt=False)
Expand Down Expand Up @@ -514,6 +514,11 @@ def load_optimizer_state_with_reshard(self, checkpoint, base_opt_name, model_wra
is_matched = reshard_util.sharding_v2.is_matched_optimizer_state_dict(
one_shard_opt_state_dict, self.optimizer, model_wrapped
)
is_matched = paddle.to_tensor([is_matched], dtype=paddle.int32)
dp_group = fleet.get_hybrid_communicate_group().get_data_parallel_group()
dp_src_rank = fleet.get_hybrid_communicate_group().get_data_parallel_group_src_rank()
dist.broadcast(is_matched, src=dp_src_rank, group=dp_group)
is_matched = bool(is_matched[0])
else:
is_matched = True

Expand Down Expand Up @@ -894,6 +899,10 @@ def _gather_sharding_metas(self):
sharding_meta["param_meta_keys"] = ["shape", "dtype", "is_distributed", "no_sync"]
sharding_meta["sharding_strategy"] = sharding_strategy
sharding_meta["enable_overlap"] = pp_overlap
dp_metas_list = self._all_gather_simple_object(sharding_meta, self.hcg.get_data_parallel_group())
for e in dp_metas_list:
for key in ["structure_name_mapping", "param_meta"]:
sharding_meta[key].update(e[key])
suffix = self._sharding_meta_suffix()
sharding_metas[suffix] = sharding_meta
sharding_metas_list = self._all_gather_simple_object(sharding_metas, self.hcg.get_model_parallel_group())
Expand Down
55 changes: 29 additions & 26 deletions paddlenlp/trainer/utils/zero_cost_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,36 +220,34 @@ def ema_state_dict(self):
ema_state_dict[k] = tensor
ema_state_dict_master_weights = {}
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
t = self.ema_buffer._slice(
meta["start"] - self.master_min_offset, meta["end"] - self.master_min_offset
).clone()
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
t = self.ema_buffer._slice(s, e).clone()
t.get_tensor()._set_dims(meta["shape"])
t.name = meta["name"]
ema_state_dict_master_weights[k] = t
ema_state_dict["master_weights"] = ema_state_dict_master_weights
return ema_state_dict

def load_ema_state_dict(self, path):
with device_guard("cpu"):
logger.info(f"[ZCC EMA] load state dict from {path}")
state_dict = paddle.load(path)
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
logger.info(f"[ZCC EMA] load model weight key={k}")
start = tensor_meta["start"]
end = tensor_meta["end"]
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
continue # non fp32 has no `self.ema_buffer_model_params`
def load_ema_state_dict(self, state_dict):
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
logger.info(f"[ZCC EMA] load model weight key={k}")
start = tensor_meta["start"]
end = tensor_meta["end"]
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
continue # non fp32 has no `self.ema_buffer_model_params`
if k in state_dict:
cpu_buffer = self.ema_buffer_model_params[tensor_meta["buffer_index"]]
tensor = state_dict[k].flatten()
cpu_buffer[start:end] = tensor

ema_master = state_dict["master_weights"]
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
logger.info(f"[ZCC EMA] load optimizer weight key={k}")
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
self.ema_buffer[s:e] = ema_master[k]
logger.info("[ZCC EMA] done loading")
ema_master = state_dict["master_weights"]
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
logger.info(f"[ZCC EMA] load optimizer weight key={k}")
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
if k in ema_master: # state-dict is filtered
self.ema_buffer[s:e] = ema_master[k].flatten()


class ParamFusionStorageHelper:
Expand Down Expand Up @@ -408,11 +406,6 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
logger.info("[ZCC manager] Synced checkpoints.")

def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kwargs):
if not isinstance(model, PipelineLayer):
self.manager.zcc_pipeline_hook(0)
# logger.info(
# f"check coef: {args.zcc_save_ema_coef} {control.should_save}, {state.global_step}, {self.zcc_ema_interval}"
# )
if not control.should_save:
if args.zcc_save_ema_coef is not None and state.global_step % self.zcc_ema_interval == 0:
self.maybe_update_zcc_worker(args, model, optimizer, state.global_step)
Expand All @@ -425,6 +418,8 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
non_cached_objects = (lr_scheduler.state_dict(), state, self.get_rng_states(args))
self.manager.get_idle_worker_for_saving((save_infos, non_cached_objects))
self.runtime_timer.stop()
if not isinstance(model, PipelineLayer):
self.manager.zcc_pipeline_hook(0)

def get_rng_states(self, args):
if not args.save_rng_states:
Expand Down Expand Up @@ -959,7 +954,15 @@ def run(self):
self.optimizer_fusion_storage_helper, self.param_fusion_storage_helper, self.ema_coef
)
if ema_ckpt_path is not None: # update ema if needed
self.zcc_ema_processor.load_ema_state_dict(ema_ckpt_path)
logger.info(f"[ZCC EMA] load state dict from {ema_ckpt_path}")
with device_guard("cpu"):
state_dict = paddle.load(ema_ckpt_path)
if self.use_expert_parallel and self.dp_rank > 0:
state_dict = self._filter_moe_no_sync_optimizer_params(
self.model_meta_content, state_dict
)
self.zcc_ema_processor.load_ema_state_dict(state_dict)
logger.info("[ZCC EMA] done loading")
ema_ckpt_path = None
elif task_type == ZCCTaskType.PREPARE:
start_time = time.time()
Expand Down