Skip to content

Commit be1c879

Browse files
Parameterize model name n minor prompt adjustments
1 parent ecf0847 commit be1c879

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

sidekick/configs/.env.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[OPENAI]
22
OPENAI_API_KEY = ""
3+
MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. text-davinci-003
34

45
[LOCAL_DB_CONFIG]
56
HOST_NAME = "localhost"

sidekick/configs/prompt_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
DEBUGGING_PROMPT = {
5353
"system_prompt": "Act as a SQL expert for PostgreSQL code",
5454
"user_prompt": """
55-
### Fix syntax errors for provided SQL Query.
55+
### Fix syntax errors for provided incorrect SQL Query.
5656
# Add ``` as prefix and ``` as suffix to generated SQL
5757
# Error: {ex_traceback}
5858
# Add explanation and reasoning for each SQL query

sidekick/prompter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def color(fore="", back="", text=None):
2626
return f"{fore}{back}{text}{Style.RESET_ALL}"
2727

2828

29-
msg = """Welcome to the SQL Sidekick!\nI am AI assistant that helps you with SQL queries.
29+
msg = """Welcome to the SQL Sidekick!\nI am an AI assistant that helps you with SQL queries.
3030
I can help you with the following:\n
3131
1. Configure a local database(for schema validation and syntax checking): `sql-sidekick configure db-setup`.\n
3232
2. Learn contextual query/answer pairs: `sql-sidekick learn add-samples`.\n
@@ -307,7 +307,9 @@ def query(question: str, table_info_path: str, sample_queries: str):
307307
click.echo("Skipping edit...")
308308
if updated_tasks is not None:
309309
sql_g._tasks = updated_tasks
310-
res = sql_g.generate_sql(table_names, question)
310+
311+
model_name = env_settings["OPENAI"]["MODEL_NAME"]
312+
res = sql_g.generate_sql(table_names, question, model_name=model_name)
311313
logger.info(f"Input query: {question}")
312314
logger.info(f"Generated response:\n\n{res}")
313315

@@ -324,7 +326,7 @@ def query(question: str, table_info_path: str, sample_queries: str):
324326
click.echo(f"Updated SQL:\n {updated_sql}")
325327
elif res_val.lower() == "r" or res_val.lower() == "regenerate":
326328
click.echo("Attempting to regenerate...")
327-
res = sql_g.generate_sql(table_names, question)
329+
res = sql_g.generate_sql(table_names, question, model_name=model_name)
328330
logger.info(f"Input query: {question}")
329331
logger.info(f"Generated response:\n\n{res}")
330332

sidekick/query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def generate_response(self, context_container, sql_index, input_prompt, attempt_
138138
try:
139139
# Attempt to heal with simple feedback
140140
# Reference: Teaching Large Language Models to Self-Debug, https://arxiv.org/abs/2304.05128
141-
logger.info(f"Attempting to heal ...,\n {se}")
141+
logger.info(f"Attempting to fix syntax error ...,\n {se}")
142142
system_prompt = DEBUGGING_PROMPT["system_prompt"]
143143
user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=ex_traceback, qry_txt=qry_txt)
144144
# Role and content
@@ -193,7 +193,7 @@ def generate_tasks(self, table_names: list, input_question: str):
193193
except Exception as se:
194194
raise se
195195

196-
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres"):
196+
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = 'gpt-3.5-turbo-0301'):
197197
_tasks = self.task_formatter(self._tasks)
198198
context_file = f"{self.path}/var/lib/tmp/data/context.json"
199199
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
@@ -217,7 +217,7 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "p
217217
context_container = self.context_builder.build_context_container()
218218

219219
# Reference: https://github.com/jerryjliu/llama_index/issues/987
220-
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.7, model_name="text-davinci-003"))
220+
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
221221
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)
222222

223223
index = GPTSQLStructStoreIndex(

0 commit comments

Comments
 (0)