Skip to content

Commit e52d332

Browse files
Misc improvements and few bug fixes
2 parents 405991a + 0896312 commit e52d332

File tree

9 files changed

+65
-18
lines changed

9 files changed

+65
-18
lines changed

app.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ title = "SQL-Sidekick"
44
description = "QnA with tabular data using NLQ"
55
LongDescription = "about.md"
66
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
7-
Version = "0.0.17"
7+
Version = "0.1.0"
88

99
[Runtime]
1010
MemoryLimit = "64Gi"

sidekick/configs/prompt_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
100100
101101
### Input:
102-
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a SQL (dialect:SQLite) query to answer the following question:\n{question_txt}.
102+
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a valid SQL (dialect:SQLite) query to answer the following question:\n{question_txt}.
103103
This query will run on a database whose schema is represented in this string:
104104
CREATE TABLE '{table_name}' ({column_info}
105105
);

sidekick/prompter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ def query_api(
452452
logger.info("Executing user provided SQL without re-generation...")
453453
res = question.strip().lower().split("execute sql:")[1].strip()
454454
else:
455+
_check_cond = question.strip().lower().split("execute sql:")
456+
if len(_check_cond) > 1:
457+
question = question.strip().lower().split("execute sql:")[1].strip()
455458
res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
456459
logger.info(f"Input query: {question}")
457460
logger.info(f"Generated response:\n\n{res}")

sidekick/query.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,10 @@ def generate_sql(
584584
_out = output_re.sequences[sorted_idx]
585585
res = tokenizer.decode(_out[input_length:], skip_special_tokens=True)
586586
result = res.replace("table_name", _table_name)
587+
# Remove the last semi-colon if exists at the end
588+
# we will add it later
589+
if result.endswith(";"):
590+
result = result.replace(";", "")
587591
if "LIMIT".lower() not in result.lower():
588592
res = "SELECT " + result.strip() + " LIMIT 100;"
589593
else:
@@ -602,6 +606,8 @@ def generate_sql(
602606
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
603607
_temp = _res.replace("table_name", table_name).split(";")[0]
604608

609+
if _temp.endswith(";"):
610+
_temp = _temp.replace(";", "")
605611
if "LIMIT".lower() not in _temp.lower():
606612
res = "SELECT " + _temp.strip() + " LIMIT 100;"
607613
else:

sidekick/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,23 @@
1515
from sentence_transformers import SentenceTransformer
1616
from sidekick.logger import logger
1717
from sklearn.metrics.pairwise import cosine_similarity
18-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
18+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
19+
BitsAndBytesConfig)
1920

20-
model_choices_map = {
21+
MODEL_CHOICE_MAP = {
2122
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
2223
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
2324
}
2425

25-
model_device_map = {
26+
MODEL_DEVICE_MAP = {
2627
"h2ogpt-sql-sqlcoder2": 0,
2728
"h2ogpt-sql-nsql-llama-2-7B": 1,
2829
}
2930

31+
TASK_CHOICE = {
32+
"q_a": "Question/Answering",
33+
"sqld": "SQL Debugging",
34+
}
3035

3136
def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
3237
# Reference:
@@ -290,7 +295,7 @@ def is_resource_low(model_name: str):
290295
off_load = False
291296
else:
292297
n_gpus = torch.cuda.device_count()
293-
device_index = model_device_map[model_name] if model_name and n_gpus > 1 else 0
298+
device_index = MODEL_DEVICE_MAP[model_name] if model_name and n_gpus > 1 else 0
294299
logger.debug(f"Information on device: {device_index}")
295300
free_in_GB = int(torch.cuda.mem_get_info(device_index)[0] / 1024**3)
296301
total_memory = int(torch.cuda.get_device_properties(device_index).total_memory / 1024**3)
@@ -382,14 +387,14 @@ def _load_llm(model_type: str, device_index: int = 0, load_in_4bit=True):
382387

383388
if not model_type: # if None, load all models
384389
for device_index in range(n_gpus):
385-
model_name = list(model_choices_map.values())[device_index]
390+
model_name = list(MODEL_CHOICE_MAP.values())[device_index]
386391
model, tokenizer = _load_llm(model_name, device_index)
387-
_name = list(model_choices_map.keys())[device_index]
392+
_name = list(MODEL_CHOICE_MAP.keys())[device_index]
388393
models[_name] = model
389394
tokenizers[_name] = tokenizer
390395
else:
391-
model_name = model_choices_map[model_type]
392-
d_index = model_device_map[model_type] if n_gpus > 1 else 0
396+
model_name = MODEL_CHOICE_MAP[model_type]
397+
d_index = MODEL_DEVICE_MAP[model_type] if n_gpus > 1 else 0
393398
model, tokenizer = _load_llm(model_name, d_index)
394399
models[model_type] = model
395400
tokenizers[model_type] = tokenizer

static/screenshot-01.png

417 KB
Loading

static/screenshot-02.png

426 KB
Loading

static/screenshot-03.png

496 KB
Loading

ui/app.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from h2o_wave.core import expando_to_dict
1313
from sidekick.prompter import db_setup_api, query_api
1414
from sidekick.query import SQLGenerator
15-
from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables
15+
from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables, TASK_CHOICE
1616

1717
# Load the config file and initialize required paths
1818
base_path = (Path(__file__).parent / "../").resolve()
@@ -91,7 +91,7 @@ def clear_cards(q, ignore: Optional[List[str]] = []) -> None:
9191
async def chat(q: Q):
9292
q.page["sidebar"].value = "#chat"
9393

94-
if q.args.table_dropdown or q.args.model_choice_dropdown:
94+
if q.args.table_dropdown or q.args.model_choice_dropdown or q.args.task_dropdown:
9595
# If a table/model is selected, the trigger causes refresh of the page
9696
# so we update chat history with table name selection and return
9797
# avoiding re-drawing.
@@ -113,6 +113,9 @@ async def chat(q: Q):
113113
ui.choice("h2ogpt-sql-sqlcoder2", "h2ogpt-sql-sqlcoder2"),
114114
]
115115
q.user.model_choice_dropdown = "h2ogpt-sql-sqlcoder2"
116+
117+
task_choices = [ui.choice("q_a", "Question/Answering"), ui.choice("sqld", "SQL Debugging")]
118+
q.user.task_choice_dropdown = "q_a"
116119
add_card(
117120
q,
118121
"background_card",
@@ -123,7 +126,7 @@ async def chat(q: Q):
123126
ui.inline(items=[ui.toggle(name="demo_mode", label="Demo", trigger=True)], justify="end"),
124127
],
125128
),
126-
)
129+
),
127130

