Skip to content

Commit ed4721b

Browse files
Generalize model download #44
1 parent e3ccf54 commit ed4721b

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

start.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
import os
22
import shlex
33
import subprocess
4+
import time
45
from pathlib import Path
56

67
from huggingface_hub import snapshot_download
78

89
print(f"Download model...")
910
base_path = (Path(__file__).parent).resolve()
1011

11-
# Model 1:
12-
print("Downloading model 1...")
13-
snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/")
14-
# Model 2:
15-
print("Downloading model 2...")
16-
snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/")
12+
MODEL_CHOICE_MAP = {
13+
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
14+
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
15+
}
16+
17+
for _m in MODEL_CHOICE_MAP.values():
18+
print(f"Downloading {_m}...", flush=True)
19+
snapshot_download(repo_id=_m, cache_dir=f"{base_path}/models/")
20+
time.sleep(3)
21+
1722
print(f"Download embedding model...")
1823
snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/")
1924

0 commit comments

Comments
 (0)