20
20
base_path = (Path (__file__ ).parent / "../" ).resolve ()
21
21
env_settings = toml .load (f"{ base_path } /sidekick/configs/.env.toml" )
22
22
db_dialect = env_settings ["DB-DIALECT" ]["DB_TYPE" ]
23
+ model_name = env_settings ["MODEL_INFO" ]["MODEL_NAME" ]
23
24
os .environ ["TOKENIZERS_PARALLELISM" ] = "False"
24
25
__version__ = "0.0.4"
25
26
@@ -127,9 +128,30 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
127
128
@click .option ("--port" , "-P" , default = 5432 , help = "Database port" , prompt = "Enter port (default 5432)" )
128
129
@click .option ("--table-info-path" , "-t" , help = "Table info path" , default = None )
129
130
def db_setup (db_name : str , hostname : str , user_name : str , password : str , port : int , table_info_path : str ):
130
- db_setup_api (db_name = db_name , hostname = hostname , user_name = user_name , password = password , port = port , table_info_path = table_info_path , table_samples_path = None , table_name = None , is_command = True )
131
+ db_setup_api (
132
+ db_name = db_name ,
133
+ hostname = hostname ,
134
+ user_name = user_name ,
135
+ password = password ,
136
+ port = port ,
137
+ table_info_path = table_info_path ,
138
+ table_samples_path = None ,
139
+ table_name = None ,
140
+ is_command = True ,
141
+ )
142
+
131
143
132
- def db_setup_api (db_name : str , hostname : str , user_name : str , password : str , port : int , table_info_path : str , table_samples_path : str , table_name : str , is_command :bool = False ):
144
+ def db_setup_api (
145
+ db_name : str ,
146
+ hostname : str ,
147
+ user_name : str ,
148
+ password : str ,
149
+ port : int ,
150
+ table_info_path : str ,
151
+ table_samples_path : str ,
152
+ table_name : str ,
153
+ is_command : bool = False ,
154
+ ):
133
155
"""Creates context for the new Database"""
134
156
click .echo (f" Information supplied:\n { db_name } , { hostname } , { user_name } , { password } , { port } " )
135
157
try :
@@ -145,7 +167,7 @@ def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, por
145
167
path = f"{ base_path } /var/lib/tmp/data"
146
168
# For current session
147
169
db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path , dialect = db_dialect )
148
- if db_obj .dialect == ' sqlite' and not os .path .isfile (f"{ base_path } /db/sqlite/{ db_name } .db" ):
170
+ if db_obj .dialect == " sqlite" and not os .path .isfile (f"{ base_path } /db/sqlite/{ db_name } .db" ):
149
171
db_obj .create_db ()
150
172
click .echo ("Database created successfully!" )
151
173
elif not db_obj .db_exists ():
@@ -176,7 +198,11 @@ def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, por
176
198
# Check if table exists; pending --> and doesn't have any rows
177
199
if db_obj .has_table ():
178
200
click .echo (f"Checked table { db_obj .table_name } exists in the DB." )
179
- val = input (color (F .GREEN , "" , "Would you like to add few sample rows (at-least 3)? (y/n):" )) if is_command else "y"
201
+ val = (
202
+ input (color (F .GREEN , "" , "Would you like to add few sample rows (at-least 3)? (y/n):" ))
203
+ if is_command
204
+ else "y"
205
+ )
180
206
if val .lower ().strip () == "y" or val .lower ().strip () == "yes" :
181
207
val = input ("Path to a CSV file to insert data from:" ) if is_command else table_samples_path
182
208
db_obj .add_samples (val )
@@ -259,9 +285,10 @@ def update_context():
259
285
@click .option ("--table-info-path" , "-t" , help = "Table info path" , default = None )
260
286
@click .option ("--sample-queries" , "-s" , help = "Samples path" , default = None )
261
287
def query (question : str , table_info_path : str , sample_queries : str ):
262
- query_api (question = question , table_info_path = table_info_path , sample_queries = sample_queries , is_command = True )
288
+ query_api (question = question , table_info_path = table_info_path , sample_queries = sample_queries , is_command = True )
263
289
264
- def query_api (question : str , table_info_path : str , sample_queries : str , is_command :bool = False ):
290
+
291
+ def query_api (question : str , table_info_path : str , sample_queries : str , is_command : bool = False ):
265
292
"""Asks question and returns SQL."""
266
293
results = []
267
294
# Book-keeping
@@ -283,27 +310,31 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
283
310
json .dump (table_context , outfile , indent = 4 , sort_keys = False )
284
311
logger .info (f"Table in use: { table_names } " )
285
312
# Check if .env.toml file exists
286
- api_key = env_settings ["OPENAI" ]["OPENAI_API_KEY" ]
287
- if api_key is None or api_key == "" :
288
- if os .getenv ("OPENAI_API_KEY" ) is None or os .getenv ("OPENAI_API_KEY" ) == "" :
289
- if is_command :
290
- val = input (
291
- color (F .GREEN , "" , "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):" )
292
- )
293
- if val .lower () == "y" :
294
- api_key = input (color (F .GREEN , "" , "Enter OPENAI_API_KEY :" ))
295
-
296
- if api_key is None and is_command :
297
- return ["Looks like API key is not set, please set OPENAI_API_KEY!" ]
298
-
299
- os .environ ["OPENAI_API_KEY" ] = api_key
300
- env_settings ["OPENAI" ]["OPENAI_API_KEY" ] = api_key
301
-
302
- # Update settings file for future use.
303
- f = open (f"{ base_path } /sidekick/configs/.env.toml" , "w" )
304
- toml .dump (env_settings , f )
305
- f .close ()
306
- openai .api_key = api_key
313
+ api_key = None
314
+ if model_name != "h2ogpt-sql" :
315
+ api_key = env_settings ["MODEL_INFO" ]["OPENAI_API_KEY" ]
316
+ if api_key is None or api_key == "" :
317
+ if os .getenv ("OPENAI_API_KEY" ) is None or os .getenv ("OPENAI_API_KEY" ) == "" :
318
+ if is_command :
319
+ val = input (
320
+ color (
321
+ F .GREEN , "" , "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):"
322
+ )
323
+ )
324
+ if val .lower () == "y" :
325
+ api_key = input (color (F .GREEN , "" , "Enter OPENAI_API_KEY :" ))
326
+
327
+ if api_key is None and is_command :
328
+ return ["Looks like API key is not set, please set OPENAI_API_KEY!" ]
329
+
330
+ os .environ ["OPENAI_API_KEY" ] = api_key
331
+ env_settings ["MODEL_INFO" ]["OPENAI_API_KEY" ] = api_key
332
+
333
+ # Update settings file for future use.
334
+ f = open (f"{ base_path } /sidekick/configs/.env.toml" , "w" )
335
+ toml .dump (env_settings , f )
336
+ f .close ()
337
+ openai .api_key = api_key
307
338
308
339
# Set context
309
340
logger .info ("Setting context..." )
@@ -327,22 +358,22 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
327
358
sql_g = SQLGenerator (
328
359
db_url , api_key , job_path = base_path , data_input_path = table_info_path , samples_queries = sample_queries
329
360
)
330
- sql_g ._tasks = sql_g .generate_tasks (table_names , question )
331
- results .extend (["List of Actions Generated: \n " , sql_g ._tasks , "\n " ])
332
- click .echo (sql_g ._tasks )
333
-
334
- updated_tasks = None
335
- if sql_g ._tasks is not None and is_command :
336
- edit_val = click .prompt ("Would you like to edit the tasks? (y/n)" )
337
- if edit_val .lower () == "y" :
338
- updated_tasks = click .edit (sql_g ._tasks )
339
- click .echo (f"Tasks:\n { updated_tasks } " )
340
- else :
341
- click .echo ("Skipping edit..." )
342
- if updated_tasks is not None :
343
- sql_g ._tasks = updated_tasks
361
+ if "h2ogpt-sql" not in model_name :
362
+ sql_g ._tasks = sql_g .generate_tasks (table_names , question )
363
+ results .extend (["List of Actions Generated: \n " , sql_g ._tasks , "\n " ])
364
+ click .echo (sql_g ._tasks )
365
+
366
+ updated_tasks = None
367
+ if sql_g ._tasks is not None and is_command :
368
+ edit_val = click .prompt ("Would you like to edit the tasks? (y/n)" )
369
+ if edit_val .lower () == "y" :
370
+ updated_tasks = click .edit (sql_g ._tasks )
371
+ click .echo (f"Tasks:\n { updated_tasks } " )
372
+ else :
373
+ click .echo ("Skipping edit..." )
374
+ if updated_tasks is not None :
375
+ sql_g ._tasks = updated_tasks
344
376
345
- model_name = env_settings ["OPENAI" ]["MODEL_NAME" ]
346
377
res = sql_g .generate_sql (table_names , question , model_name = model_name , _dialect = db_dialect )
347
378
logger .info (f"Input query: { question } " )
348
379
logger .info (f"Generated response:\n \n { res } " )
@@ -431,5 +462,6 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
431
462
432
463
return results
433
464
465
+
434
466
if __name__ == "__main__" :
435
467
cli ()
0 commit comments