Skip to content

Commit 25a9570

Browse files
Partial logic for basic UI workflow + More improvements (#10)
* Initial demo UI - courtesy Megan/narasimhard * Update version n add h2o-wave dependency * UI relate * Few more updates n corrections * Save the right generated SQL * Add pandasql dependency
1 parent b33db8d commit 25a9570

File tree

9 files changed

+298
-65
lines changed

9 files changed

+298
-65
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-sidekick"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
license = "Proprietary"
55
description = "An AI assistant for SQL"
66
authors = [
@@ -36,6 +36,8 @@ transformers = "^4.29.0"
3636
sentence-transformers = "^2.2.2"
3737
torch = "^2.0.1"
3838
sqlalchemy-utils = "^0.41.1"
39+
h2o-wave = "0.26.1"
40+
pandasql = "0.7.3"
3941

4042
[tool.poetry.scripts]
4143
sql-sidekick = "sidekick.prompter:cli"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ fsspec==2023.5.0 ; python_full_version >= "3.8.16" and python_version < "3.10"
1414
gptcache==0.1.29.1 ; python_full_version >= "3.8.16" and python_version < "3.10"
1515
greenlet==2.0.2 ; python_full_version >= "3.8.16" and platform_machine == "aarch64" and python_version < "3.10" or python_full_version >= "3.8.16" and platform_machine == "ppc64le" and python_version < "3.10" or python_full_version >= "3.8.16" and platform_machine == "x86_64" and python_version < "3.10" or python_full_version >= "3.8.16" and platform_machine == "amd64" and python_version < "3.10" or python_full_version >= "3.8.16" and platform_machine == "AMD64" and python_version < "3.10" or python_full_version >= "3.8.16" and platform_machine == "win32" and python_version < "3.10" or python_full_version >= "3.8.16" and platform_machine == "WIN32" and python_version < "3.10"
1616
huggingface-hub==0.15.1 ; python_full_version >= "3.8.16" and python_version < "3.10"
17+
h2o-wave==0.26.1 ; python_full_version >= "3.8.16" and python_version < "3.10"
1718
idna==3.4 ; python_full_version >= "3.8.16" and python_version < "3.10"
1819
jinja2==3.1.2 ; python_full_version >= "3.8.16" and python_version < "3.10"
1920
joblib==1.2.0 ; python_full_version >= "3.8.16" and python_version < "3.10"

sidekick/configs/.env.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
[OPENAI]
22
OPENAI_API_KEY = ""
3-
MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. text-davinci-003
3+
MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
44

55
[LOCAL_DB_CONFIG]
66
HOST_NAME = "localhost"
77
USER_NAME = "postgres"
88
PASSWORD = "abc"
99
DB_NAME = "querydb"
10+
PORT = "5432"
1011

1112
[LOGGING]
1213
LOG-LEVEL = "INFO"

sidekick/db_config.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# create db with supplied info
22
import json
33
from pathlib import Path
4-
import pandas as pd
54

5+
import pandas as pd
66
import psycopg2 as pg
77
import sqlalchemy
8-
from psycopg2.extras import Json
98
from pandasql import sqldf
9+
from psycopg2.extras import Json
1010
from sidekick.logger import logger
1111
from sqlalchemy import create_engine
1212
from sqlalchemy_utils import database_exists
@@ -133,19 +133,17 @@ def add_samples(self, data_csv_path=None):
133133
conn_str = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
134134
try:
135135
df = pd.read_csv(data_csv_path, infer_datetime_format=True)
136-
engine = create_engine(conn_str, isolation_level='AUTOCOMMIT')
136+
engine = create_engine(conn_str, isolation_level="AUTOCOMMIT")
137137

138-
sample_query = f'SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1'
138+
sample_query = f"SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1"
139139
num_rows_bef = pd.read_sql_query(sample_query, engine)
140140

141141
# Write rows to database
142-
res = df.to_sql(self.table_name, engine, if_exists='append', index=False)
142+
df.to_sql(self.table_name, engine, if_exists="append", index=False)
143143

144144
# Fetch the number of rows from the table
145145
num_rows_aft = pd.read_sql_query(sample_query, engine)
146-
147146
logger.info(f"Number of rows inserted: {num_rows_aft.iloc[0, 0] - num_rows_bef.iloc[0, 0]}")
148-
149147
engine.dispose()
150148

151149
except Exception as e:
@@ -154,24 +152,25 @@ def add_samples(self, data_csv_path=None):
154152
engine.dispose()
155153

156154
def execute_query_db(self, query=None, n_rows=100):
155+
output = []
157156
try:
158157
if query:
159158
# Create an engine
160-
conn_str = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
159+
conn_str = (
160+
f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
161+
)
161162
engine = create_engine(conn_str)
162163

163164
# Create a connection
164165
connection = engine.connect()
165-
166166
result = connection.execute(query)
167167

168168
# Process the query results
169169
cnt = 0
170-
logger.info("Here are the results from the queries: ")
171170
for row in result:
172171
if cnt <= n_rows:
173172
# Access row data using row[column_name]
174-
logger.info(row)
173+
output.append(row)
175174
cnt += 1
176175
else:
177176
break
@@ -182,6 +181,7 @@ def execute_query_db(self, query=None, n_rows=100):
182181
engine.dispose()
183182
else:
184183
logger.info("Query Empty or None!")
184+
return output
185185
except Exception as e:
186186
logger.info(f"Error occurred : {format(e)}")
187187
finally:

sidekick/prompter.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@
99
from colorama import Fore as F
1010
from colorama import Style
1111
from loguru import logger
12+
from pandasql import sqldf
1213
from sidekick.db_config import DBConfig
1314
from sidekick.memory import EntityMemory
1415
from sidekick.query import SQLGenerator
15-
from sidekick.utils import save_query, setup_dir, extract_table_names, execute_query_pd
16+
from sidekick.utils import (execute_query_pd, extract_table_names, save_query,
17+
setup_dir)
1618

1719
# Load the config file and initialize required paths
1820
base_path = (Path(__file__).parent / "../").resolve()
1921
env_settings = toml.load(f"{base_path}/sidekick/configs/.env.toml")
2022
db_dialect = env_settings["DB-DIALECT"]["DB_TYPE"]
2123
os.environ["TOKENIZERS_PARALLELISM"] = "False"
22-
__version__ = "0.0.3"
24+
__version__ = "0.0.4"
2325

2426

2527
def color(fore="", back="", text=None):
@@ -51,8 +53,9 @@ def enter_table_name():
5153
val = input(color(F.GREEN, "", "Would you like to create a table for the database? (y/n): "))
5254
return val
5355

56+
5457
def enter_file_path(table: str):
55-
val = input(color(F.GREEN, "", f"Please input the CSV file path to table: {table} : "))
58+
val = input(color(F.GREEN, "", f"Please input the CSV file path to table {table} : "))
5659
return val
5760

5861

@@ -80,12 +83,12 @@ def _get_table_info(cache_path: str):
8083
else:
8184
table_info_path = click.prompt("Enter table info path")
8285
table_metadata["schema_info_path"] = table_info_path
83-
with open(f"{cache_path}/table_context.json", "a") as outfile:
86+
with open(f"{cache_path}/table_context.json", "w") as outfile:
8487
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
8588
else:
8689
table_info_path = click.prompt("Enter table info path")
8790
table_metadata = {"schema_info_path": table_info_path}
88-
with open(f"{cache_path}/table_context.json", "a") as outfile:
91+
with open(f"{cache_path}/table_context.json", "w") as outfile:
8992
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
9093
return table_info_path
9194

@@ -104,6 +107,7 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
104107
if table_info_path:
105108
table_metadata = {"schema_info_path": table_info_path}
106109

110+
table_metadata["data_table_map"] = {}
107111
with open(f"{cache_path}/table_context.json", "w") as outfile:
108112
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
109113

@@ -335,16 +339,10 @@ def query(question: str, table_info_path: str, sample_queries: str):
335339
logger.info(f"Input query: {question}")
336340
logger.info(f"Generated response:\n\n{res}")
337341

338-
save_sql = click.prompt("Would you like to save the generated SQL (y/n)?")
339-
if save_sql.lower() == "y" or save_sql.lower() == "yes":
340-
# Persist for future use
341-
_val = updated_sql if updated_sql else res
342-
save_query(base_path, query=question, response=_val)
343-
344342
exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?")
345343
if exe_sql.lower() == "y" or exe_sql.lower() == "yes":
346-
# For the time being, the default option is Pandas, but the user can be asked to select Database or Panadas DF later.
347-
option = "pandas" # or DB
344+
# For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later.
345+
option = "DB" # or DB
348346
_val = updated_sql if updated_sql else res
349347
if option == "DB":
350348
hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
@@ -354,30 +352,44 @@ def query(question: str, table_info_path: str, sample_queries: str):
354352
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]
355353

356354
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path)
357-
db_obj.execute_query(query=_val)
355+
output_res = db_obj.execute_query_db(query=_val)
356+
click.echo(f"The query results are:\n {output_res}")
358357
elif option == "pandas":
359358
tables = extract_table_names(_val)
360359
tables_path = dict()
361-
for table in tables:
362-
while True:
363-
val = enter_file_path(table)
364-
if not os.path.isfile(val):
365-
click.echo("In-correct Path. Please enter again! Yes(y) or no(n)")
366-
# val = enter_file_path(table)
360+
if Path(f"{path}/table_context.json").exists():
361+
f = open(f"{path}/table_context.json", "r")
362+
table_metadata = json.load(f)
363+
for table in tables:
364+
# Check if the local table_path exists in the cache
365+
if table not in table_metadata["data_table_map"].keys():
366+
val = enter_file_path(table)
367+
if not os.path.isfile(val):
368+
click.echo("In-correct Path. Please enter again! Yes(y) or no(n)")
369+
else:
370+
tables_path[table] = val
371+
table_metadata["data_table_map"][table] = val
372+
break
367373
else:
368-
tables_path[table] = val
369-
break
370-
371-
assert len(tables) == len(tables_path)
372-
373-
res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100)
374-
375-
logger.info("The query results are:")
376-
logger.info(res)
377-
378-
else:
379-
click.echo("Exiting...")
374+
tables_path[table] = table_metadata["data_table_map"][table]
375+
assert len(tables) == len(tables_path)
376+
with open(f"{path}/table_context.json", "w") as outfile:
377+
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
378+
try:
379+
res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100)
380+
click.echo(f"The query results are:\n {res}")
381+
except sqldf.PandaSQLException as e:
382+
logger.error(f"Error in executing the query: {e}")
383+
click.echo("Error in executing the query. Validate generate SQL and try again.")
384+
click.echo("No result to display.")
380385

