Skip to content

Commit 81510e3

Browse files
Enable sql coder2
2 parents a03d46c + a0b03c5 commit 81510e3

File tree

7 files changed

+87
-14
lines changed

7 files changed

+87
-14
lines changed

sidekick/configs/prompt_template.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,32 @@
7676
-- Using reference for TABLES '{table_name}' {context}; {question_txt}?
7777
7878
SELECT"""
79+
80+
# https://colab.research.google.com/drive/13BIKsqHnPOBcQ-ba2p77L5saiepTIwu0#scrollTo=0eI-VpCkf-fN
81+
STARCODER2_PROMPT = """
82+
### Instructions:
83+
Your task is convert a question into a SQL query, given a sqlite database schema.
84+
Only use the column names from the CREATE TABLE statement.
85+
Adhere to these rules:
86+
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
87+
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
88+
- When creating a ratio, always cast the numerator as float
89+
- Use COUNT(1) instead of COUNT(*)
90+
- If the question is asking for a rate, use COUNT to compute percentage
91+
- Avoid overly complex SQL queries
92+
- Avoid using the WITH statement
93+
- Don't use aggregate and window function together
94+
- Prefer NOT EXISTS to LEFT JOIN ON null id
95+
- When using DESC keep NULLs at the end
96+
- If JSONB format found in Table schema, do pattern matching on keywords from the question and use SQL functions such as ->> or ->
97+
98+
### Input:
99+
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a SQL query to answer the following question:\n{question_txt}.
100+
This query will run on a database whose schema is represented in this string:
101+
CREATE TABLE '{table_name}' ({column_info}
102+
);
103+
104+
-- Table '{table_name}', {context}, has sample values ({data_info_detailed})
105+
106+
### Response:
107+
SELECT"""

sidekick/prompter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def query_api(
335335
table_info_path: str,
336336
sample_queries_path: str,
337337
table_name: str,
338+
model_name: str = "h2ogpt-sql-nsql-llama-2-7B",
338339
is_regenerate: bool = False,
339340
is_regen_with_options: bool = False,
340341
is_command: bool = False,
@@ -365,7 +366,7 @@ def query_api(
365366
logger.info(f"Table in use: {table_names}")
366367
# Check if env.toml file exists
367368
api_key = None
368-
if model_name != "h2ogpt-sql":
369+
if "h2ogpt-sql" not in model_name:
369370
api_key = env_settings["MODEL_INFO"]["OPENAI_API_KEY"]
370371
if api_key is None or api_key == "":
371372
if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_API_KEY") == "":
@@ -414,6 +415,7 @@ def query_api(
414415
sql_g = SQLGenerator(
415416
db_url,
416417
api_key,
418+
model_name=model_name,
417419
job_path=base_path,
418420
data_input_path=table_info_path,
419421
sample_queries_path=sample_queries_path,

sidekick/query.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from llama_index.indices.struct_store import SQLContextContainerBuilder
1818
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
1919
NSQL_QUERY_PROMPT, QUERY_PROMPT,
20-
TASK_PROMPT)
20+
STARCODER2_PROMPT, TASK_PROMPT)
2121
from sidekick.logger import logger
2222
from sidekick.utils import (_check_file_info, filter_samples, is_resource_low,
2323
load_causal_lm_model, load_embedding_model,
@@ -32,7 +32,7 @@ def __new__(
3232
cls,
3333
db_url: str,
3434
openai_key: str = None,
35-
model_name="NumbersStation/nsql-llama-2-7B",
35+
model_name="h2ogpt-sql-nsql-llama-2-7B",
3636
data_input_path: str = "./table_info.jsonl",
3737
sample_queries_path: str = "./samples.csv",
3838
job_path: str = "./",
@@ -65,7 +65,7 @@ def __init__(
6565
self,
6666
db_url: str,
6767
openai_key: str = None,
68-
model_name="NumbersStation/nsql-llama-2-7B",
68+
model_name="h2ogpt-sql-nsql-llama-2-7B",
6969
data_input_path: str = "./table_info.jsonl",
7070
sample_queries_path: str = "./samples.csv",
7171
job_path: str = "./",
@@ -281,7 +281,7 @@ def generate_sql(
281281
context_queries = self.content_queries
282282
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)
283283

284-
if model_name != "h2ogpt-sql":
284+
if "h2ogpt-sql" not in model_name:
285285
_tasks = self.task_formatter(self._tasks)
286286

287287
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
@@ -427,7 +427,11 @@ def generate_sql(
427427
logger.debug(f"Relevant sample column values: {data_samples_list}")
428428
_table_name = ", ".join(table_names)
429429

430-
query = NSQL_QUERY_PROMPT.format(
430+
query_prompt_format = STARCODER2_PROMPT
431+
if model_name == "h2ogpt-sql-nsql-llama-2-7B":
432+
query_prompt_format = NSQL_QUERY_PROMPT
433+
434+
query = query_prompt_format.format(
431435
table_name=_table_name,
432436
column_info=_column_info,
433437
data_info_detailed=data_samples_list,
@@ -449,7 +453,7 @@ def generate_sql(
449453
# 3. Maybe positional interpolation --> https://arxiv.org/abs/2306.15595
450454
if int(input_length) > 4000:
451455
logger.info("Input length is greater than 1748, removing column description from the prompt")
452-
query = NSQL_QUERY_PROMPT.format(
456+
query = query_prompt_format.format(
453457
table_name=_table_name,
454458
column_info=_column_info,
455459
data_info_detailed="",

sidekick/schema_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ def generate_schema(data_path, output_path):
99
schema = df.dtypes.to_dict()
1010
schema_list = []
1111
special_characters = {" ": "_", ":": "_", "/": "_", "-": "_"}
12+
syntax_names = ["default"]
1213

1314
for key, value in schema.items():
1415
new_key = "".join(special_characters[s] if s in special_characters.keys() else s for s in key)
15-
16+
if new_key.lower() in syntax_names:
17+
new_key = new_key + "_col"
1618
if value == "object":
1719
value = "TEXT"
1820
unique_values = df[key].dropna().unique().tolist()

sidekick/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def read_sample_pairs(input_path: str, model_name: str = "h2ogpt-sql"):
193193
df = df.reset_index(drop=True)
194194

195195
# NSQL format
196-
if model_name != "h2ogpt-sql":
196+
if "h2ogpt-sql" not in model_name:
197197
# Open AI format
198198
# Convert frame to below format
199199
# [
@@ -281,7 +281,7 @@ def is_resource_low():
281281

282282

283283
def load_causal_lm_model(
284-
model_name: str,
284+
model_type: str,
285285
cache_path: str,
286286
device: str,
287287
load_in_8bit: bool = False,
@@ -290,7 +290,13 @@ def load_causal_lm_model(
290290
re_generate: bool = False,
291291
):
292292
try:
293-
# Load h2oGPT.NSQL model
293+
model_choices_map = {
294+
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
295+
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
296+
}
297+
model_name = model_choices_map[model_type]
298+
logger.info(f"Loading model: {model_name}")
299+
# Load h2oGPT.SQL model
294300
device = {"": 0} if torch.cuda.is_available() else "cpu" if device == "auto" else device
295301
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
296302
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
@@ -402,7 +408,7 @@ def check_vulnerability(input_query: str):
402408
r'\b(SELECT\s+\*\s+FROM\s+\w+\s+WHERE\s+\w+\s*=\s*[\'"].*?[\'"]\s*;?\s*--)',
403409
r'\b(INSERT\s+INTO\s+\w+\s+\(\s*\w+\s*,\s*\w+\s*\)\s+VALUES\s*\(\s*[\'"].*?[\'"]\s*,\s*[\'"].*?[\'"]\s*\)\s*;?\s*--)',
404410
r"\b(DROP\s+TABLE|ALTER\s+TABLE|admin\'--)", # DROP TABLE/ALTER TABLE
405-
r"(?:'|\”|--|#|‘\s*OR\s*‘1|‘\s*OR\s*\d+\s*--\s*-|\"\s*OR\s*\"\" = \"|\"\s*OR\s*\d+\s*=\s*\d+\s*--\s*-|’\s*OR\s*''\s*=\s*‘|‘=‘|'=0--+|OR\s*\d+\s*=\s*\d+|‘\s*OR\s*‘x’=‘x’|AND\s*id\s*IS\s*NULL;\s*--|‘’’’’’’’’’’’’UNION\s*SELECT\s*‘\d+|%00|/\*.*?\*/|\+|\|\||%|@\w+|@@\w+)",
411+
r"(?:'|\”|--|#|‘\s*OR\s*‘1|‘\s*OR\s*\d+\s*--\s*-|\"\s*OR\s*\"\" = \"|\"\s*OR\s*\d+\s*=\s*\d+\s*--\s*-|’\s*OR\s*''\s*=\s*‘|‘=‘|'=0--+|OR\s*\d+\s*=\s*\d+|‘\s*OR\s*‘x’=‘x’|AND\s*id\s*IS\s*NULL;\s*--|‘’’’’’’’’’’’’UNION\s*SELECT\s*‘\d+|%00|/\*.*?\*/|\+|\|\||@\w+|@@\w+)",
406412
r"AND\s[01]|AND\s(true|false)|[01]-((true|false))",
407413
r"\d+'\s*ORDER\s*BY\s*\d+--\+|\d+'\s*GROUP\s*BY\s*(\d+,)*\d+--\+|'\s*GROUP\s*BY\s*columnnames\s*having\s*1=1\s*--",
408414
r"\bUNION\b\s+\b(?:ALL\s+)?\bSELECT\b\s+[A-Za-z0-9]+", # Union Based

start.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
logging.info(f"Download model...")
1010
base_path = (Path(__file__).parent).resolve()
11+
# Model 1:
1112
snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/")
13+
# Model 2:
14+
snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/")
1215
logging.info(f"Download embedding model...")
1316
snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/")
1417

ui/app.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ async def chat(q: Q):
8686
original_name = meta_data[table].get("original_name", q.user.original_name)
8787
table_names.append(ui.choice(table, f"{original_name}"))
8888

89+
model_choices = [
90+
ui.choice("h2ogpt-sql-nsql-llama-2-7B", "h2ogpt-sql-nsql-llama-2-7B"),
91+
ui.choice("h2ogpt-sql-sqlcoder2", "h2ogpt-sql-sqlcoder2"),
92+
]
93+
q.user.model_choice_dropdown = "h2ogpt-sql-sqlcoder2"
8994
add_card(
9095
q,
9196
"background_card",
@@ -111,7 +116,15 @@ async def chat(q: Q):
111116
choices=table_names,
112117
value=q.user.table_name if q.user.table_name else None,
113118
trigger=True,
114-
)
119+
),
120+
ui.dropdown(
121+
name="model_choice_dropdown",
122+
label="Model Choice",
123+
required=True,
124+
choices=model_choices,
125+
value=q.user.model_choice_dropdown if q.user.model_choice_dropdown else None,
126+
trigger=True,
127+
),
115128
],
116129
),
117130
)
@@ -209,6 +222,7 @@ async def chatbot(q: Q):
209222
sample_queries_path=q.user.sample_qna_path,
210223
table_info_path=q.user.table_info_path,
211224
table_name=q.user.table_name,
225+
model_name=q.user.model_choice_dropdown,
212226
is_regenerate=True,
213227
is_regen_with_options=False,
214228
)
@@ -227,6 +241,7 @@ async def chatbot(q: Q):
227241
sample_queries_path=q.user.sample_qna_path,
228242
table_info_path=q.user.table_info_path,
229243
table_name=q.user.table_name,
244+
model_name=q.user.model_choice_dropdown,
230245
is_regenerate=False,
231246
is_regen_with_options=True,
232247
)
@@ -248,6 +263,7 @@ async def chatbot(q: Q):
248263
sample_queries_path=q.user.sample_qna_path,
249264
table_info_path=q.user.table_info_path,
250265
table_name=q.user.table_name,
266+
model_name=q.user.model_choice_dropdown,
251267
)
252268
llm_response = "\n".join(llm_response)
253269
except (MemoryError, RuntimeError) as e:
@@ -567,12 +583,23 @@ async def on_event(q: Q):
567583
elif q.args.regenerate:
568584
q.args.chatbot = "regenerate"
569585

570-
if q.args.table_dropdown and not q.args.chatbot:
586+
if q.args.table_dropdown and not q.args.chatbot and q.user.table_name != q.args.table_dropdown:
571587
logging.info(f"User selected table: {q.args.table_dropdown}")
572588
await submit_table(q)
573589
q.args.chatbot = f"Table {q.args.table_dropdown} selected"
574590
# Refresh response is triggered when user selects a table via dropdown
575591
event_handled = True
592+
if (
593+
q.args.model_choice_dropdown
594+
and not q.args.chatbot
595+
and q.user.model_choice_dropdown != q.args.model_choice_dropdown
596+
):
597+
logging.info(f"User selected model type: {q.args.model_choice_dropdown}")
598+
q.user.model_choice_dropdown = q.args.model_choice_dropdown
599+
q.page["select_tables"].model_choice_dropdown.value = q.user.model_choice_dropdown
600+
q.args.chatbot = f"Model {q.args.model_choice_dropdown} selected"
601+
# Refresh response is triggered when user selects a table via dropdown
602+
event_handled = True
576603

577604
if q.args.save_conversation or (q.args.chatbot and "save the qna pair:" in q.args.chatbot.lower()):
578605
question = q.client.query

0 commit comments

Comments
 (0)