|
7 | 7 | import openai |
8 | 8 | import sqlglot |
9 | 9 | import toml |
| 10 | +import torch |
10 | 11 | from langchain import OpenAI |
11 | | -from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase |
| 12 | +from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, |
| 13 | + LLMPredictor, ServiceContext, SQLDatabase) |
12 | 14 | from llama_index.indices.struct_store import SQLContextContainerBuilder |
13 | | -from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT |
| 15 | +from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT, |
| 16 | + TASK_PROMPT) |
14 | 17 | from sidekick.logger import logger |
15 | | -from sidekick.utils import csv_parser, filter_samples, remove_duplicates |
| 18 | +from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates |
16 | 19 | from sqlalchemy import create_engine |
| 20 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
17 | 21 |
|
18 | 22 |
|
19 | 23 | def _check_file_info(file_path: str): |
@@ -63,7 +67,7 @@ def update_context_queries(self): |
63 | 67 | new_context_queries = [] |
64 | 68 | if self.sample_queries_path is not None and Path(self.sample_queries_path).exists(): |
65 | 69 | logger.info(f"Using samples from path {self.sample_queries_path}") |
66 | | - new_context_queries = csv_parser(self.sample_queries_path) |
| 70 | + new_context_queries = read_sample_pairs(self.sample_queries_path, "gpt") |
67 | 71 | # cache the samples for future use |
68 | 72 | with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f: |
69 | 73 | json.dump(new_context_queries, f, indent=2) |
@@ -191,51 +195,106 @@ def generate_tasks(self, table_names: list, input_question: str): |
191 | 195 | except Exception as se: |
192 | 196 | raise se |
193 | 197 |
|
194 | | - def generate_sql( |
195 | | - self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "gpt-3.5-turbo-0301" |
196 | | - ): |
197 | | - _tasks = self.task_formatter(self._tasks) |
| 198 | + |
| 199 | + def generate_sql(self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "nsql"): |
198 | 200 | context_file = f"{self.path}/var/lib/tmp/data/context.json" |
199 | 201 | additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {} |
200 | | - |
201 | 202 | context_queries = self.content_queries |
202 | | - # TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate. |
203 | | - query_str = QUERY_PROMPT.format( |
204 | | - _dialect=_dialect, |
205 | | - _data_info=self._data_info, |
206 | | - _question=input_question, |
207 | | - _table_name=table_name, |
208 | | - _sample_queries=context_queries, |
209 | | - _tasks=_tasks, |
210 | | - ) |
211 | 203 |
|
212 | | - table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()} |
213 | | - self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) |
| 204 | + if model_name != "nsql": |
| 205 | + _tasks = self.task_formatter(self._tasks) |
214 | 206 |
|
215 | | - table_schema_index = self.build_index(persist=False) |
216 | | - self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) |
217 | | - context_container = self.context_builder.build_context_container() |
| 207 | + # TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate. |
| 208 | + query_str = QUERY_PROMPT.format( |
| 209 | + _dialect=_dialect, |
| 210 | + _data_info=self._data_info, |
| 211 | + _question=input_question, |
| 212 | + _table_name=table_name, |
| 213 | + _sample_queries=context_queries, |
| 214 | + _tasks=_tasks, |
| 215 | + ) |
218 | 216 |
|
219 | | - # Reference: https://github.com/jerryjliu/llama_index/issues/987 |
220 | | - llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name)) |
221 | | - service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512) |
| 217 | + table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()} |
| 218 | + self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) |
222 | 219 |
|
223 | | - index = GPTSQLStructStoreIndex( |
224 | | - [], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3 |
225 | | - ) |
226 | | - res = self.generate_response(context_container, sql_index=index, input_prompt=query_str, _dialect = _dialect) |
227 | | - try: |
228 | | - # Check if `SQL` is formatted ---> ``` SQL_text ``` |
229 | | - if "```" in str(res): |
230 | | - res = ( |
231 | | - str(res).split("```", 1)[1].split(";", 1)[0].strip().replace("```", "").replace("sql\n", "").strip() |
232 | | - ) |
233 | | - else: |
234 | | - res = str(res).split("Explanation:", 1)[0].strip() |
235 | | - sqlglot.transpile(res) |
236 | | - except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: |
237 | | - logger.info("We did the best we could, there might be still be some error:\n") |
238 | | - logger.info(f"Realized query so far:\n {res}") |
| 220 | + table_schema_index = self.build_index(persist=False) |
| 221 | + self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) |
| 222 | + context_container = self.context_builder.build_context_container() |
| 223 | + |
| 224 | + # Reference: https://github.com/jerryjliu/llama_index/issues/987 |
| 225 | + llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name)) |
| 226 | + service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512) |
| 227 | + |
| 228 | + index = GPTSQLStructStoreIndex( |
| 229 | + [], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3 |
| 230 | + ) |
| 231 | + res = self.generate_response(context_container, sql_index=index, input_prompt=query_str) |
| 232 | + try: |
| 233 | + # Check if `SQL` is formatted ---> ``` SQL_text ``` |
| 234 | + if "```" in str(res): |
| 235 | + res = ( |
| 236 | + str(res) |
| 237 | + .split("```", 1)[1] |
| 238 | + .split(";", 1)[0] |
| 239 | + .strip() |
| 240 | + .replace("```", "") |
| 241 | + .replace("sql\n", "") |
| 242 | + .strip() |
| 243 | + ) |
| 244 | + else: |
| 245 | + res = str(res).split("Explanation:", 1)[0].strip() |
| 246 | + sqlglot.transpile(res) |
| 247 | + except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: |
| 248 | + logger.info("We did the best we could, there might be still be some error:\n") |
| 249 | + logger.info(f"Realized query so far:\n {res}") |
| 250 | + else: |
| 251 | + # Load h2oGPT.NSQL model |
| 252 | + tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B") |
| 253 | + model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B") |
| 254 | + |
| 255 | + data_samples = context_queries |
| 256 | + |
| 257 | + _context = { |
| 258 | + "if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL", |
| 259 | + "if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT", |
| 260 | + } |
| 261 | + |
| 262 | + filtered_context = filter_samples(input_question, probable_qs=list(_context.keys()), |
| 263 | + model_path='', threshold=0.845) |
| 264 | + |
| 265 | + print(f"Filter Context: {filtered_context}") |
| 266 | + |
| 267 | + contextual_context = [] |
| 268 | + for _item in filtered_context: |
| 269 | + _val = _context.get(_item, None) |
| 270 | + if _val: |
| 271 | + contextual_context.append(f"{_item}: {_val}") |
| 272 | + |
| 273 | + print("Filtering Question/Query pairs") |
| 274 | + _samples = filter_samples(input_question, probable_qs=sample_pairs, |
| 275 | + model_path=local_model_path, threshold=0.90) |
| 276 | + |
| 277 | + # If QnA pairs > 5, we keep only 5 of them for focused context |
| 278 | + if len(_samples) > 5: |
| 279 | + _samples = _samples[0:5][::-1] |
| 280 | + qna_samples = '\n'.join(_samples) |
| 281 | + |
| 282 | + contextual_context_val = ', '.join(contextual_context) |
| 283 | + |
| 284 | + if len(_samples) > 2: |
| 285 | + # Check for the columns in the QnA samples provided, if exists keep them |
| 286 | + context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()] |
| 287 | + if len(context_columns) > 0: |
| 288 | + contextual_data_samples = [_d for _cc in context_columns for _d in data_samples_list if _cc.lower() in _d.lower()] |
| 289 | + data_samples = contextual_data_samples |
| 290 | + relevant_columns = context_columns if len(context_columns) > 0 else column_names |
| 291 | + _data_info = ', '.join(relevant_columns) |
| 292 | + |
| 293 | + query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples, |
| 294 | + sample_queries=qna_samples, context=contextual_context_val, |
| 295 | + question_txt=input_question) |
| 296 | + |
| 297 | + input_ids = tokenizer(query, return_tensors="pt").input_ids |
239 | 298 | return res |
240 | 299 |
|
241 | 300 | def task_formatter(self, input_task: str): |
|
0 commit comments