386+
save_sql = click.prompt("Would you like to save the generated SQL (y/n)?")
387+
if save_sql.lower() == "y" or save_sql.lower() == "yes":
388+
# Persist for future use
389+
_val = updated_sql if updated_sql else res
390+
save_query(base_path, query=question, response=_val)
391+
else:
392+
click.echo("Exiting...")
381393

382394

383395
if __name__ == "__main__":

sidekick/query.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
import sqlglot
99
import toml
1010
from langchain import OpenAI
11-
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
12-
LLMPredictor, ServiceContext, SQLDatabase)
11+
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
1312
from llama_index.indices.struct_store import SQLContextContainerBuilder
14-
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
15-
TASK_PROMPT)
13+
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
1614
from sidekick.logger import logger
1715
from sidekick.utils import csv_parser, filter_samples, remove_duplicates
1816
from sqlalchemy import create_engine
@@ -186,14 +184,16 @@ def generate_tasks(self, table_names: list, input_question: str):
186184
data = json.loads(line)
187185
data_info += "\n" + json.dumps(data)
188186
self._data_info = data_info
189-
task_list = self._query_tasks(input_question, data_info, _queries.lower(), table_names)
187+
task_list = self._query_tasks(input_question, data_info, _queries, table_names)
190188
with open(f"{self.path}/var/lib/tmp/data/tasks.txt", "w") as f:
191189
f.write(task_list)
192190
return task_list
193191
except Exception as se:
194192
raise se
195193

