|
8 | 8 | import numpy as np
|
9 | 9 | import openai
|
10 | 10 | import sqlglot
|
| 11 | +import sqlparse |
11 | 12 | import torch
|
12 | 13 | import torch.nn.functional as F
|
13 | 14 | from langchain import OpenAI
|
@@ -87,6 +88,7 @@ def __init__(
|
87 | 88 | self.is_regenerate_with_options = is_regenerate_with_options
|
88 | 89 | self.is_regenerate = is_regenerate
|
89 | 90 | self.device = device
|
| 91 | + self.table_name = None |
90 | 92 |
|
91 | 93 | def clear(self):
|
92 | 94 | del SQLGenerator._instance
|
@@ -129,7 +131,7 @@ def update_context_queries(self):
|
129 | 131 | with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "r") as f:
|
130 | 132 | new_context_queries = json.load(f)
|
131 | 133 | # Read the history file and update the context queries
|
132 |
| - history_file = f"{self.path}/var/lib/tmp/data/history.jsonl" |
| 134 | + history_file = f"{self.path}/var/lib/tmp/data/{self.table_name}/history.jsonl" |
133 | 135 | try:
|
134 | 136 | if Path(history_file).exists():
|
135 | 137 | with open(history_file, "r") as in_file:
|
@@ -218,6 +220,7 @@ def generate_tasks(self, table_names: list, input_question: str):
|
218 | 220 | try:
|
219 | 221 | # Step 1: Given a question, generate tasks to possibly answer the question and persist the result -> tasks.txt
|
220 | 222 | # Step 2: Append task list to 'query_prompt_template', generate SQL code to answer the question and persist the result -> sql.txt
|
| 223 | + self.table_name = table_names[0] |
221 | 224 | context_queries: list = self.update_context_queries()
|
222 | 225 | logger.info(f"Number of context queries found: {len(context_queries)}")
|
223 | 226 |
|
@@ -263,6 +266,7 @@ def generate_sql(
|
263 | 266 | ):
|
264 | 267 | # TODO: Update needed to support multiple tables
|
265 | 268 | table_name = str(table_names[0].replace(" ", "_")).lower()
|
| 269 | + self.table_name = table_name |
266 | 270 | alternate_queries = []
|
267 | 271 | describe_keywords = ["describe table", "describe", "describe table schema", "describe data"]
|
268 | 272 | enable_describe_qry = any([True for _dk in describe_keywords if _dk in input_question.lower()])
|
@@ -546,7 +550,12 @@ def generate_sql(
|
546 | 550 | res = "SELECT " + result.strip() + " LIMIT 100;"
|
547 | 551 | else:
|
548 | 552 | res = "SELECT " + result.strip() + ";"
|
549 |
| - alt_res = f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{res}\n" |
| 553 | + |
| 554 | + pretty_sql = sqlparse.format(res, reindent=True, keyword_case="upper") |
| 555 | + syntax_highlight = f"""``` sql\n{pretty_sql}\n```\n\n""" |
| 556 | + alt_res = ( |
| 557 | + f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{syntax_highlight}\n" |
| 558 | + ) |
550 | 559 | alternate_queries.append(alt_res)
|
551 | 560 | logger.info(alt_res)
|
552 | 561 |
|
|
0 commit comments