Skip to content

Commit b33db8d

Browse files
authored
Code to add samples and execute query (#9)
Added: 1. Added prompt to ask user if they want to execute the query 2. A method to extract table names from the query generated 3. Added method to use PandaSQL to execute the query in memory sqllite
1 parent be1c879 commit b33db8d

File tree

4 files changed

+147
-15
lines changed

4 files changed

+147
-15
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ typing-inspect==0.9.0 ; python_full_version >= "3.8.16" and python_version < "3.
6565
urllib3==2.0.2 ; python_full_version >= "3.8.16" and python_version < "3.10"
6666
win32-setctime==1.1.0 ; python_full_version >= "3.8.16" and python_version < "3.10" and sys_platform == "win32"
6767
yarl==1.9.2 ; python_full_version >= "3.8.16" and python_version < "3.10"
68+
pandasql==0.7.3 ; python_full_version >= "3.8.16" and python_version < "3.10"

sidekick/db_config.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# create db with supplied info
22
import json
33
from pathlib import Path
4+
import pandas as pd
45

56
import psycopg2 as pg
67
import sqlalchemy
78
from psycopg2.extras import Json
9+
from pandasql import sqldf
810
from sidekick.logger import logger
911
from sqlalchemy import create_engine
1012
from sqlalchemy_utils import database_exists
@@ -127,16 +129,61 @@ def has_table(self):
127129
)
128130
return sqlalchemy.inspect(engine).has_table(self.table_name)
129131

130-
def add_samples(self):
131-
# Non-functional for now.
132-
conn = pg.connect(
133-
database=self.db_name, user=self.user_name, password=self.password, host=self.hostname, port=self.port
134-
)
135-
# Creating a cursor object using the cursor() method
136-
conn.autocommit = True
137-
cursor = conn.cursor()
132+
def add_samples(self, data_csv_path=None):
133+
conn_str = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
134+
try:
135+
df = pd.read_csv(data_csv_path, infer_datetime_format=True)
136+
engine = create_engine(conn_str, isolation_level='AUTOCOMMIT')
137+
138+
sample_query = f'SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1'
139+
num_rows_bef = pd.read_sql_query(sample_query, engine)
140+
141+
# Write rows to database
142+
res = df.to_sql(self.table_name, engine, if_exists='append', index=False)
143+
144+
# Fetch the number of rows from the table
145+
num_rows_aft = pd.read_sql_query(sample_query, engine)
146+
147+
logger.info(f"Number of rows inserted: {num_rows_aft.iloc[0, 0] - num_rows_bef.iloc[0, 0]}")
148+
149+
engine.dispose()
138150

139-
cursor.execute()
151+
except Exception as e:
152+
logger.info(f"Error occurred : {format(e)}")
153+
finally:
154+
engine.dispose()
140155

141-
# Commit your changes in the database
142-
conn.commit()
156+
def execute_query_db(self, query=None, n_rows=100):
157+
try:
158+
if query:
159+
# Create an engine
160+
conn_str = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/{self.db_name}"
161+
engine = create_engine(conn_str)
162+
163+
# Create a connection
164+
connection = engine.connect()
165+
166+
result = connection.execute(query)
167+
168+
# Process the query results
169+
cnt = 0
170+
logger.info("Here are the results from the queries: ")
171+
for row in result:
172+
if cnt <= n_rows:
173+
# Access row data using row[column_name]
174+
logger.info(row)
175+
cnt += 1
176+
else:
177+
break
178+
# Close the connection
179+
connection.close()
180+
181+
# Close the engine
182+
engine.dispose()
183+
else:
184+
logger.info("Query Empty or None!")
185+
except Exception as e:
186+
logger.info(f"Error occurred : {format(e)}")
187+
finally:
188+
connection.close()
189+
engine.dispose()

sidekick/prompter.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sidekick.db_config import DBConfig
1313
from sidekick.memory import EntityMemory
1414
from sidekick.query import SQLGenerator
15-
from sidekick.utils import save_query, setup_dir
15+
from sidekick.utils import save_query, setup_dir, extract_table_names, execute_query_pd
1616

1717
# Load the config file and initialize required paths
1818
base_path = (Path(__file__).parent / "../").resolve()
@@ -51,6 +51,10 @@ def enter_table_name():
5151
val = input(color(F.GREEN, "", "Would you like to create a table for the database? (y/n): "))
5252
return val
5353

54+
def enter_file_path(table: str):
55+
val = input(color(F.GREEN, "", f"Please input the CSV file path to table: {table} : "))
56+
return val
57+
5458

