@@ -127,6 +127,9 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
127
127
@click .option ("--port" , "-P" , default = 5432 , help = "Database port" , prompt = "Enter port (default 5432)" )
128
128
@click .option ("--table-info-path" , "-t" , help = "Table info path" , default = None )
129
129
def db_setup (db_name : str , hostname : str , user_name : str , password : str , port : int , table_info_path : str ):
130
+ db_setup_api (db_name = db_name , hostname = hostname , user_name = user_name , password = password , port = port , table_info_path = table_info_path , table_samples_path = None , table_name = None , is_command = True )
131
+
132
+ def db_setup_api (db_name : str , hostname : str , user_name : str , password : str , port : int , table_info_path : str , table_samples_path : str , table_name : str , is_command :bool = False ):
130
133
"""Creates context for the new Database"""
131
134
click .echo (f" Information supplied:\n { db_name } , { hostname } , { user_name } , { password } , { port } " )
132
135
try :
@@ -151,7 +154,7 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
151
154
else :
152
155
click .echo ("Database already exists!" )
153
156
154
- val = enter_table_name ()
157
+ val = enter_table_name () if is_command else "y"
155
158
while True :
156
159
if val .lower () != "y" and val .lower () != "n" :
157
160
click .echo ("In-correct values. Enter Yes(y) or no(n)" )
@@ -163,7 +166,7 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
163
166
table_info_path = _get_table_info (path )
164
167
165
168
if val .lower () == "y" or val .lower () == "yes" :
166
- table_value = input ("Enter table name: " )
169
+ table_value = input ("Enter table name: " ) if is_command else table_name
167
170
click .echo (f"Table name: { table_value } " )
168
171
# set table name
169
172
db_obj .table_name = table_value .replace (" " , "_" )
@@ -173,17 +176,23 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
173
176
# Check if table exists; pending --> and doesn't have any rows
174
177
if db_obj .has_table ():
175
178
click .echo (f"Checked table { db_obj .table_name } exists in the DB." )
176
- val = input (color (F .GREEN , "" , "Would you like to add few sample rows (at-least 3)? (y/n):" ))
179
+ val = input (color (F .GREEN , "" , "Would you like to add few sample rows (at-least 3)? (y/n):" )) if is_command else "y"
177
180
if val .lower ().strip () == "y" or val .lower ().strip () == "yes" :
178
- val = input ("Path to a CSV file to insert data from:" )
181
+ val = input ("Path to a CSV file to insert data from:" ) if is_command else table_samples_path
179
182
db_obj .add_samples (val )
180
183
else :
181
184
click .echo ("Exiting..." )
182
185
return
183
186
else :
184
- click .echo ("Job done. Ask a question now!" )
187
+ echo_msg = "Job done. Ask a question now!"
188
+ click .echo (echo_msg )
189
+
190
+ return f"Created a Database { db_name } . Inserted sample values from { table_samples_path } into table { table_name } , please ask questions!"
185
191
except Exception as e :
186
- click .echo (f"Error creating database. Check configuration parameters.\n : { e } " )
192
+ echo_msg = f"Error creating database. Check configuration parameters.\n : { e } "
193
+ click .echo (echo_msg )
194
+ if not is_command :
195
+ return echo_msg
187
196
188
197
189
198
@cli .group ("learn" )
@@ -250,7 +259,11 @@ def update_context():
250
259
@click .option ("--table-info-path" , "-t" , help = "Table info path" , default = None )
251
260
@click .option ("--sample-queries" , "-s" , help = "Samples path" , default = None )
252
261
def query (question : str , table_info_path : str , sample_queries : str ):
262
+ query_api (question = question , table_info_path = table_info_path , sample_queries = sample_queries , is_command = True )
263
+
264
+ def query_api (question : str , table_info_path : str , sample_queries : str , is_command :bool = False ):
253
265
"""Asks question and returns SQL."""
266
+ results = []
254
267
# Book-keeping
255
268
setup_dir (base_path )
256
269
@@ -273,11 +286,16 @@ def query(question: str, table_info_path: str, sample_queries: str):
273
286
api_key = env_settings ["OPENAI" ]["OPENAI_API_KEY" ]
274
287
if api_key is None or api_key == "" :
275
288
if os .getenv ("OPENAI_API_KEY" ) is None or os .getenv ("OPENAI_API_KEY" ) == "" :
276
- val = input (
277
- color (F .GREEN , "" , "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):" )
278
- )
279
- if val .lower () == "y" :
280
- api_key = input (color (F .GREEN , "" , "Enter OPENAI_API_KEY :" ))
289
+ if is_command :
290
+ val = input (
291
+ color (F .GREEN , "" , "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):" )
292
+ )
293
+ if val .lower () == "y" :
294
+ api_key = input (color (F .GREEN , "" , "Enter OPENAI_API_KEY :" ))
295
+
296
+ if api_key is None and is_command :
297
+ return ["Looks like API key is not set, please set OPENAI_API_KEY!" ]
298
+
281
299
os .environ ["OPENAI_API_KEY" ] = api_key
282
300
env_settings ["OPENAI" ]["OPENAI_API_KEY" ] = api_key
283
301
@@ -310,10 +328,11 @@ def query(question: str, table_info_path: str, sample_queries: str):
310
328
db_url , api_key , job_path = base_path , data_input_path = table_info_path , samples_queries = sample_queries
311
329
)
312
330
sql_g ._tasks = sql_g .generate_tasks (table_names , question )
331
+ results .extend (["List of Actions Generated: \n " , sql_g ._tasks , "\n " ])
313
332
click .echo (sql_g ._tasks )
314
333
315
334
updated_tasks = None
316
- if sql_g ._tasks is not None :
335
+ if sql_g ._tasks is not None and is_command :
317
336
edit_val = click .prompt ("Would you like to edit the tasks? (y/n)" )
318
337
if edit_val .lower () == "y" :
319
338
updated_tasks = click .edit (sql_g ._tasks )
@@ -331,23 +350,27 @@ def query(question: str, table_info_path: str, sample_queries: str):
331
350
if res is not None :
332
351
updated_sql = None
333
352
res_val = "e"
334
- while res_val .lower () in ["e" , "edit" , "r" , "regenerate" ]:
335
- res_val = click .prompt (
336
- "Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. "
337
- "To skip, enter 's' or 'skip'"
338
- )
339
- if res_val .lower () == "e" or res_val .lower () == "edit" :
340
- updated_sql = click .edit (res )
341
- click .echo (f"Updated SQL:\n { updated_sql } " )
342
- elif res_val .lower () == "r" or res_val .lower () == "regenerate" :
343
- click .echo ("Attempting to regenerate..." )
344
- res = sql_g .generate_sql (table_names , question , model_name = model_name , _dialect = db_dialect )
345
- logger .info (f"Input query: { question } " )
346
- logger .info (f"Generated response:\n \n { res } " )
347
-
348
- exe_sql = click .prompt ("Would you like to execute the generated SQL (y/n)?" )
353
+ if is_command :
354
+ while res_val .lower () in ["e" , "edit" , "r" , "regenerate" ]:
355
+ res_val = click .prompt (
356
+ "Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. "
357
+ "To skip, enter 's' or 'skip'"
358
+ )
359
+ if res_val .lower () == "e" or res_val .lower () == "edit" :
360
+ updated_sql = click .edit (res )
361
+ click .echo (f"Updated SQL:\n { updated_sql } " )
362
+ elif res_val .lower () == "r" or res_val .lower () == "regenerate" :
363
+ click .echo ("Attempting to regenerate..." )
364
+ res = sql_g .generate_sql (table_names , question , model_name = model_name , _dialect = db_dialect )
365
+ logger .info (f"Input query: { question } " )
366
+ logger .info (f"Generated response:\n \n { res } " )
367
+
368
+ results .extend (["Generated Query:\n " , res , "\n " ])
369
+
370
+ exe_sql = click .prompt ("Would you like to execute the generated SQL (y/n)?" ) if is_command else "y"
349
371
if exe_sql .lower () == "y" or exe_sql .lower () == "yes" :
350
372
# For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later.
373
+ q_res = None
351
374
option = "DB" # or DB
352
375
_val = updated_sql if updated_sql else res
353
376
if option == "DB" :
@@ -359,8 +382,8 @@ def query(question: str, table_info_path: str, sample_queries: str):
359
382
360
383
db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path , dialect = db_dialect )
361
384
362
- output_res = db_obj .execute_query_db (query = _val )
363
- click . echo ( f"The query results are: \n { output_res } " )
385
+ q_res , err = db_obj .execute_query_db (query = _val )
386
+
364
387
elif option == "pandas" :
365
388
tables = extract_table_names (_val )
366
389
tables_path = dict ()
@@ -383,21 +406,30 @@ def query(question: str, table_info_path: str, sample_queries: str):
383
406
with open (f"{ path } /table_context.json" , "w" ) as outfile :
384
407
json .dump (table_metadata , outfile , indent = 4 , sort_keys = False )
385
408
try :
386
- res = execute_query_pd (query = _val , tables_path = tables_path , n_rows = 100 )
387
- click .echo (f"The query results are:\n { res } " )
409
+ q_res = execute_query_pd (query = _val , tables_path = tables_path , n_rows = 100 )
410
+ click .echo (f"The query results are:\n { q_res } " )
388
411
except sqldf .PandaSQLException as e :
389
412
logger .error (f"Error in executing the query: { e } " )
390
- click .echo ("Error in executing the query. Validate generate SQL and try again." )
413
+ click .echo ("Error in executing the query. Validate generated SQL and try again." )
391
414
click .echo ("No result to display." )
392
415
393
- save_sql = click .prompt ("Would you like to save the generated SQL (y/n)?" )
416
+ results .append ("Query Results: \n " )
417
+ if q_res :
418
+ click .echo (f"The query results are:\n { q_res } " )
419
+ results .extend ([str (q_res ), "\n " ])
420
+ else :
421
+ click .echo (f"While executing query:\n { err } " )
422
+ results .extend ([str (err ), "\n " ])
423
+ # results.extend(["Query Results:", q_res])
424
+ save_sql = click .prompt ("Would you like to save the generated SQL (y/n)?" ) if is_command else "n"
394
425
if save_sql .lower () == "y" or save_sql .lower () == "yes" :
395
426
# Persist for future use
396
427
_val = updated_sql if updated_sql else res
397
428
save_query (base_path , query = question , response = _val )
398
429
else :
399
430
click .echo ("Exiting..." )
400
431
432
+ return results
401
433
402
434
if __name__ == "__main__" :
403
435
cli ()
0 commit comments