Skip to content

Commit 06279be

Browse files
Quick bug fixes, adjustments, cosmetics #44
1 parent 0edcbbc commit 06279be

File tree

9 files changed

+65
-44
lines changed

9 files changed

+65
-44
lines changed

Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ setup: download_demo_data ## Setup
1010
./.sidekickvenv/bin/python3 -m pip install --upgrade pip
1111
./.sidekickvenv/bin/python3 -m pip install wheel
1212
./.sidekickvenv/bin/python3 -m pip install -r requirements.txt
13-
mkdir -p ./db/sqlite
1413
mkdir -p ./examples/demo/
1514

1615
download_models:

about.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
**Actively Being Maintained:** Yes (Demo release: _In active RnD_)
66

7-
**Last Updated:** October, 2023
7+
**Last Updated:** November, 2023
88

99
**Allows uploading and using new model and data:** Yes
1010

app.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ title = "SQL-Sidekick"
44
description = "QnA with tabular data using NLQ"
55
LongDescription = "about.md"
66
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
7-
Version = "0.1.3"
7+
Version = "0.1.4"
88

99
[Runtime]
1010
MemoryLimit = "64Gi"
11-
MemoryReservation = "16Gi"
11+
MemoryReservation = "64Gi"
1212
module = "start"
1313
VolumeMount = "/meta_data"
1414
VolumeSize = "100Gi"

sidekick/db_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
self.base_path = base_path
4040
self.column_names = []
4141
if dialect == "sqlite":
42+
logger.debug(f"Creating SQLite DB: sqlite:///{base_path}/db/sqlite/{db_name}.db")
4243
self._url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
4344
else:
4445
self._url = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/"
@@ -181,9 +182,10 @@ def add_samples(self, data_csv_path=None):
181182
# Fetch the number of rows from the table
182183
sample_query = f"SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1"
183184
num_rows = pd.read_sql_query(sample_query, engine)
184-
logger.info(f"Number of rows inserted: {num_rows.values[0][0]}")
185+
res = num_rows.values[0][0]
186+
logger.info(f"Number of rows inserted: {res}")
185187
engine.dispose()
186-
return num_rows, None
188+
return res, None
187189
except SQLAlchemyError as sqla_error:
188190
logger.debug("SQLAlchemy error:", sqla_error)
189191
return None, sqla_error

sidekick/prompter.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
)
2929

3030
# Load the config file and initialize required paths
31-
base_path = (Path(__file__).parent / "../").resolve()
32-
env_settings = toml.load(f"{base_path}/sidekick/configs/env.toml")
31+
app_base_path = (Path(__file__).parent / "../").resolve()
32+
# Below check is to handle the case when the app is running on the h2o.ai cloud or locally
33+
base_path = app_base_path if os.path.isdir("./.sidekickvenv/bin/") else "/meta_data"
34+
env_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml")
3335
db_dialect = env_settings["DB-DIALECT"]["DB_TYPE"]
3436
model_name = env_settings["MODEL_INFO"]["MODEL_NAME"]
3537
os.environ["TOKENIZERS_PARALLELISM"] = "False"
36-
__version__ = "0.0.4"
38+
__version__ = "0.1.4"
3739

3840

