2020base_path = (Path (__file__ ).parent / "../" ).resolve ()
2121env_settings = toml .load (f"{ base_path } /sidekick/configs/.env.toml" )
2222db_dialect = env_settings ["DB-DIALECT" ]["DB_TYPE" ]
23+ model_name = env_settings ["MODEL_INFO" ]["MODEL_NAME" ]
2324os .environ ["TOKENIZERS_PARALLELISM" ] = "False"
2425__version__ = "0.0.4"
2526
@@ -127,9 +128,30 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
127128@click .option ("--port" , "-P" , default = 5432 , help = "Database port" , prompt = "Enter port (default 5432)" )
128129@click .option ("--table-info-path" , "-t" , help = "Table info path" , default = None )
129130def db_setup (db_name : str , hostname : str , user_name : str , password : str , port : int , table_info_path : str ):
130- db_setup_api (db_name = db_name , hostname = hostname , user_name = user_name , password = password , port = port , table_info_path = table_info_path , table_samples_path = None , table_name = None , is_command = True )
131+ db_setup_api (
132+ db_name = db_name ,
133+ hostname = hostname ,
134+ user_name = user_name ,
135+ password = password ,
136+ port = port ,
137+ table_info_path = table_info_path ,
138+ table_samples_path = None ,
139+ table_name = None ,
140+ is_command = True ,
141+ )
142+
131143
132- def db_setup_api (db_name : str , hostname : str , user_name : str , password : str , port : int , table_info_path : str , table_samples_path : str , table_name : str , is_command :bool = False ):
144+ def db_setup_api (
145+ db_name : str ,
146+ hostname : str ,
147+ user_name : str ,
148+ password : str ,
149+ port : int ,
150+ table_info_path : str ,
151+ table_samples_path : str ,
152+ table_name : str ,
153+ is_command : bool = False ,
154+ ):
133155 """Creates context for the new Database"""
134156 click .echo (f" Information supplied:\n { db_name } , { hostname } , { user_name } , { password } , { port } " )
135157 try :
@@ -145,7 +167,7 @@ def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, por
145167 path = f"{ base_path } /var/lib/tmp/data"
146168 # For current session
147169 db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path , dialect = db_dialect )
148- if db_obj .dialect == ' sqlite' and not os .path .isfile (f"{ base_path } /db/sqlite/{ db_name } .db" ):
170+ if db_obj .dialect == " sqlite" and not os .path .isfile (f"{ base_path } /db/sqlite/{ db_name } .db" ):
149171 db_obj .create_db ()
150172 click .echo ("Database created successfully!" )
151173 elif not db_obj .db_exists ():
@@ -176,7 +198,11 @@ def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, por
176198 # Check if table exists; pending --> and doesn't have any rows
177199 if db_obj .has_table ():
178200 click .echo (f"Checked table { db_obj .table_name } exists in the DB." )
179- val = input (color (F .GREEN , "" , "Would you like to add few sample rows (at-least 3)? (y/n):" )) if is_command else "y"
201+ val = (
202+ input (color (F .GREEN , "" , "Would you like to add few sample rows (at-least 3)? (y/n):" ))
203+ if is_command
204+ else "y"
205+ )
180206 if val .lower ().strip () == "y" or val .lower ().strip () == "yes" :
181207 val = input ("Path to a CSV file to insert data from:" ) if is_command else table_samples_path
182208 db_obj .add_samples (val )
@@ -259,9 +285,10 @@ def update_context():
259285@click .option ("--table-info-path" , "-t" , help = "Table info path" , default = None )
260286@click .option ("--sample-queries" , "-s" , help = "Samples path" , default = None )
261287def query (question : str , table_info_path : str , sample_queries : str ):
262- query_api (question = question , table_info_path = table_info_path , sample_queries = sample_queries , is_command = True )
288+ query_api (question = question , table_info_path = table_info_path , sample_queries = sample_queries , is_command = True )
263289
264- def query_api (question : str , table_info_path : str , sample_queries : str , is_command :bool = False ):
290+
291+ def query_api (question : str , table_info_path : str , sample_queries : str , is_command : bool = False ):
265292 """Asks question and returns SQL."""
266293 results = []
267294 # Book-keeping
@@ -283,27 +310,31 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
283310 json .dump (table_context , outfile , indent = 4 , sort_keys = False )
284311 logger .info (f"Table in use: { table_names } " )
285312 # Check if .env.toml file exists
286- api_key = env_settings ["OPENAI" ]["OPENAI_API_KEY" ]
287- if api_key is None or api_key == "" :
288- if os .getenv ("OPENAI_API_KEY" ) is None or os .getenv ("OPENAI_API_KEY" ) == "" :
289- if is_command :
290- val = input (
291- color (F .GREEN , "" , "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):" )
292- )
293- if val .lower () == "y" :
294- api_key = input (color (F .GREEN , "" , "Enter OPENAI_API_KEY :" ))
295-
296- if api_key is None and is_command :
297- return ["Looks like API key is not set, please set OPENAI_API_KEY!" ]
298-
299- os .environ ["OPENAI_API_KEY" ] = api_key
300- env_settings ["OPENAI" ]["OPENAI_API_KEY" ] = api_key
301-
302- # Update settings file for future use.
303- f = open (f"{ base_path } /sidekick/configs/.env.toml" , "w" )
304- toml .dump (env_settings , f )
305- f .close ()
306- openai .api_key = api_key
313+ api_key = None
314+ if model_name != "h2ogpt-sql" :
315+ api_key = env_settings ["MODEL_INFO" ]["OPENAI_API_KEY" ]
316+ if api_key is None or api_key == "" :
317+ if os .getenv ("OPENAI_API_KEY" ) is None or os .getenv ("OPENAI_API_KEY" ) == "" :
318+ if is_command :
319+ val = input (
320+ color (
321+ F .GREEN , "" , "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):"
322+ )
323+ )
324+ if val .lower () == "y" :
325+ api_key = input (color (F .GREEN , "" , "Enter OPENAI_API_KEY :" ))
326+
327+ if api_key is None and is_command :
328+ return ["Looks like API key is not set, please set OPENAI_API_KEY!" ]
329+
330+ os .environ ["OPENAI_API_KEY" ] = api_key
331+ env_settings ["MODEL_INFO" ]["OPENAI_API_KEY" ] = api_key
332+
333+ # Update settings file for future use.
334+ f = open (f"{ base_path } /sidekick/configs/.env.toml" , "w" )
335+ toml .dump (env_settings , f )
336+ f .close ()
337+ openai .api_key = api_key
307338
308339 # Set context
309340 logger .info ("Setting context..." )
@@ -327,22 +358,22 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
327358 sql_g = SQLGenerator (
328359 db_url , api_key , job_path = base_path , data_input_path = table_info_path , samples_queries = sample_queries
329360 )
330- sql_g ._tasks = sql_g .generate_tasks (table_names , question )
331- results .extend (["List of Actions Generated: \n " , sql_g ._tasks , "\n " ])
332- click .echo (sql_g ._tasks )
333-
334- updated_tasks = None
335- if sql_g ._tasks is not None and is_command :
336- edit_val = click .prompt ("Would you like to edit the tasks? (y/n)" )
337- if edit_val .lower () == "y" :
338- updated_tasks = click .edit (sql_g ._tasks )
339- click .echo (f"Tasks:\n { updated_tasks } " )
340- else :
341- click .echo ("Skipping edit..." )
342- if updated_tasks is not None :
343- sql_g ._tasks = updated_tasks
361+ if "h2ogpt-sql" not in model_name :
362+ sql_g ._tasks = sql_g .generate_tasks (table_names , question )
363+ results .extend (["List of Actions Generated: \n " , sql_g ._tasks , "\n " ])
364+ click .echo (sql_g ._tasks )
365+
366+ updated_tasks = None
367+ if sql_g ._tasks is not None and is_command :
368+ edit_val = click .prompt ("Would you like to edit the tasks? (y/n)" )
369+ if edit_val .lower () == "y" :
370+ updated_tasks = click .edit (sql_g ._tasks )
371+ click .echo (f"Tasks:\n { updated_tasks } " )
372+ else :
373+ click .echo ("Skipping edit..." )
374+ if updated_tasks is not None :
375+ sql_g ._tasks = updated_tasks
344376
345- model_name = env_settings ["OPENAI" ]["MODEL_NAME" ]
346377 res = sql_g .generate_sql (table_names , question , model_name = model_name , _dialect = db_dialect )
347378 logger .info (f"Input query: { question } " )
348379 logger .info (f"Generated response:\n \n { res } " )
@@ -431,5 +462,6 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
431462
432463 return results
433464
465+
434466if __name__ == "__main__" :
435467 cli ()
0 commit comments