Skip to content

Commit abe06ef

Browse files
Pass Db dialect properly
2 parents 9fec65a + 7bdd709 commit abe06ef

File tree

9 files changed

+1172
-73
lines changed

9 files changed

+1172
-73
lines changed

app.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ title = "SQL-Sidekick"
44
description = "QnA with tabular data using NLQ"
55
LongDescription = "about.md"
66
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
7-
Version = "0.1.7"
7+
Version = "0.1.8"
88

99
[Runtime]
1010
MemoryLimit = "64Gi"

examples/notebooks/databricks_db.ipynb

Lines changed: 1065 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-sidekick"
3-
version = "0.1.7"
3+
version = "0.1.8"
44
license = "Proprietary"
55
description = "An AI assistant for SQL"
66
authors = [

sidekick/configs/prompt_template.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
# Reference: https://arxiv.org/pdf/2005.14165.pdf
2727
QUERY_PROMPT = """
2828
### System: Act as a SQL Expert
29-
# For table {_table_name}, given an input *Question*, only generate syntactically correct SQL queries.
29+
# For table {_table_name}, given an input *Question*, only generate syntactically correct {dialect} SQL queries.
3030
# Let's work it out in a detailed step by step way using the reasoning from *Tasks* section.
3131
# Pick the SQL query which has the highest average log probability if more than one result is likely to answer the
3232
candidate *Question*.
33-
### {_dialect} SQL tables
33+
### {dialect} SQL tables
3434
### *Data:* \nFor table {_table_name} schema info is mentioned below,\n{_data_info}
3535
### *History*:\n{_sample_queries}
3636
### *Question*: For table {_table_name}, {_question}
@@ -52,7 +52,7 @@
5252
"""
5353

5454
DEBUGGING_PROMPT = {
55-
"system_prompt": "Act as a SQL expert for {_dialect} code",
55+
"system_prompt": "Act as a SQL expert for {dialect} code",
5656
"user_prompt": """
5757
### Fix syntax errors for provided incorrect SQL Query.
5858
# Add ``` as prefix and ``` as suffix to generated SQL
@@ -63,7 +63,7 @@
6363
}
6464

6565
NSQL_QUERY_PROMPT = """
66-
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries})
66+
For {dialect} SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries})
6767
6868
CREATE TABLE '{table_name}'({column_info}
6969
)
@@ -72,7 +72,7 @@
7272
7373
7474
75-
-- Using valid and syntactically correct SQLite query, answer the following questions (check for typos, grammatical and spelling errors and fix them) with the information for '{table_name}' provided above; for final SQL only use column names from the CREATE TABLE (Do not query for columns that do not exist).
75+
-- Using valid and syntactically correct {dialect} SQL syntax, answer the following questions (check for typos, grammatical and spelling errors and fix them) with the information for '{table_name}' provided above; for final SQL only use column names from the CREATE TABLE (Do not query for columns that do not exist).
7676
7777
7878
-- Using reference for TABLES '{table_name}' {context}; {question_txt}?
@@ -82,7 +82,7 @@
8282
# https://colab.research.google.com/drive/13BIKsqHnPOBcQ-ba2p77L5saiepTIwu0#scrollTo=0eI-VpCkf-fN
8383
STARCODER2_PROMPT = """
8484
### Instructions:
85-
Your task is convert a question into a valid SQLite SQL query, given a sqlite database schema. Let's work this out step by step to be sure we have the right answer.
85+
Your task is convert a question into a valid {dialect} syntax SQL query, given a {dialect} database schema. Let's work this out step by step to be sure we have the right answer.
8686
Only use the column names from the CREATE TABLE statement.
8787
Adhere to these rules:
8888
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
@@ -101,7 +101,7 @@
101101
102102
103103
### Input:
104-
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a valid SQL (dialect:SQLite) query to answer the following question:\n{question_txt}.
104+
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a valid SQL (dialect:{dialect}) query to answer the following question:\n{question_txt}.
105105
This query will run on a database whose schema is represented in this string:
106106
CREATE TABLE '{table_name}' ({column_info}
107107
);

sidekick/db_config.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(
3939
self.base_path = base_path
4040
self.column_names = []
4141
if dialect == "sqlite":
42-
logger.debug(f"Creating SQLite DB: sqlite:///{base_path}/db/sqlite/{db_name}.db")
4342
self._url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
4443
else:
4544
self._url = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/"
@@ -86,63 +85,77 @@ def create_db(self):
8685
logger.debug("Error Occurred:", error)
8786
return None, error
8887

89-
def _extract_schema_info(self, schema_info_path=None):
88+
89+
def _parser(self, file_handle=None, schema_info=None):
90+
sample_values = []
91+
res = []
92+
_lines = file_handle if file_handle else schema_info
93+
for line in _lines:
94+
data = json.loads(line) if isinstance(line, str) and line.strip() else line
95+
if "Column Name" in data and "Column Type" in data:
96+
col_name = data["Column Name"]
97+
self.column_names.append(col_name)
98+
col_type = data["Column Type"]
99+
if col_type.lower() == "text":
100+
col_type = col_type + " COLLATE NOCASE"
101+
# if column has sample values, save in cache for future use.
102+
if "Sample Values" in data:
103+
_sample_values = data["Sample Values"]
104+
_ds = data_samples_template.format(
105+
column_name=col_name,
106+
comma_separated_sample_values=",".join(
107+
str(_sample_val) for _sample_val in _sample_values
108+
),
109+
)
110+
sample_values.append(_ds)
111+
_new_samples = f"{col_name} {col_type}"
112+
res.append(_new_samples)
113+
return res, sample_values
114+
115+
116+
def _extract_schema_info(self, schema=None, schema_path=None):
90117
# From jsonl format
91118
# E.g. {"Column Name": "id", "Column Type": "uuid PRIMARY KEY"}
92-
if schema_info_path is None:
93-
table_info_file = f"{self.base_path}/var/lib/tmp/data/table_context.json"
94-
if Path(table_info_file).exists():
95-
with open(table_info_file, "w") as outfile:
96-
schema_info_path = json.load(outfile)["schema_info_path"]
97119
res = []
98120
sample_values = []
99121
try:
100-
logger.debug(f"Schema path: {schema_info_path}")
101-
if Path(schema_info_path).exists():
102-
with open(schema_info_path, "r") as in_file:
103-
for line in in_file:
104-
if line.strip():
105-
data = json.loads(line)
106-
if "Column Name" in data and "Column Type" in data:
107-
col_name = data["Column Name"]
108-
self.column_names.append(col_name)
109-
col_type = data["Column Type"]
110-
if col_type.lower() == "text":
111-
col_type = col_type + " COLLATE NOCASE"
112-
# if column has sample values, save in cache for future use.
113-
if "Sample Values" in data:
114-
_sample_values = data["Sample Values"]
115-
_ds = data_samples_template.format(
116-
column_name=col_name,
117-
comma_separated_sample_values=",".join(
118-
str(_sample_val) for _sample_val in _sample_values
119-
),
120-
)
121-
sample_values.append(_ds)
122-
_new_samples = f"{col_name} {col_type}"
123-
res.append(_new_samples)
124-
if len(sample_values) > 0:
125-
# cache it for future use
126-
with open(
127-
f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w"
128-
) as outfile:
129-
json.dump(sample_values, outfile, indent=2, sort_keys=False)
122+
if schema is not None:
123+
logger.debug(f"Using passed schema information.")
124+
res, sample_values = self._parser(schema_info=schema)
125+
else:
126+
if schema_path is None:
127+
table_info_file = f"{self.base_path}/var/lib/tmp/data/table_context.json"
128+
if Path(table_info_file).exists():
129+
with open(table_info_file, "w") as outfile:
130+
schema_path = json.load(outfile)["schema_info_path"]
131+
if Path(schema_path).exists():
132+
logger.debug(f"Using schema information from: {schema_path}")
133+
with open(schema_path, "r") as in_file:
134+
res, sample_values = self._parser(file_handle=in_file)
135+
if len(sample_values) > 0:
136+
# cache it for future use
137+
with open(
138+
f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w"
139+
) as outfile:
140+
json.dump(sample_values, outfile, indent=2, sort_keys=False)
130141
except ValueError as ve:
131142
logger.error(f"Error in reading table context file: {ve}")
132143
pass
133144
return res
134145

135-
def create_table(self, schema_info_path: str=None, schema_info=None):
146+
def create_table(self, schema_info_path=None, schema_info=None):
136147
try:
137148
engine = create_engine(self._url, isolation_level="AUTOCOMMIT")
138149
self._engine = engine
139-
if self.schema_info is None:
140-
if schema_info is not None:
141-
self.schema_info = schema_info
142-
else:
143-
# If schema information is not provided, extract from the template.
144-
self.schema_info = """,\n""".join(self._extract_schema_info(schema_info_path)).strip()
145-
logger.debug(f"Schema info used for creating table:\n {self.schema_info}")
150+
if self.schema_info is None and schema_info_path:
151+
# If schema information is not provided, extract from the template.
152+
self.schema_info = """,\n""".join(self._extract_schema_info(schema_path=schema_info_path)).strip()
153+
else:
154+
self.schema_info = """,\n""".join(self._extract_schema_info(schema=schema_info)).strip()
155+
156+
logger.debug(f"Schema info used for creating table:\n {self.schema_info}")
157+
# Currently, multiple tables is not supported.
158+
# TODO https://github.com/h2oai/sql-sidekick/issues/62
146159
create_syntax = f"""
147160
CREATE TABLE IF NOT EXISTS {self.table_name} (
148161
{self.schema_info}

sidekick/prompter.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
from pathlib import Path
6+
from typing import Optional
67

78
import click
89
import openai
@@ -212,12 +213,13 @@ def db_setup(
212213
user_name: str,
213214
password: str,
214215
port: int,
215-
table_info_path: str,
216-
table_samples_path: str,
217216
table_name: str,
217+
table_info_path: Optional[str] = None,
218+
table_schema: Optional[list] = None,
219+
table_samples_path: Optional[str] = None,
218220
add_sample: bool=True,
219221
is_command: bool = False,
220-
local_base_path: str = None
222+
local_base_path: Optional[str] = None
221223
):
222224
"""Creates context for the new Database"""
223225
click.echo(f" Information supplied:\n {db_name}, {hostname}, {user_name}, {password}, {port}")
@@ -264,7 +266,7 @@ def db_setup(
264266
else:
265267
break
266268

267-
if table_info_path is None:
269+
if table_info_path is None and table_schema is None:
268270
logger.debug(f"Retrieve meta information for table {table_name}")
269271
table_info_path = _get_table_info(path, table_name)
270272
logger.debug(f"Updated table info path: {table_info_path}")
@@ -274,7 +276,11 @@ def db_setup(
274276
click.echo(f"Table name: {table_value}")
275277
# set table name
276278
db_obj.table_name = table_value.lower().replace(" ", "_")
277-
res, err = db_obj.create_table(table_info_path)
279+
if table_schema:
280+
res, err = db_obj.create_table(schema_info=table_schema)
281+
else:
282+
if table_info_path:
283+
res, err = db_obj.create_table(schema_info_path=table_info_path)
278284

279285
update_table_info(path, table_info_path, db_obj.table_name)
280286
# Check if table exists; pending --> and doesn't have any rows
@@ -407,11 +413,13 @@ def ask(
407413
sample_queries_path: str,
408414
table_name: str,
409415
model_name: str = "h2ogpt-sql-nsql-llama-2-7B",
416+
db_dialect = "sqlite",
417+
execute_db_dialect="sqlite",
410418
is_regenerate: bool = False,
411419
is_regen_with_options: bool = False,
412420
is_command: bool = False,
413421
execute_query: bool = True,
414-
local_base_path = None
422+
local_base_path = None,
415423
):
416424
"""Asks question and returns SQL."""
417425
results = []
@@ -438,6 +446,7 @@ def ask(
438446
with open(f"{path}/table_context.json", "w") as outfile:
439447
json.dump(table_context, outfile, indent=4, sort_keys=False)
440448
logger.info(f"Table in use: {table_names}")
449+
logger.info(f"SQL dialect for generation: {db_dialect}")
441450
# Check if env.toml file exists
442451
api_key = os.getenv("OPENAI_API_KEY", None)
443452
if (model_name == 'gpt-3.5-turbo-0301' or model_name == 'gpt-3.5-turbo-1106') and api_key is None:
@@ -477,16 +486,18 @@ def ask(
477486
passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
478487
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]
479488

480-
if db_dialect == "sqlite":
489+
if execute_db_dialect.lower() == "sqlite":
481490
db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
482-
else:
483-
db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format(
491+
elif execute_db_dialect.lower() == "postgresql":
492+
db_url = f"{execute_db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format(
484493
user_name, passwd, host_name, db_name
485494
)
495+
else:
496+
db_url = None
486497

487498
if table_info_path is None:
488499
table_info_path = _get_table_info(path, table_name)
489-
logger.debug(f"Table info path: {table_info_path}")
500+
logger.debug(f"Table info path: {table_info_path}")
490501

491502
sql_g = SQLGenerator(
492503
db_url,
@@ -497,6 +508,7 @@ def ask(
497508
sample_queries_path=sample_queries_path,
498509
is_regenerate_with_options=is_regen_with_options,
499510
is_regenerate=is_regenerate,
511+
db_dialect=db_dialect
500512
)
501513
if "h2ogpt-sql" not in model_name and not _execute_sql(question):
502514
sql_g._tasks = sql_g.generate_tasks(table_names, question)
@@ -531,7 +543,7 @@ def ask(
531543
_check_cond = question.strip().lower().split("execute sql:")
532544
if len(_check_cond) > 1:
533545
question = question.strip().lower().split("execute sql:")[1].strip()
534-
res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
546+
res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name)
535547
logger.info(f"Input query: {question}")
536548
logger.info(f"Generated response:\n\n{res}")
537549

sidekick/query.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __new__(
3939
model_name="h2ogpt-sql-nsql-llama-2-7B",
4040
data_input_path: str = "./table_info.jsonl",
4141
sample_queries_path: str = "./samples.csv",
42+
db_dialect = "sqlite",
4243
job_path: str = "./",
4344
device: str = "auto",
4445
is_regenerate: bool = False,
@@ -98,12 +99,14 @@ def __init__(
9899
sample_queries_path: str = "./samples.csv",
99100
job_path: str = "./",
100101
device: str = "cpu",
102+
db_dialect = "sqlite",
101103
is_regenerate: bool = False,
102104
is_regenerate_with_options: bool = False,
103105
):
104106
self.db_url = db_url
105107
self.engine = create_engine(db_url) if db_url else None
106108
self.sql_database = SQLDatabase(self.engine) if self.engine else None
109+
self.dialect = db_dialect
107110
self.context_builder = None
108111
self.data_input_path = _check_file_info(data_input_path)
109112
self.sample_queries_path = sample_queries_path
@@ -218,7 +221,7 @@ def _query_tasks(self, question_str, data_info, sample_queries, table_name: list
218221
return res
219222

220223
def generate_response(
221-
self, sql_index, input_prompt, attempt_fix_on_error: bool = True, _dialect: str = "sqlite"
224+
self, sql_index, input_prompt, attempt_fix_on_error: bool = True
222225
):
223226
try:
224227
_sql_index = sql_index.as_query_engine()
@@ -234,7 +237,7 @@ def generate_response(
234237
# Attempt to heal with simple feedback
235238
# Reference: Teaching Large Language Models to Self-Debug, https://arxiv.org/abs/2304.05128
236239
logger.info(f"Attempting to fix syntax error ...,\n {se}")
237-
system_prompt = DEBUGGING_PROMPT["system_prompt"].format(_dialect=_dialect)
240+
system_prompt = DEBUGGING_PROMPT["system_prompt"].format(dialect=self.dialect)
238241
user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=ex_traceback, qry_txt=qry_txt)
239242
# Role and content
240243
query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
@@ -304,7 +307,6 @@ def generate_sql(
304307
self,
305308
table_names: list,
306309
input_question: str,
307-
_dialect: str = "sqlite",
308310
model_name: str = "h2ogpt-sql-nsql-llama-2-7B",
309311
):
310312
# TODO: Update needed to support multiple tables
@@ -328,7 +330,7 @@ def generate_sql(
328330

329331
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
330332
query_str = QUERY_PROMPT.format(
331-
_dialect=_dialect,
333+
dialect=self.dialect,
332334
_data_info=self._data_info,
333335
_question=input_question,
334336
_table_name=table_names,
@@ -368,7 +370,7 @@ def generate_sql(
368370
)
369371
else:
370372
res = str(result).split("Explanation:", 1)[0].strip()
371-
res = sqlglot.transpile(res, identify=True, read=_dialect)[0]
373+
res = sqlglot.transpile(res, identify=True, write=self.dialect)[0]
372374
result = res
373375
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
374376
logger.info("We did the best we could, there might be still be some error:\n")
@@ -488,6 +490,7 @@ def generate_sql(
488490
sample_queries=qna_samples,
489491
context=contextual_context_val,
490492
question_txt=input_question,
493+
dialect=self.dialect
491494
)
492495

493496
logger.debug(f"Query Text:\n {query}")
@@ -649,7 +652,7 @@ def generate_sql(
649652
# Reference ticket: https://github.com/tobymao/sqlglot/issues/2011
650653
result = res
651654
try:
652-
result = sqlglot.transpile(res, identify=True, write=_dialect)[0]
655+
result = sqlglot.transpile(res, identify=True, write=self.dialect)[0]
653656
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
654657
logger.info("We did the best we could, there might be still be some error:\n")
655658
logger.info(f"Realized query so far:\n {res}")

0 commit comments

Comments
 (0)