5559
@configure.command("log", help="Adjust log settings")
5660
@click.option("--set_level", "-l", help="Set log level (Default: INFO)")
@@ -162,9 +166,10 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
162166
# Check if table exists; pending --> and doesn't have any rows
163167
if db_obj.has_table():
164168
click.echo(f"Checked table {db_obj.table_name} exists in the DB.")
165-
val = input(color(F.GREEN, "", "Would you like to add few sample rows (at-least 3)? (y/n): "))
166-
if val.lower() == "y":
167-
db_obj.add_samples()
169+
val = input(color(F.GREEN, "", "Would you like to add few sample rows (at-least 3)? (y/n):"))
170+
if val.lower().strip() == "y" or val.lower().strip() == "yes":
171+
val = input("Path to a CSV file to insert data from:")
172+
db_obj.add_samples(val)
168173
else:
169174
click.echo("Exiting...")
170175
return
@@ -336,6 +341,44 @@ def query(question: str, table_info_path: str, sample_queries: str):
336341
_val = updated_sql if updated_sql else res
337342
save_query(base_path, query=question, response=_val)
338343

344+
exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?")
345+
if exe_sql.lower() == "y" or exe_sql.lower() == "yes":
346+
# For the time being, the default option is Pandas, but the user can be asked to select Database or Panadas DF later.
347+
option = "pandas" # or DB
348+
_val = updated_sql if updated_sql else res
349+
if option == "DB":
350+
hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"]
351+
user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"]
352+
password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
353+
port = env_settings["LOCAL_DB_CONFIG"]["PORT"]
354+
db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"]
355+
356+
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path)
357+
db_obj.execute_query(query=_val)
358+
elif option == "pandas":
359+
tables = extract_table_names(_val)
360+
tables_path = dict()
361+
for table in tables:
362+
while True:
363+
val = enter_file_path(table)
364+
if not os.path.isfile(val):
365+
click.echo("In-correct Path. Please enter again! Yes(y) or no(n)")
366+
# val = enter_file_path(table)
367+
else:
368+
tables_path[table] = val
369+
break
370+
371+
assert len(tables) == len(tables_path)
372+
373+
res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100)
374+
375+
logger.info("The query results are:")
376+
logger.info(res)
377+
378+
else:
379+
click.echo("Exiting...")
380+
381+
339382

340383
if __name__ == "__main__":
341384
cli()

sidekick/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numpy as np
77
import pandas as pd
8+
from pandasql import sqldf
9+
import re
810
from sentence_transformers import SentenceTransformer
911
from sidekick.logger import logger
1012
from sklearn.metrics.pairwise import cosine_similarity
@@ -106,3 +108,42 @@ def csv_parser(input_path: str):
106108
# ]
107109
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
108110
return res
111+
112+
def extract_table_names(query: str):
113+
"""
114+
Extracts table names from a SQL query.
115+
116+
Parameters:
117+
query (str): The SQL query to extract table names from.
118+
119+
Returns:
120+
list: A list of table names.
121+
"""
122+
table_names = re.findall(r'\bFROM\s+(\w+)', query, re.IGNORECASE)
123+
table_names += re.findall(r'\bJOIN\s+(\w+)', query, re.IGNORECASE)
124+
table_names += re.findall(r'\bUPDATE\s+(\w+)', query, re.IGNORECASE)
125+
table_names += re.findall(r'\bINTO\s+(\w+)', query, re.IGNORECASE)
126+
127+
# Below keywords may not be relevant for the project but adding for sake for completness
128+
table_names += re.findall(r'\bINSERT\s+INTO\s+(\w+)', query, re.IGNORECASE)
129+
table_names += re.findall(r'\bDELETE\s+FROM\s+(\w+)', query, re.IGNORECASE)
130+
131+
return table_names
132+
133+
def execute_query_pd(query=None, tables_path=None, n_rows=100):
134+
"""
135+
Runs an SQL query on a pandas DataFrame.
136+
137+
Parameters:
138+
df (pandas DataFrame): The DataFrame to query.
139+
query (str): The SQL query to execute.
140+
141+
Returns:
142+
pandas DataFrame: The result of the SQL query.
143+
"""
144+
for table in tables_path:
145+
locals()[f"{table}"] = pd.read_csv(tables_path[table])
146+
147+
res_df = sqldf(query, locals())
148+
149+
return res_df

0 commit comments

Comments
 (0)