128131
add_card(
129132
q,
@@ -149,7 +152,24 @@ async def chat(q: Q):
149152
),
150153
],
151154
),
152-
)
155+
),
156+
add_card(
157+
q,
158+
"task_choice",
159+
ui.form_card(
160+
box="vertical",
161+
items=[
162+
ui.dropdown(
163+
name="task_dropdown",
164+
label="Task",
165+
required=True,
166+
choices=task_choices,
167+
value=q.user.task_choice_dropdown if q.user.task_choice_dropdown else None,
168+
trigger=True,
169+
)
170+
],
171+
),
172+
),
153173
add_card(
154174
q,
155175
"chat_card",
@@ -228,11 +248,15 @@ async def chatbot(q: Q):
228248
if (
229249
f"Table {q.user.table_dropdown} selected" in q.args.chatbot
230250
or f"Model {q.user.model_choice_dropdown} selected" in q.args.chatbot
251+
or f"Task {q.user.task_dropdown} selected" in q.args.chatbot
231252
):
232253
return
233254

234255
# Append bot response.
235256
question = f"{q.args.chatbot}"
257+
# Check on task choice.
258+
if q.user.task_dropdown == "sqld":
259+
question = f"Execute SQL:\n{q.args.chatbot}"
236260
logging.info(f"Question: {question}")
237261

238262
# For regeneration, currently there are 2 modes
@@ -531,13 +555,16 @@ async def init(q: Q) -> None:
531555
items=[
532556
ui.nav_group(
533557
"Menu",
534-
items=[ui.nav_item(name="#datasets", label="Upload Dataset"), ui.nav_item(name="#chat", label="Chat")],
558+
items=[
559+
ui.nav_item(name="#datasets", label="Upload Dataset", icon="Database"),
560+
ui.nav_item(name="#chat", label="Chat", icon="Chat"),
561+
],
535562
),
536563
ui.nav_group(
537564
"Help",
538565
items=[
539-
ui.nav_item(name="#documentation", label="Documentation"),
540-
ui.nav_item(name="#support", label="Support"),
566+
ui.nav_item(name="#documentation", label="Documentation", icon="TextDocument"),
567+
ui.nav_item(name="#support", label="Support", icon="Telemarketer"),
541568
],
542569
),
543570
],
@@ -638,7 +665,13 @@ async def on_event(q: Q):
638665
q.args.chatbot = f"Model {q.user.model_choice_dropdown} selected"
639666
# Refresh response is triggered when user selects a table via dropdown
640667
event_handled = True
641-
668+
if q.args.task_dropdown and not q.args.chatbot and q.user.task_dropdown != q.args.task_dropdown:
669+
logging.info(f"User selected task: {q.args.task_dropdown}")
670+
q.user.task_dropdown = q.args.task_dropdown
671+
q.page["task_choice"].task_dropdown.value = q.user.task_dropdown
672+
q.args.chatbot = f"Task '{TASK_CHOICE[q.user.task_dropdown]}' selected"
673+
# Refresh response is triggered when user selects a table via dropdown
674+
event_handled = True
642675
if (
643676
q.args.save_conversation
644677
or q.args.save_rejected_conversation

0 commit comments

Comments
 (0)