Skip to content

Commit 6d35db8

Browse files
authored
Change the default dialect to SQLite (#13)
1 parent 25a9570 commit 6d35db8

File tree

6 files changed

+54
-31
lines changed

6 files changed

+54
-31
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ setup: download_models ## Setup
99
./.sidekickvenv/bin/python3 -m pip install --upgrade pip
1010
./.sidekickvenv/bin/python3 -m pip install wheel
1111
./.sidekickvenv/bin/python3 -m pip install -r requirements.txt
12+
mkdir -p ./db/sqlite
1213

1314
download_models:
1415
mkdir -p ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2

sidekick/configs/.env.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-
44

55
[LOCAL_DB_CONFIG]
66
HOST_NAME = "localhost"
7-
USER_NAME = "postgres"
7+
USER_NAME = "sqlite"
88
PASSWORD = "abc"
99
DB_NAME = "querydb"
1010
PORT = "5432"
@@ -13,4 +13,4 @@ PORT = "5432"
1313
LOG-LEVEL = "INFO"
1414

1515
[DB-DIALECT]
16-
DB_TYPE = "postgresql"
16+
DB_TYPE = "sqlite"

sidekick/configs/prompt_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"""
5151

5252
DEBUGGING_PROMPT = {
53-
"system_prompt": "Act as a SQL expert for PostgreSQL code",
53+
"system_prompt": "Act as a SQL expert for {_dialect} code",
5454
"user_prompt": """
5555
### Fix syntax errors for provided incorrect SQL Query.
5656
# Add ``` as prefix and ``` as suffix to generated SQL

sidekick/db_config.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
base_path,
2424
schema_info_path=None,
2525
schema_info=None,
26-
dialect="postgresql",
26+
dialect="sqlite",
2727
) -> None:
2828
self.db_name = db_name
2929
self.hostname = hostname
@@ -36,7 +36,10 @@ def __init__(
3636
self._engine = None
3737
self.dialect = dialect
3838
self.base_path = base_path
39-
self._url = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/"
39+
if dialect == "sqlite":
40+
self._url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
41+
else:
42+
self._url = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/"
4043

4144
@property
4245
def table_name(self):
@@ -51,18 +54,26 @@ def engine(self):
5154
return self._engine
5255

5356
def db_exists(self):
54-
engine = create_engine(f"{self._url}{self.db_name}", echo=True)
57+
if self.dialect == "sqlite":
58+
engine = create_engine(f"{self._url}", echo=True)
59+
else:
60+
engine = create_engine(f"{self._url}{self.db_name}", echo=True)
5561
return database_exists(f"{engine.url}")
5662

5763
def create_db(self):
5864
engine = create_engine(self._url)
5965
self._engine = engine
6066

6167
with engine.connect() as conn:
62-
conn.execute("commit")
68+
# conn.execute("commit")
6369
# Do not substitute user-supplied database names here.
64-
res = conn.execute(f"CREATE DATABASE {self.db_name}")
65-
return res
70+
if self.dialect != "sqlite":
71+
conn.execute("commit")
72+
res = conn.execute(f"CREATE DATABASE {self.db_name}")
73+
self._url = f"{self._url}{self.db_name}"
74+
return res
75+
else:
76+
logger.debug("SQLite DB is created when 'engine.connect()' is called")
6677

6778
def _extract_schema_info(self, schema_info_path=None):
6879
# From jsonl format
@@ -91,7 +102,7 @@ def _extract_schema_info(self, schema_info_path=None):
91102

92103
def create_table(self, schema_info_path=None, schema_info=None):
93104
engine = create_engine(
94-
f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
105+
self._url, isolation_level="AUTOCOMMIT"
95106
)
96107
self._engine = engine
97108
if self.schema_info is None:
@@ -119,20 +130,22 @@ def create_table(self, schema_info_path=None, schema_info=None):
119130
)
120131
"""
121132
with engine.connect() as conn:
122-
conn.execute("commit")
133+
if self.dialect != "sqlite":
134+
conn.execute("commit")
123135
conn.execute(create_syntax)
124136
return
125137

126138
def has_table(self):
127139
engine = create_engine(
128-
f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
140+
self._url
129141
)
142+
130143
return sqlalchemy.inspect(engine).has_table(self.table_name)
131144

132145
def add_samples(self, data_csv_path=None):
133-
conn_str = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
146+
conn_str = self._url
134147
try:
135-
df = pd.read_csv(data_csv_path, infer_datetime_format=True)
148+
df = pd.read_csv(data_csv_path)
136149
engine = create_engine(conn_str, isolation_level="AUTOCOMMIT")
137150

138151
sample_query = f"SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1"
@@ -153,12 +166,14 @@ def add_samples(self, data_csv_path=None):
153166

154167
def execute_query_db(self, query=None, n_rows=100):
155168
output = []
169+
if self.dialect != "sqlite":
170+
conn_str = f"{self._url}{self.db_name}"
171+
else:
172+
conn_str = self._url
173+
156174
try:
157175
if query:
158176
# Create an engine
159-
conn_str = (
160-
f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
161-
)
162177
engine = create_engine(conn_str)
163178

164179
# Create a connection

sidekick/prompter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
112112
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
113113

114114

115-
@configure.command("db-setup", help="Enter information to configure postgres database locally")
115+
@configure.command("db-setup", help=f"Enter information to configure {db_dialect} database locally")
116116
@click.option("--db_name", "-n", default="querydb", help="Database name", prompt="Enter Database name")
117117
@click.option("--hostname", "-h", default="localhost", help="Database hostname", prompt="Enter hostname name")
118-
@click.option("--user_name", "-u", default="postgres", help="Database username", prompt="Enter username name")
118+
@click.option("--user_name", "-u", default=f"{db_dialect}", help="Database username", prompt="Enter username name")
119119
@click.option(
120120
"--password",
121121
"-p",
@@ -141,8 +141,11 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
141141
f.close()
142142
path = f"{base_path}/var/lib/tmp/data"
143143
# For current session
144-
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path)
145-
if not db_obj.db_exists():
144+
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect)
145+
if db_obj.dialect == 'sqlite' and not os.path.isfile(f"{base_path}/db/sqlite/{db_name}.db"):
146+
db_obj.create_db()
147+
click.echo("Database created successfully!")
148+
elif not db_obj.db_exists():
146149
db_obj.create_db()
147150
click.echo("Database created successfully!")
148151
else:
@@ -293,9 +296,12 @@ def query(question: str, table_info_path: str, sample_queries: str):
293296
passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
294297
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]
295298

296-
db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format(
297-
user_name, passwd, host_name, db_name
298-
)
299+
if db_dialect == "sqlite":
300+
db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db"
301+
else:
302+
db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format(
303+
user_name, passwd, host_name, db_name
304+
)
299305

300306
if table_info_path is None:
301307
table_info_path = _get_table_info(path)
@@ -318,7 +324,7 @@ def query(question: str, table_info_path: str, sample_queries: str):
318324
sql_g._tasks = updated_tasks
319325

320326
model_name = env_settings["OPENAI"]["MODEL_NAME"]
321-
res = sql_g.generate_sql(table_names, question, model_name=model_name)
327+
res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
322328
logger.info(f"Input query: {question}")
323329
logger.info(f"Generated response:\n\n{res}")
324330

@@ -335,7 +341,7 @@ def query(question: str, table_info_path: str, sample_queries: str):
335341
click.echo(f"Updated SQL:\n {updated_sql}")
336342
elif res_val.lower() == "r" or res_val.lower() == "regenerate":
337343
click.echo("Attempting to regenerate...")
338-
res = sql_g.generate_sql(table_names, question, model_name=model_name)
344+
res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
339345
logger.info(f"Input query: {question}")
340346
logger.info(f"Generated response:\n\n{res}")
341347

@@ -351,7 +357,8 @@ def query(question: str, table_info_path: str, sample_queries: str):
351357
port = env_settings["LOCAL_DB_CONFIG"]["PORT"]
352358
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]
353359

354-
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path)
360+
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect)
361+
355362
output_res = db_obj.execute_query_db(query=_val)
356363
click.echo(f"The query results are:\n {output_res}")
357364
elif option == "pandas":

sidekick/query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _query_tasks(self, question_str, data_info, sample_queries, table_name: list
123123
res = ex_value.statement if ex_value.statement else None
124124
return res
125125

126-
def generate_response(self, context_container, sql_index, input_prompt, attempt_fix_on_error: bool = True):
126+
def generate_response(self, context_container, sql_index, input_prompt, attempt_fix_on_error: bool = True, _dialect: str = "sqlite"):
127127
try:
128128
response = sql_index.query(input_prompt, sql_context_container=context_container)
129129
res = response.extra_info["sql_query"]
@@ -137,7 +137,7 @@ def generate_response(self, context_container, sql_index, input_prompt, attempt_
137137
# Attempt to heal with simple feedback
138138
# Reference: Teaching Large Language Models to Self-Debug, https://arxiv.org/abs/2304.05128
139139
logger.info(f"Attempting to fix syntax error ...,\n {se}")
140-
system_prompt = DEBUGGING_PROMPT["system_prompt"]
140+
system_prompt = DEBUGGING_PROMPT["system_prompt"].format(_dialect=_dialect)
141141
user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=ex_traceback, qry_txt=qry_txt)
142142
# Role and content
143143
query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
@@ -192,7 +192,7 @@ def generate_tasks(self, table_names: list, input_question: str):
192192
raise se
193193

194194
def generate_sql(
195-
self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "gpt-3.5-turbo-0301"
195+
self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "gpt-3.5-turbo-0301"
196196
):
197197
_tasks = self.task_formatter(self._tasks)
198198
context_file = f"{self.path}/var/lib/tmp/data/context.json"
@@ -223,7 +223,7 @@ def generate_sql(
223223
index = GPTSQLStructStoreIndex(
224224
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
225225
)
226-
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str)
226+
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str, _dialect = _dialect)
227227
try:
228228
# Check if `SQL` is formatted ---> ``` SQL_text ```
229229
if "```" in str(res):

0 commit comments

Comments
 (0)