196-
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = 'gpt-3.5-turbo-0301'):
194+
def generate_sql(
195+
self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "gpt-3.5-turbo-0301"
196+
):
197197
_tasks = self.task_formatter(self._tasks)
198198
context_file = f"{self.path}/var/lib/tmp/data/context.json"
199199
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
@@ -203,10 +203,10 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "p
203203
query_str = QUERY_PROMPT.format(
204204
_dialect=_dialect,
205205
_data_info=self._data_info,
206-
_question=input_question.lower(),
206+
_question=input_question,
207207
_table_name=table_name,
208208
_sample_queries=context_queries,
209-
_tasks=_tasks.lower(),
209+
_tasks=_tasks,
210210
)
211211

212212
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
@@ -230,6 +230,8 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "p
230230
res = (
231231
str(res).split("```", 1)[1].split(";", 1)[0].strip().replace("```", "").replace("sql\n", "").strip()
232232
)
233+
else:
234+
res = str(res).split("Explanation:", 1)[0].strip()
233235
sqlglot.transpile(res)
234236
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
235237
logger.info("We did the best we could, there might be still be some error:\n")

sidekick/utils.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import json
22
import os
3+
import re
34
from pathlib import Path
45
from typing import Optional
56

67
import numpy as np
78
import pandas as pd
89
from pandasql import sqldf
9-
import re
1010
from sentence_transformers import SentenceTransformer
1111
from sidekick.logger import logger
1212
from sklearn.metrics.pairwise import cosine_similarity
@@ -109,6 +109,7 @@ def csv_parser(input_path: str):
109109
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
110110
return res
111111

