Skip to content

Commit b60fb7b

Browse files
Enable syntax highlighting for multiple responses
1 parent eb02b9e commit b60fb7b

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

sidekick/query.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import openai
1010
import sqlglot
11+
import sqlparse
1112
import torch
1213
import torch.nn.functional as F
1314
from langchain import OpenAI
@@ -87,6 +88,7 @@ def __init__(
8788
self.is_regenerate_with_options = is_regenerate_with_options
8889
self.is_regenerate = is_regenerate
8990
self.device = device
91+
self.table_name = None
9092

9193
def clear(self):
9294
del SQLGenerator._instance
@@ -129,7 +131,7 @@ def update_context_queries(self):
129131
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "r") as f:
130132
new_context_queries = json.load(f)
131133
# Read the history file and update the context queries
132-
history_file = f"{self.path}/var/lib/tmp/data/history.jsonl"
134+
history_file = f"{self.path}/var/lib/tmp/data/{self.table_name}/history.jsonl"
133135
try:
134136
if Path(history_file).exists():
135137
with open(history_file, "r") as in_file:
@@ -218,6 +220,7 @@ def generate_tasks(self, table_names: list, input_question: str):
218220
try:
219221
# Step 1: Given a question, generate tasks to possibly answer the question and persist the result -> tasks.txt
220222
# Step 2: Append task list to 'query_prompt_template', generate SQL code to answer the question and persist the result -> sql.txt
223+
self.table_name = table_names[0]
221224
context_queries: list = self.update_context_queries()
222225
logger.info(f"Number of context queries found: {len(context_queries)}")
223226

@@ -263,6 +266,7 @@ def generate_sql(
263266
):
264267
# TODO: Update needed to support multiple tables
265268
table_name = str(table_names[0].replace(" ", "_")).lower()
269+
self.table_name = table_name
266270
alternate_queries = []
267271
describe_keywords = ["describe table", "describe", "describe table schema", "describe data"]
268272
enable_describe_qry = any([True for _dk in describe_keywords if _dk in input_question.lower()])
@@ -546,7 +550,12 @@ def generate_sql(
546550
res = "SELECT " + result.strip() + " LIMIT 100;"
547551
else:
548552
res = "SELECT " + result.strip() + ";"
549-
alt_res = f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{res}\n"
553+
554+
pretty_sql = sqlparse.format(res, reindent=True, keyword_case="upper")
555+
syntax_highlight = f"""``` sql\n{pretty_sql}\n```\n\n"""
556+
alt_res = (
557+
f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{syntax_highlight}\n"
558+
)
550559
alternate_queries.append(alt_res)
551560
logger.info(alt_res)
552561

0 commit comments

Comments
 (0)