diff --git a/graphgps/finetuning.py b/graphgps/finetuning.py index 3423252c..a838ef03 100644 --- a/graphgps/finetuning.py +++ b/graphgps/finetuning.py @@ -121,21 +121,17 @@ def init_model_from_pretrained(model, pretrained_dir, pretrained_dict = ckpt[MODEL_STATE] model_dict = model.state_dict() + if reset_prediction_head: + # Filter out prediction head parameter keys. + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if not k.startswith('post_mp')} + if not list(pretrained_dict.keys())[0].startswith('model.'): # Update checkpoint dict for models saved with GraphGym PyG prior v2.1 for k in list(pretrained_dict.keys()): # print(f' updating: {k} -> model.{k}') pretrained_dict[f'model.{k}'] = pretrained_dict.pop(k) - # print('>>>> pretrained dict: ') - # print(pretrained_dict.keys()) - # print('>>>> model dict: ') - # print(model_dict.keys()) - - if reset_prediction_head: - # Filter out prediction head parameter keys. - pretrained_dict = {k: v for k, v in pretrained_dict.items() - if not k.startswith('post_mp')} # Overwrite entries in the existing state dict. model_dict.update(pretrained_dict) # Load the new state dict.