Skip to content

Commit d0daa71

Browse files
committed
Add huggingface support.
1 parent 2e509e6 commit d0daa71

File tree

6 files changed

+22
-4
lines changed

6 files changed

+22
-4
lines changed

bioscanclip/config/global_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,6 @@ general_fine_tune_setting:
6262
epoch: 15
6363
batch_size: 200
6464

65+
hf_repo_id: bioscan-ml/clibd
66+
6567
default_seed: 42

bioscanclip/config/model_config/for_bioscan_1m/final_experiments/image_dna_seed_42.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dna:
1515
model_output_name: image_dna_4gpu
1616
evaluation_period: 1
1717
ckpt_path: ${project_root_path}/ckpt/bioscan_clip/final_experiments/image_dna_4gpu_50epoch/best.pth
18+
hf_model_name: ckpt/bioscan_clip/final_experiments/image_dna_4gpu_50epoch/best.pth
1819
output_dim: 768
1920
port: 29532
2021

bioscanclip/config/model_config/for_bioscan_1m/final_experiments/image_dna_text_seed_42.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ language:
1717
model_output_name: image_dna_text_4gpu
1818
evaluation_period: 1
1919
ckpt_path: ${project_root_path}/ckpt/bioscan_clip/final_experiments/image_dna_text_4gpu_50epoch/best.pth
20+
hf_model_name: ckpt/bioscan_clip/final_experiments/image_dna_text_4gpu_50epoch/best.pth
2021
output_dim: 768
2122
port: 29531
2223

bioscanclip/config/model_config/for_bioscan_5m/final_experiments/image_dna_seed_42.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dna:
1515
model_output_name: image_dna_4gpu
1616
evaluation_period: 1
1717
ckpt_path: ${project_root_path}/ckpt/bioscan_clip/new_5M_training/trained_with_5M_image_dna/best.pth
18+
hf_model_name: ckpt/bioscan_clip/new_5M_training/trained_with_5M_image_dna/best.pth
1819
output_dim: 768
1920
port: 29531
2021

bioscanclip/config/model_config/for_bioscan_5m/final_experiments/image_dna_text_seed_42.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ language:
1818
model_output_name: image_dna_text_4gpu
1919
evaluation_period: 1
2020
ckpt_path: ${project_root_path}/ckpt/bioscan_clip/new_5M_training/trained_with_5M_image_dna_text/best.pth
21+
hf_model_name: ckpt/bioscan_clip/new_5M_training/trained_with_5M_image_dna_text/best.pth
2122
output_dim: 768
2223
port: 29531
2324

scripts/inference_and_eval.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
make_prediction,
2727
All_TYPE_OF_FEATURES_OF_KEY,
2828
)
29+
from huggingface_hub import hf_hub_download
2930

3031
PLOT_FOLDER = "html_plots"
3132
RETRIEVAL_FOLDER = "image_retrieval"
@@ -595,12 +596,23 @@ def main(args: DictConfig) -> None:
595596

596597
if hasattr(args.model_config, "load_ckpt") and args.model_config.load_ckpt is False:
597598
pass
599+
# elif os.path.exists(args.model_config.ckpt_path):
600+
# checkpoint = torch.load(args.model_config.ckpt_path, map_location="cuda:0")
601+
# print(f"Loading model from {args.model_config.ckpt_path}")
602+
# print()
603+
# model.load_state_dict(checkpoint)
604+
elif hasattr(args.model_config, "hf_model_name"):
605+
checkpoint_path = hf_hub_download(
606+
repo_id=args.hf_repo_id,
607+
filename=args.model_config.hf_model_name,
608+
)
609+
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
610+
model.load_state_dict(checkpoint)
611+
print(f"Loading model from {args.hf_repo_id}/{args.model_config.hf_model_name}")
598612
else:
599-
checkpoint = torch.load(args.model_config.ckpt_path, map_location="cuda:0")
600-
print(f"Loading model from {args.model_config.ckpt_path}")
601-
print()
613+
raise ValueError("No checkpoint found. Please specify a valid checkpoint path.")
614+
602615

603-
model.load_state_dict(checkpoint)
604616

605617
# Load data
606618
# args.model_config.batch_size = 24

0 commit comments

Comments
 (0)