From ba1e6f0cf84ba7558373271ced7cf7e7865ff9ea Mon Sep 17 00:00:00 2001 From: pramitchoudhary Date: Tue, 12 Dec 2023 17:42:24 -0800 Subject: [PATCH] Add support for sqlcoder-34b-alpha #4 --- Makefile | 3 --- sidekick/utils.py | 2 ++ start.py | 3 +++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index c396853..5bfa913 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -sentence_transformer = s3cmd get --recursive --skip-existing s3://h2o-model-gym/models/nlp/sentence_trasnsformer/all-MiniLM-L6-v2/ ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2 demo_data = s3cmd get --recursive --skip-existing s3://h2o-sql-sidekick/demo/sleepEDA/ ./examples/demo/ .PHONY: download_demo_data @@ -12,8 +11,6 @@ setup: download_demo_data ## Setup ./.sidekickvenv/bin/python3 -m pip install -r requirements.txt mkdir -p ./examples/demo/ -download_models: - mkdir -p ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2 download_demo_data: mkdir -p ./examples/demo/ diff --git a/sidekick/utils.py b/sidekick/utils.py index 0c94c6b..a163893 100644 --- a/sidekick/utils.py +++ b/sidekick/utils.py @@ -23,6 +23,7 @@ MODEL_CHOICE_MAP_EVAL_MODE = { "h2ogpt-sql-sqlcoder2": "defog/sqlcoder2", + "h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha", "h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B", "gpt-3.5-turbo": "gpt-3.5-turbo-1106", "gpt-4-8k": "gpt-4", @@ -32,6 +33,7 @@ MODEL_CHOICE_MAP_DEFAULT = { "h2ogpt-sql-sqlcoder2": "defog/sqlcoder2", + "h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha", "h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B" } diff --git a/start.py b/start.py index 6b65852..d9fc116 100644 --- a/start.py +++ b/start.py @@ -25,6 +25,9 @@ def setup_dir(base_path: str): # Model 2: print(f"Download model 2...") snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/") +# Model 3: +print(f"Download model 3...") +snapshot_download(repo_id="defog/sqlcoder-34b-alpha", cache_dir=f"{base_path}/models/") print(f"Download embedding model...") snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/")