3941
def color(fore="", back="", text=None):
@@ -189,7 +191,7 @@ def db_setup_api(
189191
# env_settings["TABLE_INFO"]["TABLE_SAMPLES_PATH"] = table_samples_path
190192

191193
# Update settings file for future use.
192-
f = open(f"{base_path}/sidekick/configs/env.toml", "w")
194+
f = open(f"{app_base_path}/sidekick/configs/env.toml", "w")
193195
toml.dump(env_settings, f)
194196
f.close()
195197
path = f"{base_path}/var/lib/tmp/data"
@@ -391,7 +393,7 @@ def query_api(
391393
env_settings["MODEL_INFO"]["OPENAI_API_KEY"] = api_key
392394

393395
# Update settings file for future use.
394-
f = open(f"{base_path}/sidekick/configs/env.toml", "w")
396+
f = open(f"{app_base_path}/sidekick/configs/env.toml", "w")
395397
toml.dump(env_settings, f)
396398
f.close()
397399
openai.api_key = api_key

sidekick/query.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __new__(
7575
if cls._instance is None or (cls._instance and not cls._instance.models.get(model_name, None)):
7676
if cls._instance is None:
7777
cls._instance = super().__new__(cls)
78+
cls._instance.current_temps = {}
7879
cls._instance.models, cls._instance.tokenizers = load_causal_lm_model(
7980
model_name,
8081
cache_path=f"{job_path}/models/",
@@ -84,7 +85,7 @@ def __new__(
8485
)
8586
cls._instance.model_name = "h2ogpt-sql-sqlcoder2" if not model_name else model_name
8687
model_embed_path = f"{job_path}/models/sentence_transformers"
87-
cls._instance.models[cls._instance.model_name].current_temperature = 0.5
88+
cls._instance.current_temps[cls._instance.model_name] = 0.5
8889
device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device
8990
cls._instance.similarity_model = load_embedding_model(model_path=model_embed_path, device=device)
9091
return cls._instance
@@ -479,6 +480,7 @@ def generate_sql(
479480
tokenizer = self.tokenizers[model_name]
480481
inputs = tokenizer([query], return_tensors="pt")
481482
model = self.models[model_name]
483+
current_temperature = self.current_temps.get(model_name, 0.5)
482484
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
483485
logger.info(f"Context length: {input_length}")
484486

@@ -512,19 +514,25 @@ def generate_sql(
512514

513515
possible_temp_gt_5 = [0.6, 0.75, 0.8, 0.9, 1.0]
514516
possible_temp_lt_5 = [0.1, 0.2, 0.3, 0.4]
515-
random_temperature = model.current_temperature
516517
random_seed = random.randint(0, 50)
517518
torch.manual_seed(random_seed)
518-
if model.current_temperature >= 0.5:
519+
520+
if current_temperature >= 0.5:
519521
random_temperature = np.random.choice(possible_temp_lt_5, 1)[0]
520522
else:
521523
random_temperature = np.random.choice(possible_temp_gt_5, 1)[0]
524+
import pdb
525+
526+
pdb.set_trace()
522527
if not self.is_regenerate_with_options and not self.is_regenerate:
523528
# Greedy decoding
529+
# Reset temperature to 0.5
530+
current_temperature = 0.5
531+
logger.debug(f"Generation with default temperature : {current_temperature}")
524532
output = model.generate(
525533
**inputs.to(device_type),
526534
max_new_tokens=512,
527-
temperature=0.5,
535+
temperature=current_temperature,
528536
output_scores=True,
529537
do_sample=True,
530538
return_dict_in_generate=True,
@@ -544,7 +552,8 @@ def generate_sql(
544552
return_dict_in_generate=True,
545553
)
546554
generated_tokens = output.sequences[:, input_length:][0]
547-
model.current_temperature = random_temperature
555+
self.current_temps[model_name] = random_temperature
556+
logger.debug(f"Temperature saved: {self.current_temps[model_name]}")
548557
else:
549558
logger.info("Regeneration with options requested on previous query ...")
550559
# Diverse beam search decoding to explore more options

sidekick/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
}
3030

3131
TASK_CHOICE = {
32-
"q_a": "Question/Answering",
33-
"sqld": "SQL Debugging",
32+
"q_a": "Ask Questions",
33+
"sqld": "Debugging",
3434
}
3535

3636

@@ -60,6 +60,7 @@ def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, devic
6060

6161

6262
def load_embedding_model(model_path: str, device: str):
63+
logger.debug(f"Loading embedding model from: {model_path}")
6364
model_name_path = glob.glob(f"{model_path}/models--BAAI--bge-base-en/snapshots/*/")[0]
6465

6566
sentence_model = SentenceTransformer(model_name_path, cache_folder=model_path, device=device)
@@ -186,7 +187,7 @@ def save_query(
186187

187188

188189
def setup_dir(base_path: str):
189-
dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models/weights"]
190+
dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models", "db/sqlite"]
190191
for _dl in dir_list:
191192
p = Path(f"{base_path}/{_dl}")
192193
if not p.is_dir():
@@ -344,7 +345,7 @@ def _load_llm(model_type: str, device_index: int = 0, load_in_4bit=True):
344345
_load_in_8bit = load_in_8bit
345346
model_name = model_type
346347
logger.info(f"Loading model: {model_name} on device id: {device_index}")
347-
348+
logger.debug(f"Model cache: {cache_path}")
348349
# 22GB (Least requirement on GPU) is a magic number for the current model size.
349350
if off_load and re_generate and total_memory < 22:
350351
# To prevent the system from crashing in-case memory runs low.

start.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@
66

77
from huggingface_hub import snapshot_download
88

9-
print(f"Download model...")
10-
base_path = (Path(__file__).parent).resolve()
11-
12-
MODEL_CHOICE_MAP = {
13-
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
14-
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
15-
}
16-
17-
for _m in MODEL_CHOICE_MAP.values():
18-
print(f"Downloading {_m}...", flush=True)
19-
snapshot_download(repo_id=_m, cache_dir=f"{base_path}/models/")
20-
time.sleep(3)
9+
10+
def setup_dir(base_path: str):
11+
dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models", "db/sqlite"]
12+
for _dl in dir_list:
13+
p = Path(f"{base_path}/{_dl}")
14+
if not p.is_dir():
15+
p.mkdir(parents=True, exist_ok=True)
16+
17+
18+
print(f"Download models...")
19+
base_path = (Path(__file__).parent).resolve() if os.path.isdir("./.sidekickvenv/bin/") else "/meta_data"
20+
setup_dir(base_path)
21+
22+
# Model 1:
23+
print(f"Download model 1...")
24+
snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/")
25+
# Model 2:
26+
print(f"Download model 2...")
27+
snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/")
2128

2229
print(f"Download embedding model...")
2330
snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/")

ui/app.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
from sidekick.utils import TASK_CHOICE, get_table_keys, save_query, setup_dir, update_tables
1616

1717
# Load the config file and initialize required paths
18-
base_path = (Path(__file__).parent / "../").resolve()
19-
env_settings = toml.load(f"{base_path}/ui/app_config.toml")
18+
app_base_path = (Path(__file__).parent / "../").resolve()
19+
env_settings = toml.load(f"{app_base_path}/ui/app_config.toml")
20+
# Below check is to handle the case when the app is running on the h2o.ai cloud or locally
21+
base_path = app_base_path if os.path.isdir("./.sidekickvenv/bin/") else "/meta_data"
2022
tmp_path = f"{base_path}/var/lib/tmp"
2123

2224
ui_title = env_settings["WAVE_UI"]["TITLE"]
@@ -45,7 +47,7 @@ def initialize_models():
4547

4648

4749
async def user_variable(q: Q):
48-
db_settings = toml.load(f"{base_path}/sidekick/configs/env.toml")
50+
db_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml")
4951

5052
q.user.db_dialect = db_settings["DB-DIALECT"]["DB_TYPE"]
5153
q.user.host_name = db_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
@@ -115,7 +117,7 @@ async def chat(q: Q):
115117
]
116118
q.user.model_choice_dropdown = "h2ogpt-sql-sqlcoder2"
117119

118-
task_choices = [ui.choice("q_a", "Question/Answering"), ui.choice("sqld", "SQL Debugging")]
120+
task_choices = [ui.choice("q_a", "Ask Questions"), ui.choice("sqld", "Debugging")]
119121
q.user.task_choice_dropdown = "q_a"
120122
add_card(
121123
q,
@@ -162,7 +164,7 @@ async def chat(q: Q):
162164
items=[
163165
ui.dropdown(
164166
name="task_dropdown",
165-
label="Task",
167+
label="Mode",
166168
required=True,
167169
choices=task_choices,
168170
value=q.user.task_choice_dropdown if q.user.task_choice_dropdown else None,
@@ -250,7 +252,7 @@ async def chatbot(q: Q):
250252
if (
251253
f"Table {q.user.table_dropdown} selected" in q.args.chatbot
252254
or f"Model {q.user.model_choice_dropdown} selected" in q.args.chatbot
253-
or f"Task {q.user.task_dropdown} selected" in q.args.chatbot
255+
or f"{q.user.task_dropdown} mode selected" in q.args.chatbot
254256
):
255257
return
256258

@@ -417,13 +419,13 @@ async def fileupload(q: Q):
417419
table_name=q.user.table_name,
418420
)
419421
logging.info(f"DB updates: \n {db_resp}")
420-
q.args.n_rows = n_rows
421422
if "error" in str(db_resp).lower():
422423
q.page["dataset"].error_upload_bar.visible = True
423424
q.page["dataset"].error_bar.visible = False
424425
q.page["dataset"].progress_bar.visible = False
425426
else:
426427
q.page["dataset"].progress_bar.visible = False
428+
q.page["dataset"].success_bar.text = f"Data successfully uploaded, it has {n_rows:,} rows!"
427429
q.page["dataset"].success_bar.visible = True
428430
except Exception as e:
429431
logging.error(f"Something went wrong while uploading the dataset: {e}")
@@ -460,7 +462,7 @@ async def datasets(q: Q):
460462
ui.message_bar(
461463
name="success_bar",
462464
type="success",
463-
text=f"Data successfully uploaded, it has {q.args.n_rows} rows!",
465+
text=f"Data successfully uploaded!",
464466
visible=False,
465467
),
466468
ui.file_upload(
@@ -653,7 +655,7 @@ def upload_demo_examples(q: Q):
653655
q.user.table_info_path = f"{sample_data_path}/table_info.jsonl"
654656
q.user.sample_qna_path = None
655657

656-
n_rows, db_resp = db_setup_api(
658+
_, db_resp = db_setup_api(
657659
db_name=q.user.db_name,
658660
hostname=q.user.host_name,
659661
user_name=q.user.user_name,
@@ -665,7 +667,6 @@ def upload_demo_examples(q: Q):
665667
)
666668
logging.info(f"DB updated with demo examples: \n {db_resp}")
667669
q.args.table_dropdown = usr_table_name
668-
return n_rows
669670

670671

671672
async def on_event(q: Q):
@@ -698,7 +699,7 @@ async def on_event(q: Q):
698699
logging.info(f"User selected task: {q.args.task_dropdown}")
699700
q.user.task_dropdown = q.args.task_dropdown
700701
q.page["task_choice"].task_dropdown.value = q.user.task_dropdown
701-
q.args.chatbot = f"Task '{TASK_CHOICE[q.user.task_dropdown]}' selected"
702+
q.args.chatbot = f"'{TASK_CHOICE[q.user.task_dropdown]}' mode selected"
702703
# Refresh response is triggered when user selects a table via dropdown
703704
event_handled = True
704705
if (
@@ -763,7 +764,7 @@ async def on_event(q: Q):
763764
elif q.args.demo_mode:
764765
logging.info(f"Switching to demo mode!")
765766
# If demo datasets are not present, register them.
766-
_ = upload_demo_examples(q)
767+
upload_demo_examples(q)
767768
logging.info(f"Demo dataset selected: {q.user.table_name}")
768769
await submit_table(q)
769770
sample_qs = """

0 commit comments

Comments
 (0)