Skip to content

Commit 0240036

Browse files
Initial working version with local LLM #4
1 parent 0724822 commit 0240036

File tree

7 files changed

+221
-122
lines changed

7 files changed

+221
-122
lines changed

Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@ setup: download_models ## Setup
1313

1414
download_models:
1515
mkdir -p ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2
16-
$(sentence_transformer)

sidekick/configs/.env.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
[OPENAI]
2-
OPENAI_API_KEY = ""
1+
[MODEL_INFO]
2+
OPENAI_API_KEY = "" # Needed only for openAI models
33
MODEL_NAME = "h2ogpt-sql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
44

55
[LOCAL_DB_CONFIG]
@@ -13,9 +13,10 @@ PORT = "5432"
1313
LOG-LEVEL = "INFO"
1414

1515
[DB-DIALECT]
16-
DB_TYPE = "sqlite" # postgresql
16+
DB_TYPE = "sqlite"
1717

1818
[TABLE_INFO]
1919
TABLE_INFO_PATH = "/examples/test/table_info.jsonl"
2020
TABLE_SAMPLES_PATH = "/examples/test/masked_data_and_columns.csv"
2121
TABLE_NAME = "demo"
22+
DB_TYPE = "sqlite"

sidekick/configs/prompt_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@
6363
NSQL_QUERY_PROMPT = """
6464
For SQL TABLE '{table_name}' sample question/answer pairs,\n({sample_queries})
6565
66-
CREATE TABLE '{table_name}'({data_info}
66+
CREATE TABLE '{table_name}'({column_info}
6767
)
6868
6969
Table '{table_name}' has sample values ({data_info_detailed})
7070
7171
7272
73-
-- Using valid {_dialect}, answer the following questions with the information for '{table_name}' provided above; for final SQL only use values from the question.
73+
-- Using valid SQLite, answer the following questions with the information for '{table_name}' provided above; for final SQL only use values from the question.
7474
7575
7676
-- Using reference for TABLES '{table_name}' {context}; {question_txt}?

sidekick/db_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _extract_schema_info(self, schema_info_path=None):
104104
sample_values.append(_ds)
105105
_new_samples = f"{col_name} {col_type}"
106106
res.append(_new_samples)
107-
if len(sample_values):
107+
if len(sample_values) > 0:
108108
# cache it for future use
109109
with open(
110110
f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w"

sidekick/prompter.py

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
base_path = (Path(__file__).parent / "../").resolve()
2121
env_settings = toml.load(f"{base_path}/sidekick/configs/.env.toml")
2222
db_dialect = env_settings["DB-DIALECT"]["DB_TYPE"]
23+
model_name = env_settings["MODEL_INFO"]["MODEL_NAME"]
2324
os.environ["TOKENIZERS_PARALLELISM"] = "False"
2425
__version__ = "0.0.4"
2526

@@ -127,9 +128,30 @@ def update_table_info(cache_path: str, table_info_path: str = None, table_name:
127128
@click.option("--port", "-P", default=5432, help="Database port", prompt="Enter port (default 5432)")
128129
@click.option("--table-info-path", "-t", help="Table info path", default=None)
129130
def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: int, table_info_path: str):
130-
db_setup_api(db_name=db_name, hostname=hostname, user_name=user_name, password=password, port=port, table_info_path=table_info_path, table_samples_path=None, table_name=None, is_command=True)
131+
db_setup_api(
132+
db_name=db_name,
133+
hostname=hostname,
134+
user_name=user_name,
135+
password=password,
136+
port=port,
137+
table_info_path=table_info_path,
138+
table_samples_path=None,
139+
table_name=None,
140+
is_command=True,
141+
)
142+
131143

132-
def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, port: int, table_info_path: str, table_samples_path: str, table_name: str, is_command:bool=False):
144+
def db_setup_api(
145+
db_name: str,
146+
hostname: str,
147+
user_name: str,
148+
password: str,
149+
port: int,
150+
table_info_path: str,
151+
table_samples_path: str,
152+
table_name: str,
153+
is_command: bool = False,
154+
):
133155
"""Creates context for the new Database"""
134156
click.echo(f" Information supplied:\n {db_name}, {hostname}, {user_name}, {password}, {port}")
135157
try:
@@ -145,7 +167,7 @@ def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, por
145167
path = f"{base_path}/var/lib/tmp/data"
146168
# For current session
147169
db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect)
148-
if db_obj.dialect == 'sqlite' and not os.path.isfile(f"{base_path}/db/sqlite/{db_name}.db"):
170+
if db_obj.dialect == "sqlite" and not os.path.isfile(f"{base_path}/db/sqlite/{db_name}.db"):
149171
db_obj.create_db()
150172
click.echo("Database created successfully!")
151173
elif not db_obj.db_exists():
@@ -176,7 +198,11 @@ def db_setup_api(db_name: str, hostname: str, user_name: str, password: str, por
176198
# Check if table exists; pending --> and doesn't have any rows
177199
if db_obj.has_table():
178200
click.echo(f"Checked table {db_obj.table_name} exists in the DB.")
179-
val = input(color(F.GREEN, "", "Would you like to add few sample rows (at-least 3)? (y/n):")) if is_command else "y"
201+
val = (
202+
input(color(F.GREEN, "", "Would you like to add few sample rows (at-least 3)? (y/n):"))
203+
if is_command
204+
else "y"
205+
)
180206
if val.lower().strip() == "y" or val.lower().strip() == "yes":
181207
val = input("Path to a CSV file to insert data from:") if is_command else table_samples_path
182208
db_obj.add_samples(val)
@@ -259,9 +285,10 @@ def update_context():
259285
@click.option("--table-info-path", "-t", help="Table info path", default=None)
260286
@click.option("--sample-queries", "-s", help="Samples path", default=None)
261287
def query(question: str, table_info_path: str, sample_queries: str):
262-
query_api(question= question, table_info_path=table_info_path, sample_queries=sample_queries, is_command=True)
288+
query_api(question=question, table_info_path=table_info_path, sample_queries=sample_queries, is_command=True)
263289

264-
def query_api(question: str, table_info_path: str, sample_queries: str, is_command:bool=False):
290+
291+
def query_api(question: str, table_info_path: str, sample_queries: str, is_command: bool = False):
265292
"""Asks question and returns SQL."""
266293
results = []
267294
# Book-keeping
@@ -283,27 +310,31 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
283310
json.dump(table_context, outfile, indent=4, sort_keys=False)
284311
logger.info(f"Table in use: {table_names}")
285312
# Check if .env.toml file exists
286-
api_key = env_settings["OPENAI"]["OPENAI_API_KEY"]
287-
if api_key is None or api_key == "":
288-
if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_API_KEY") == "":
289-
if is_command:
290-
val = input(
291-
color(F.GREEN, "", "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):")
292-
)
293-
if val.lower() == "y":
294-
api_key = input(color(F.GREEN, "", "Enter OPENAI_API_KEY :"))
295-
296-
if api_key is None and is_command:
297-
return ["Looks like API key is not set, please set OPENAI_API_KEY!"]
298-
299-
os.environ["OPENAI_API_KEY"] = api_key
300-
env_settings["OPENAI"]["OPENAI_API_KEY"] = api_key
301-
302-
# Update settings file for future use.
303-
f = open(f"{base_path}/sidekick/configs/.env.toml", "w")
304-
toml.dump(env_settings, f)
305-
f.close()
306-
openai.api_key = api_key
313+
api_key = None
314+
if model_name != "h2ogpt-sql":
315+
api_key = env_settings["MODEL_INFO"]["OPENAI_API_KEY"]
316+
if api_key is None or api_key == "":
317+
if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_API_KEY") == "":
318+
if is_command:
319+
val = input(
320+
color(
321+
F.GREEN, "", "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):"
322+
)
323+
)
324+
if val.lower() == "y":
325+
api_key = input(color(F.GREEN, "", "Enter OPENAI_API_KEY :"))
326+
327+
if api_key is None and is_command:
328+
return ["Looks like API key is not set, please set OPENAI_API_KEY!"]
329+
330+
os.environ["OPENAI_API_KEY"] = api_key
331+
env_settings["MODEL_INFO"]["OPENAI_API_KEY"] = api_key
332+
333+
# Update settings file for future use.
334+
f = open(f"{base_path}/sidekick/configs/.env.toml", "w")
335+
toml.dump(env_settings, f)
336+
f.close()
337+
openai.api_key = api_key
307338

308339
# Set context
309340
logger.info("Setting context...")
@@ -327,22 +358,22 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
327358
sql_g = SQLGenerator(
328359
db_url, api_key, job_path=base_path, data_input_path=table_info_path, samples_queries=sample_queries
329360
)
330-
sql_g._tasks = sql_g.generate_tasks(table_names, question)
331-
results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"])
332-
click.echo(sql_g._tasks)
333-
334-
updated_tasks = None
335-
if sql_g._tasks is not None and is_command:
336-
edit_val = click.prompt("Would you like to edit the tasks? (y/n)")
337-
if edit_val.lower() == "y":
338-
updated_tasks = click.edit(sql_g._tasks)
339-
click.echo(f"Tasks:\n {updated_tasks}")
340-
else:
341-
click.echo("Skipping edit...")
342-
if updated_tasks is not None:
343-
sql_g._tasks = updated_tasks
361+
if "h2ogpt-sql" not in model_name:
362+
sql_g._tasks = sql_g.generate_tasks(table_names, question)
363+
results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"])
364+
click.echo(sql_g._tasks)
365+
366+
updated_tasks = None
367+
if sql_g._tasks is not None and is_command:
368+
edit_val = click.prompt("Would you like to edit the tasks? (y/n)")
369+
if edit_val.lower() == "y":
370+
updated_tasks = click.edit(sql_g._tasks)
371+
click.echo(f"Tasks:\n {updated_tasks}")
372+
else:
373+
click.echo("Skipping edit...")
374+
if updated_tasks is not None:
375+
sql_g._tasks = updated_tasks
344376

345-
model_name = env_settings["OPENAI"]["MODEL_NAME"]
346377
res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
347378
logger.info(f"Input query: {question}")
348379
logger.info(f"Generated response:\n\n{res}")
@@ -431,5 +462,6 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
431462

432463
return results
433464

465+
434466
if __name__ == "__main__":
435467
cli()

0 commit comments

Comments
 (0)