@@ -112,10 +112,10 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
112112 json .dump (table_metadata , outfile , indent = 4 , sort_keys = False )
113113
114114
115- @configure .command ("db-setup" , help = "Enter information to configure postgres database locally" )
115+ @configure .command ("db-setup" , help = f "Enter information to configure { db_dialect } database locally" )
116116@click .option ("--db_name" , "-n" , default = "querydb" , help = "Database name" , prompt = "Enter Database name" )
117117@click .option ("--hostname" , "-h" , default = "localhost" , help = "Database hostname" , prompt = "Enter hostname name" )
118- @click .option ("--user_name" , "-u" , default = "postgres " , help = "Database username" , prompt = "Enter username name" )
118+ @click .option ("--user_name" , "-u" , default = f" { db_dialect } " , help = "Database username" , prompt = "Enter username name" )
119119@click .option (
120120 "--password" ,
121121 "-p" ,
@@ -141,8 +141,11 @@ def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: i
141141 f .close ()
142142 path = f"{ base_path } /var/lib/tmp/data"
143143 # For current session
144- db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path )
145- if not db_obj .db_exists ():
144+ db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path , dialect = db_dialect )
145+ if db_obj .dialect == 'sqlite' and not os .path .isfile (f"{ base_path } /db/sqlite/{ db_name } .db" ):
146+ db_obj .create_db ()
147+ click .echo ("Database created successfully!" )
148+ elif not db_obj .db_exists ():
146149 db_obj .create_db ()
147150 click .echo ("Database created successfully!" )
148151 else :
@@ -293,9 +296,12 @@ def query(question: str, table_info_path: str, sample_queries: str):
293296 passwd = env_settings ["LOCAL_DB_CONFIG" ]["PASSWORD" ]
294297 db_name = env_settings ["LOCAL_DB_CONFIG" ]["DB_NAME" ]
295298
296- db_url = f"{ db_dialect } +psycopg2://{ user_name } :{ passwd } @{ host_name } /{ db_name } " .format (
297- user_name , passwd , host_name , db_name
298- )
299+ if db_dialect == "sqlite" :
300+ db_url = f"sqlite:///{ base_path } /db/sqlite/{ db_name } .db"
301+ else :
302+ db_url = f"{ db_dialect } +psycopg2://{ user_name } :{ passwd } @{ host_name } /{ db_name } " .format (
303+ user_name , passwd , host_name , db_name
304+ )
299305
300306 if table_info_path is None :
301307 table_info_path = _get_table_info (path )
@@ -318,7 +324,7 @@ def query(question: str, table_info_path: str, sample_queries: str):
318324 sql_g ._tasks = updated_tasks
319325
320326 model_name = env_settings ["OPENAI" ]["MODEL_NAME" ]
321- res = sql_g .generate_sql (table_names , question , model_name = model_name )
327+ res = sql_g .generate_sql (table_names , question , model_name = model_name , _dialect = db_dialect )
322328 logger .info (f"Input query: { question } " )
323329 logger .info (f"Generated response:\n \n { res } " )
324330
@@ -335,7 +341,7 @@ def query(question: str, table_info_path: str, sample_queries: str):
335341 click .echo (f"Updated SQL:\n { updated_sql } " )
336342 elif res_val .lower () == "r" or res_val .lower () == "regenerate" :
337343 click .echo ("Attempting to regenerate..." )
338- res = sql_g .generate_sql (table_names , question , model_name = model_name )
344+ res = sql_g .generate_sql (table_names , question , model_name = model_name , _dialect = db_dialect )
339345 logger .info (f"Input query: { question } " )
340346 logger .info (f"Generated response:\n \n { res } " )
341347
@@ -351,7 +357,8 @@ def query(question: str, table_info_path: str, sample_queries: str):
351357 port = env_settings ["LOCAL_DB_CONFIG" ]["PORT" ]
352358 db_name = env_settings ["LOCAL_DB_CONFIG" ]["DB_NAME" ]
353359
354- db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path )
360+ db_obj = DBConfig (db_name , hostname , user_name , password , port , base_path = base_path , dialect = db_dialect )
361+
355362 output_res = db_obj .execute_query_db (query = _val )
356363 click .echo (f"The query results are:\n { output_res } " )
357364 elif option == "pandas" :
0 commit comments