112+
112113
def extract_table_names(query: str):
113114
"""
114115
Extracts table names from a SQL query.
@@ -119,16 +120,16 @@ def extract_table_names(query: str):
119120
Returns:
120121
list: A list of table names.
121122
"""
122-
table_names = re.findall(r'\bFROM\s+(\w+)', query, re.IGNORECASE)
123-
table_names += re.findall(r'\bJOIN\s+(\w+)', query, re.IGNORECASE)
124-
table_names += re.findall(r'\bUPDATE\s+(\w+)', query, re.IGNORECASE)
125-
table_names += re.findall(r'\bINTO\s+(\w+)', query, re.IGNORECASE)
123+
table_names = re.findall(r"\bFROM\s+(\w+)", query, re.IGNORECASE)
124+
table_names += re.findall(r"\bJOIN\s+(\w+)", query, re.IGNORECASE)
125+
table_names += re.findall(r"\bUPDATE\s+(\w+)", query, re.IGNORECASE)
126+
table_names += re.findall(r"\bINTO\s+(\w+)", query, re.IGNORECASE)
126127

127-
# Below keywords may not be relevant for the project but adding for sake for completness
128-
table_names += re.findall(r'\bINSERT\s+INTO\s+(\w+)', query, re.IGNORECASE)
129-
table_names += re.findall(r'\bDELETE\s+FROM\s+(\w+)', query, re.IGNORECASE)
128+
# Below keywords may not be relevant for the project but adding for sake for completeness
129+
table_names += re.findall(r"\bINSERT\s+INTO\s+(\w+)", query, re.IGNORECASE)
130+
table_names += re.findall(r"\bDELETE\s+FROM\s+(\w+)", query, re.IGNORECASE)
131+
return np.unique(table_names).tolist()
130132

131-
return table_names
132133

133134
def execute_query_pd(query=None, tables_path=None, n_rows=100):
134135
"""
@@ -142,8 +143,9 @@ def execute_query_pd(query=None, tables_path=None, n_rows=100):
142143
pandas DataFrame: The result of the SQL query.
143144
"""
144145
for table in tables_path:
145-
locals()[f"{table}"] = pd.read_csv(tables_path[table])
146+
if not table in locals():
147+
# Update the local namespace with the table name, pandas object
148+
locals()[f"{table}"] = pd.read_csv(tables_path[table])
146149

147150
res_df = sqldf(query, locals())
148-
149-
return res_df
151+
return res_df

ui/.app_config.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[WAVE_UI]
2+
TITLE = "SideKick Assistant UI"
3+
SUB_TITLE = "Get answers to your questions"

0 commit comments

Comments
 (0)