diff --git a/paddlenlp/trainer/unified_checkpoint/load_dynamic.py b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py index 7f34ddc145c0..bdeb20c6de3b 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_dynamic.py +++ b/paddlenlp/trainer/unified_checkpoint/load_dynamic.py @@ -307,7 +307,8 @@ def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, saf ) dist.barrier() logger.debug("Setting state dict into model ...") - error_msgs = _load_state_dict_into_model(model, state_dict, "") + model_to_load_state_dict = model.state_dict() + error_msgs = _load_state_dict_into_model(model, state_dict, "", model_to_load_state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) raise RuntimeError(f"Error(s) in loading dynamic state_dict for {model.__class__.__name__}:\n\t{error_msg}")