Skip to content

Commit 32bc877

Browse files
Enable alternate options in chat #4
1 parent 379a1ec commit 32bc877

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

sidekick/prompter.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from sidekick.db_config import DBConfig
1414
from sidekick.memory import EntityMemory
1515
from sidekick.query import SQLGenerator
16-
from sidekick.utils import execute_query_pd, extract_table_names, save_query, setup_dir
16+
from sidekick.utils import (execute_query_pd, extract_table_names, save_query,
17+
setup_dir)
1718

1819
# Load the config file and initialize required paths
1920
base_path = (Path(__file__).parent / "../").resolve()
@@ -410,7 +411,7 @@ def query_api(
410411
if updated_tasks is not None:
411412
sql_g._tasks = updated_tasks
412413

413-
res = sql_g.generate_sql(
414+
res, alt_res = sql_g.generate_sql(
414415
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate
415416
)
416417
logger.info(f"Input query: {question}")
@@ -430,13 +431,14 @@ def query_api(
430431
click.echo(f"Updated SQL:\n {updated_sql}")
431432
elif res_val.lower() == "r" or res_val.lower() == "regenerate":
432433
click.echo("Attempting to regenerate...")
433-
res = sql_g.generate_sql(
434-
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=True
434+
res, alt_res = sql_g.generate_sql(
435+
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate
435436
)
436437
logger.info(f"Input query: {question}")
437438
logger.info(f"Generated response:\n\n{res}")
438439

439-
results.extend(["Generated Query:\n", res, "\n"])
440+
results.extend(["**Generated Query:**\n", res, "\n"])
441+
logger.info(f"Alternate responses:\n\n{alt_res}")
440442

441443
exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") if is_command else "y"
442444
if exe_sql.lower() == "y" or exe_sql.lower() == "yes":
@@ -484,14 +486,14 @@ def query_api(
484486
click.echo("Error in executing the query. Validate generated SQL and try again.")
485487
click.echo("No result to display.")
486488

487-
results.append("Query Results: \n")
489+
results.append("**Query Results:** \n")
488490
if q_res:
489491
click.echo(f"The query results are:\n {q_res}")
490492
results.extend([str(q_res), "\n"])
491493
else:
492494
click.echo(f"While executing query:\n {err}")
493495
results.extend([str(err), "\n"])
494-
# results.extend(["Query Results:", q_res])
496+
495497
save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n"
496498
if save_sql.lower() == "y" or save_sql.lower() == "yes":
497499
# Persist for future use
@@ -500,7 +502,7 @@ def query_api(
500502
else:
501503
click.echo("Exiting...")
502504

503-
return results, err
505+
return results, alt_res, err
504506

505507

506508
if __name__ == "__main__":

sidekick/query.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ def generate_sql(
421421
# Greedy search for quick response
422422
self.model.eval()
423423
device_type = "cuda" if torch.cuda.is_available() else "cpu"
424+
425+
alternate_queries = []
424426
if not is_regenerate:
425427
output = self.model.generate(
426428
**inputs.to(device_type),
@@ -476,7 +478,9 @@ def generate_sql(
476478
res = "SELECT " + result.strip() + " LIMIT 100;"
477479
else:
478480
res = "SELECT " + result.strip() + ";"
479-
logger.info(f"Option: {idx+1}:\n{res}\nprobability: {probabilities_scores[idx]}")
481+
alt_res = f"Option {idx+1}: (_probability_: {probabilities_scores[idx]})\n{res}"
482+
alternate_queries.append(alt_res)
483+
logger.info(alt_res)
480484

481485
_res = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
482486
# Below is a pre-caution in-case of an error in table name during generation
@@ -500,7 +504,7 @@ def generate_sql(
500504
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
501505
logger.info("We did the best we could, there might be still be some error:\n")
502506
logger.info(f"Realized query so far:\n {res}")
503-
return result
507+
return result, alternate_queries
504508

505509
def task_formatter(self, input_task: str):
506510
# Generated format

ui/app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import toml
88
from h2o_wave import Q, app, data, handle_on, main, on, ui
99
from sidekick.prompter import db_setup_api, query_api
10-
from sidekick.utils import setup_dir, update_tables, get_table_keys
10+
from sidekick.utils import get_table_keys, setup_dir, update_tables
1111

1212
# Load the config file and initialize required paths
1313
base_path = (Path(__file__).parent / "../").resolve()
@@ -130,20 +130,27 @@ async def chatbot(q: Q):
130130
)
131131
elif q.args.chatbot.lower() == "regenerate":
132132
if q.client.query is not None and q.client.query.strip() != "":
133-
llm_response, err = query_api(
133+
llm_response, alt_response, err = query_api(
134134
question=q.client.query,
135135
sample_queries_path=q.user.sample_qna_path,
136136
table_info_path=q.user.table_info_path,
137137
table_name=q.user.table_name,
138138
is_regenerate=True,
139139
)
140-
llm_response = "\n".join(llm_response)
140+
response = "\n".join(llm_response)
141+
if alt_response:
142+
llm_response = response + "\n\n" + "**Alternate options:**\n" + "\n".join(alt_response)
143+
logging.info(f"Regenerate response: {llm_response}")
144+
else:
145+
llm_response = response
141146
else:
142-
llm_response, err = ("Sure, I can generate a new response for you. However, in order to assist you "
143-
"effectively could you please provide me with your question?"), None
147+
llm_response, err = (
148+
"Sure, I can generate a new response for you. However, in order to assist you "
149+
"effectively could you please provide me with your question?"
150+
), None
144151
else:
145152
q.client.query = question
146-
llm_response, err = query_api(
153+
llm_response, alt_response, err = query_api(
147154
question=question,
148155
sample_queries_path=q.user.sample_qna_path,
149156
table_info_path=q.user.table_info_path,

0 commit comments

Comments
 (0)