Skip to content

Commit bb61ca1

Browse files
Resolve model state when swapping #4
1 parent 81510e3 commit bb61ca1

File tree

4 files changed

+99
-94
lines changed

4 files changed

+99
-94
lines changed

app.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ title = "SQL-Sidekick"
44
description = "QnA with tabular data using NLQ"
55
LongDescription = "about.md"
66
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
7-
Version = "0.0.13"
7+
Version = "0.0.14"
88

99
[Runtime]
1010
MemoryLimit = "64Gi"

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
packages = [{include = "sidekick"}]
2222

2323
[tool.poetry.dependencies]
24-
python = ">=3.8.1,<=3.11"
24+
python = ">=3.8.1,<=3.10"
2525
pandas = "^1.3.3"
2626
numpy = "^1.21.2"
2727
click = "^8.0.1"
@@ -35,7 +35,7 @@ sqlglot = "^12.2.0"
3535
sqlparse = "^0.4.4"
3636
transformers = "^4.29.0"
3737
sentence-transformers = "^2.2.2"
38-
torch = "^2.0.1"
38+
torch = "2.0.1"
3939
sqlalchemy-utils = "^0.41.1"
4040
h2o-wave = "0.26.1"
4141
pandasql = "0.7.3"

requirements.txt

Lines changed: 90 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,90 @@
1-
accelerate==0.21.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
2-
aiohttp==3.8.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
3-
aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
4-
ansicon==1.89.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0" and platform_system == "Windows"
5-
anyio==4.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
6-
async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
7-
attrs==23.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
8-
bitsandbytes==0.41.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
9-
blessed==1.20.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
10-
cachetools==5.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
11-
certifi==2023.7.22 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
12-
charset-normalizer==3.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
13-
click==8.1.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
14-
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
15-
dataclasses-json==0.5.14 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
16-
exceptiongroup==1.1.3 ; python_full_version >= "3.8.1" and python_version < "3.11"
17-
filelock==3.12.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
18-
frozenlist==1.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
19-
fsspec==2023.9.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
20-
gptcache==0.1.42 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
21-
greenlet==3.0.0 ; python_full_version >= "3.8.1" and platform_machine == "aarch64" and python_full_version <= "3.11.0" or python_full_version >= "3.8.1" and platform_machine == "ppc64le" and python_full_version <= "3.11.0" or python_full_version >= "3.8.1" and platform_machine == "x86_64" and python_full_version <= "3.11.0" or python_full_version >= "3.8.1" and platform_machine == "amd64" and python_full_version <= "3.11.0" or python_full_version >= "3.8.1" and platform_machine == "AMD64" and python_full_version <= "3.11.0" or python_full_version >= "3.8.1" and platform_machine == "win32" and python_full_version <= "3.11.0" or python_full_version >= "3.8.1" and platform_machine == "WIN32" and python_full_version <= "3.11.0"
22-
h11==0.14.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
23-
h2o-wave==0.26.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
24-
httpcore==0.18.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
25-
httpx==0.25.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
26-
huggingface-hub==0.17.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
27-
idna==3.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
28-
inquirer==3.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
29-
instructorembedding==1.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
30-
jinja2==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
31-
jinxed==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0" and platform_system == "Windows"
32-
joblib==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
33-
langchain==0.0.142 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
34-
llama-index==0.5.27 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
35-
loguru==0.7.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
36-
markupsafe==2.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
37-
marshmallow==3.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
38-
mpmath==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
39-
multidict==6.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
40-
mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
41-
networkx==3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
42-
nltk==3.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
43-
numexpr==2.8.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
44-
numpy==1.24.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
45-
openai==0.28.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
46-
openapi-schema-pydantic==1.2.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
47-
packaging==23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
48-
pandas==1.5.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
49-
pandasql==0.7.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
50-
pillow==10.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
51-
psutil==5.9.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
52-
psycopg2-binary==2.9.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
53-
pydantic==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
54-
python-dateutil==2.8.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
55-
python-editor==1.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
56-
pytz==2023.3.post1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
57-
pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
58-
readchar==4.0.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
59-
regex==2023.10.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
60-
requests==2.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
61-
safetensors==0.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
62-
scikit-learn==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
63-
scipy==1.10.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
64-
sentence-transformers==2.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
65-
sentencepiece==0.1.99 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
66-
setuptools==68.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
67-
six==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
68-
sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
69-
sqlalchemy-utils==0.41.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
70-
sqlalchemy==1.4.49 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
71-
sqlglot==12.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
72-
sqlparse==0.4.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
73-
starlette==0.31.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
74-
sympy==1.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
75-
tenacity==8.2.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
76-
threadpoolctl==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
77-
tiktoken==0.5.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
78-
tokenizers==0.14.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
79-
toml==0.10.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
80-
torch==2.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
81-
torchvision==0.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
82-
tqdm==4.66.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
83-
transformers==4.34.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
84-
typing-extensions==4.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
85-
typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
86-
urllib3==2.0.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
87-
uvicorn==0.23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
88-
wcwidth==0.2.8 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
89-
win32-setctime==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0" and sys_platform == "win32"
90-
yarl==1.9.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.11.0"
1+
accelerate==0.21.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2+
aiohttp==3.8.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3+
aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4+
ansicon==1.89.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows"
5+
anyio==4.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6+
async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7+
attrs==23.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8+
bitsandbytes==0.41.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
9+
blessed==1.20.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
10+
cachetools==5.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
11+
certifi==2023.7.22 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
12+
charset-normalizer==3.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
13+
click==8.1.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
14+
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
15+
dataclasses-json==0.5.14 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
16+
exceptiongroup==1.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
17+
filelock==3.12.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
18+
frozenlist==1.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
19+
fsspec==2023.9.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
20+
gptcache==0.1.42 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
21+
greenlet==3.0.0 ; python_full_version >= "3.8.1" and platform_machine == "aarch64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "ppc64le" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "x86_64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "amd64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "AMD64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "win32" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "WIN32" and python_full_version <= "3.10.0"
22+
h11==0.14.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
23+
h2o-wave==0.26.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
24+
httpcore==0.18.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
25+
httpx==0.25.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
26+
huggingface-hub==0.17.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
27+
idna==3.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
28+
inquirer==3.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
29+
instructorembedding==1.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
30+
jinja2==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
31+
jinxed==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows"
32+
joblib==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
33+
langchain==0.0.142 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
34+
llama-index==0.5.27 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
35+
loguru==0.7.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
36+
markupsafe==2.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
37+
marshmallow==3.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
38+
mpmath==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
39+
multidict==6.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
40+
mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
41+
networkx==3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
42+
nltk==3.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
43+
numexpr==2.8.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
44+
numpy==1.24.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
45+
openai==0.28.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
46+
openapi-schema-pydantic==1.2.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
47+
packaging==23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
48+
pandas==1.5.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
49+
pandasql==0.7.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
50+
pillow==10.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
51+
psutil==5.9.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
52+
psycopg2-binary==2.9.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
53+
pydantic==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
54+
python-dateutil==2.8.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
55+
python-editor==1.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
56+
pytz==2023.3.post1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
57+
pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
58+
readchar==4.0.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
59+
regex==2023.10.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
60+
requests==2.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
61+
safetensors==0.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
62+
scikit-learn==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
63+
scipy==1.10.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
64+
sentence-transformers==2.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
65+
sentencepiece==0.1.99 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
66+
setuptools==68.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
67+
six==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
68+
sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
69+
sqlalchemy-utils==0.41.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
70+
sqlalchemy==1.4.49 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
71+
sqlglot==12.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
72+
sqlparse==0.4.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
73+
starlette==0.31.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
74+
sympy==1.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
75+
tenacity==8.2.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
76+
threadpoolctl==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
77+
tiktoken==0.5.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
78+
tokenizers==0.14.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
79+
toml==0.10.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
80+
torch==2.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
81+
torchvision==0.15.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
82+
tqdm==4.66.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
83+
transformers==4.34.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
84+
typing-extensions==4.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
85+
typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
86+
urllib3==2.0.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
87+
uvicorn==0.23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
88+
wcwidth==0.2.8 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
89+
win32-setctime==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and sys_platform == "win32"
90+
yarl==1.9.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"

sidekick/query.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,17 @@ def __new__(
4141
is_regenerate_with_options: bool = False,
4242
):
4343
offloading = is_resource_low()
44-
if offloading and is_regenerate_with_options:
44+
# Initially load one model at a time if the user swapped the model dynamically.
45+
# TODO:
46+
# 1. Keep multiple models in memory if possible
47+
# 2. Support remote model loading as an option
48+
if offloading and is_regenerate_with_options or (cls._instance and cls._instance.model_name != model_name):
4549
del cls._instance
4650
cls._instance = None
4751
gc.collect()
4852
torch.cuda.empty_cache()
4953
logger.info(f"Low memory: {offloading}/ Model re-initialization: True")
54+
5055
if cls._instance is None:
5156
cls._instance = super().__new__(cls)
5257
cls._instance.model, cls._instance.tokenizer = load_causal_lm_model(

0 commit comments

Comments
 (0)