|
7 | 7 | import openai
|
8 | 8 | import sqlglot
|
9 | 9 | import toml
|
| 10 | +import torch |
10 | 11 | from langchain import OpenAI
|
11 |
| -from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase |
| 12 | +from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, |
| 13 | + LLMPredictor, ServiceContext, SQLDatabase) |
12 | 14 | from llama_index.indices.struct_store import SQLContextContainerBuilder
|
13 |
| -from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT |
| 15 | +from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT, |
| 16 | + TASK_PROMPT) |
14 | 17 | from sidekick.logger import logger
|
15 |
| -from sidekick.utils import csv_parser, filter_samples, remove_duplicates |
| 18 | +from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates |
16 | 19 | from sqlalchemy import create_engine
|
| 20 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
17 | 21 |
|
18 | 22 |
|
19 | 23 | def _check_file_info(file_path: str):
|
@@ -63,7 +67,7 @@ def update_context_queries(self):
|
63 | 67 | new_context_queries = []
|
64 | 68 | if self.sample_queries_path is not None and Path(self.sample_queries_path).exists():
|
65 | 69 | logger.info(f"Using samples from path {self.sample_queries_path}")
|
66 |
| - new_context_queries = csv_parser(self.sample_queries_path) |
| 70 | + new_context_queries = read_sample_pairs(self.sample_queries_path, "gpt") |
67 | 71 | # cache the samples for future use
|
68 | 72 | with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f:
|
69 | 73 | json.dump(new_context_queries, f, indent=2)
|
@@ -191,51 +195,105 @@ def generate_tasks(self, table_names: list, input_question: str):
|
191 | 195 | except Exception as se:
|
192 | 196 | raise se
|
193 | 197 |
|
194 |
| - def generate_sql( |
195 |
| - self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "gpt-3.5-turbo-0301" |
196 |
| - ): |
197 |
| - _tasks = self.task_formatter(self._tasks) |
| 198 | + def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "nsql"): |
198 | 199 | context_file = f"{self.path}/var/lib/tmp/data/context.json"
|
199 | 200 | additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
|
200 |
| - |
201 | 201 | context_queries = self.content_queries
|
202 |
| - # TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate. |
203 |
| - query_str = QUERY_PROMPT.format( |
204 |
| - _dialect=_dialect, |
205 |
| - _data_info=self._data_info, |
206 |
| - _question=input_question, |
207 |
| - _table_name=table_name, |
208 |
| - _sample_queries=context_queries, |
209 |
| - _tasks=_tasks, |
210 |
| - ) |
211 | 202 |
|
212 |
| - table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()} |
213 |
| - self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) |
| 203 | + if model_name != "nsql": |
| 204 | + _tasks = self.task_formatter(self._tasks) |
214 | 205 |
|
215 |
| - table_schema_index = self.build_index(persist=False) |
216 |
| - self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) |
217 |
| - context_container = self.context_builder.build_context_container() |
| 206 | + # TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate. |
| 207 | + query_str = QUERY_PROMPT.format( |
| 208 | + _dialect=_dialect, |
| 209 | + _data_info=self._data_info, |
| 210 | + _question=input_question, |
| 211 | + _table_name=table_name, |
| 212 | + _sample_queries=context_queries, |
| 213 | + _tasks=_tasks, |
| 214 | + ) |
218 | 215 |
|
219 |
| - # Reference: https://github.com/jerryjliu/llama_index/issues/987 |
220 |
| - llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name)) |
221 |
| - service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512) |
| 216 | + table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()} |
| 217 | + self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) |
222 | 218 |
|
223 |
| - index = GPTSQLStructStoreIndex( |
224 |
| - [], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3 |
225 |
| - ) |
226 |
| - res = self.generate_response(context_container, sql_index=index, input_prompt=query_str) |
227 |
| - try: |
228 |
| - # Check if `SQL` is formatted ---> ``` SQL_text ``` |
229 |
| - if "```" in str(res): |
230 |
| - res = ( |
231 |
| - str(res).split("```", 1)[1].split(";", 1)[0].strip().replace("```", "").replace("sql\n", "").strip() |
232 |
| - ) |
233 |
| - else: |
234 |
| - res = str(res).split("Explanation:", 1)[0].strip() |
235 |
| - sqlglot.transpile(res) |
236 |
| - except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: |
237 |
| - logger.info("We did the best we could, there might be still be some error:\n") |
238 |
| - logger.info(f"Realized query so far:\n {res}") |
| 219 | + table_schema_index = self.build_index(persist=False) |
| 220 | + self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) |
| 221 | + context_container = self.context_builder.build_context_container() |
| 222 | + |
| 223 | + # Reference: https://github.com/jerryjliu/llama_index/issues/987 |
| 224 | + llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name)) |
| 225 | + service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512) |
| 226 | + |
| 227 | + index = GPTSQLStructStoreIndex( |
| 228 | + [], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3 |
| 229 | + ) |
| 230 | + res = self.generate_response(context_container, sql_index=index, input_prompt=query_str) |
| 231 | + try: |
| 232 | + # Check if `SQL` is formatted ---> ``` SQL_text ``` |
| 233 | + if "```" in str(res): |
| 234 | + res = ( |
| 235 | + str(res) |
| 236 | + .split("```", 1)[1] |
| 237 | + .split(";", 1)[0] |
| 238 | + .strip() |
| 239 | + .replace("```", "") |
| 240 | + .replace("sql\n", "") |
| 241 | + .strip() |
| 242 | + ) |
| 243 | + else: |
| 244 | + res = str(res).split("Explanation:", 1)[0].strip() |
| 245 | + sqlglot.transpile(res) |
| 246 | + except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: |
| 247 | + logger.info("We did the best we could, there might be still be some error:\n") |
| 248 | + logger.info(f"Realized query so far:\n {res}") |
| 249 | + else: |
| 250 | + # Load h2oGPT.NSQL model |
| 251 | + tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B") |
| 252 | + model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B") |
| 253 | + |
| 254 | + data_samples = context_queries |
| 255 | + |
| 256 | + _context = { |
| 257 | + "if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL", |
| 258 | + "if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT", |
| 259 | + } |
| 260 | + |
| 261 | + filtered_context = filter_samples(input_question, probable_qs=list(_context.keys()), |
| 262 | + model_path='', threshold=0.845) |
| 263 | + |
| 264 | + print(f"Filter Context: {filtered_context}") |
| 265 | + |
| 266 | + contextual_context = [] |
| 267 | + for _item in filtered_context: |
| 268 | + _val = _context.get(_item, None) |
| 269 | + if _val: |
| 270 | + contextual_context.append(f"{_item}: {_val}") |
| 271 | + |
| 272 | + print("Filtering Question/Query pairs") |
| 273 | + _samples = filter_samples(input_question, probable_qs=sample_pairs, |
| 274 | + model_path=local_model_path, threshold=0.90) |
| 275 | + |
| 276 | + # If QnA pairs > 5, we keep only 5 of them for focused context |
| 277 | + if len(_samples) > 5: |
| 278 | + _samples = _samples[0:5][::-1] |
| 279 | + qna_samples = '\n'.join(_samples) |
| 280 | + |
| 281 | + contextual_context_val = ', '.join(contextual_context) |
| 282 | + |
| 283 | + if len(_samples) > 2: |
| 284 | + # Check for the columns in the QnA samples provided, if exists keep them |
| 285 | + context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()] |
| 286 | + if len(context_columns) > 0: |
| 287 | + contextual_data_samples = [_d for _cc in context_columns for _d in data_samples_list if _cc.lower() in _d.lower()] |
| 288 | + data_samples = contextual_data_samples |
| 289 | + relevant_columns = context_columns if len(context_columns) > 0 else column_names |
| 290 | + _data_info = ', '.join(relevant_columns) |
| 291 | + |
| 292 | + query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples, |
| 293 | + sample_queries=qna_samples, context=contextual_context_val, |
| 294 | + question_txt=input_question) |
| 295 | + |
| 296 | + input_ids = tokenizer(query, return_tensors="pt").input_ids |
239 | 297 | return res
|
240 | 298 |
|
241 | 299 | def task_formatter(self, input_task: str):
|
|
0 commit comments