Skip to content

Commit 192da49

Browse files
Initial skeleton to support local LLM #4
1 parent 25a9570 commit 192da49

File tree

5 files changed

+137
-63
lines changed

5 files changed

+137
-63
lines changed

sidekick/configs/.env.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[OPENAI]
22
OPENAI_API_KEY = ""
3-
MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
3+
MODEL_NAME = "nsql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
44

55
[LOCAL_DB_CONFIG]
66
HOST_NAME = "localhost"
@@ -13,4 +13,4 @@ PORT = "5432"
1313
LOG-LEVEL = "INFO"
1414

1515
[DB-DIALECT]
16-
DB_TYPE = "postgresql"
16+
DB_TYPE = "SQLite" # postgresql

sidekick/configs/prompt_template.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,20 @@
5959
Query:\n {qry_txt}
6060
""",
6161
}
62+
63+
NSQL_QUERY_PROMPT = """
64+
For SQL TABLE '{table_name}' sample question/answer pairs,\n({sample_queries})
65+
66+
CREATE TABLE '{table_name}'({data_info}
67+
)
68+
69+
Table '{table_name}' has sample values ({data_info_detailed})
70+
71+
72+
73+
-- Using valid {_dialect}, answer the following questions with the information for '{table_name}' provided above; for final SQL only use values from the question.
74+
75+
76+
-- Using reference for TABLES '{table_name}' {context}; {question_txt}?
77+
78+
SELECT"""

sidekick/db_config.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,6 @@ def create_table(self, schema_info_path=None, schema_info=None):
101101
# If schema information is not provided, extract from the template.
102102
self.schema_info = """,\n""".join(self._extract_schema_info(schema_info_path)).strip()
103103
logger.debug(f"Schema info used for creating table:\n {self.schema_info}")
104-
# self.schema_info = """
105-
# id uuid PRIMARY KEY,
106-
# ts TIMESTAMP WITH TIME ZONE NOT NULL,
107-
# kind TEXT NOT NULL, -- or int?,
108-
# user_id TEXT,
109-
# user_name TEXT,
110-
# resource_type TEXT NOT NULL, -- or int?,
111-
# resource_id TEXT,
112-
# stream TEXT NOT NULL,
113-
# source TEXT NOT NULL,
114-
# payload jsonb NOT NULL
115-
# """
116104
create_syntax = f"""
117105
CREATE TABLE IF NOT EXISTS {self.table_name} (
118106
{self.schema_info}

sidekick/query.py

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
import openai
88
import sqlglot
99
import toml
10+
import torch
1011
from langchain import OpenAI
11-
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
12+
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
13+
LLMPredictor, ServiceContext, SQLDatabase)
1214
from llama_index.indices.struct_store import SQLContextContainerBuilder
13-
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
15+
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
16+
TASK_PROMPT)
1417
from sidekick.logger import logger
15-
from sidekick.utils import csv_parser, filter_samples, remove_duplicates
18+
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
1619
from sqlalchemy import create_engine
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
1721

1822

1923
def _check_file_info(file_path: str):
@@ -63,7 +67,7 @@ def update_context_queries(self):
6367
new_context_queries = []
6468
if self.sample_queries_path is not None and Path(self.sample_queries_path).exists():
6569
logger.info(f"Using samples from path {self.sample_queries_path}")
66-
new_context_queries = csv_parser(self.sample_queries_path)
70+
new_context_queries = read_sample_pairs(self.sample_queries_path, "gpt")
6771
# cache the samples for future use
6872
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f:
6973
json.dump(new_context_queries, f, indent=2)
@@ -191,51 +195,105 @@ def generate_tasks(self, table_names: list, input_question: str):
191195
except Exception as se:
192196
raise se
193197

194-
def generate_sql(
195-
self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "gpt-3.5-turbo-0301"
196-
):
197-
_tasks = self.task_formatter(self._tasks)
198+
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "nsql"):
198199
context_file = f"{self.path}/var/lib/tmp/data/context.json"
199200
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
200-
201201
context_queries = self.content_queries
202-
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
203-
query_str = QUERY_PROMPT.format(
204-
_dialect=_dialect,
205-
_data_info=self._data_info,
206-
_question=input_question,
207-
_table_name=table_name,
208-
_sample_queries=context_queries,
209-
_tasks=_tasks,
210-
)
211202

212-
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
213-
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)
203+
if model_name != "nsql":
204+
_tasks = self.task_formatter(self._tasks)
214205

215-
table_schema_index = self.build_index(persist=False)
216-
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
217-
context_container = self.context_builder.build_context_container()
206+
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
207+
query_str = QUERY_PROMPT.format(
208+
_dialect=_dialect,
209+
_data_info=self._data_info,
210+
_question=input_question,
211+
_table_name=table_name,
212+
_sample_queries=context_queries,
213+
_tasks=_tasks,
214+
)
218215

219-
# Reference: https://github.com/jerryjliu/llama_index/issues/987
220-
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
221-
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)
216+
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
217+
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)
222218

223-
index = GPTSQLStructStoreIndex(
224-
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
225-
)
226-
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str)
227-
try:
228-
# Check if `SQL` is formatted ---> ``` SQL_text ```
229-
if "```" in str(res):
230-
res = (
231-
str(res).split("```", 1)[1].split(";", 1)[0].strip().replace("```", "").replace("sql\n", "").strip()
232-
)
233-
else:
234-
res = str(res).split("Explanation:", 1)[0].strip()
235-
sqlglot.transpile(res)
236-
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
237-
logger.info("We did the best we could, there might be still be some error:\n")
238-
logger.info(f"Realized query so far:\n {res}")
219+
table_schema_index = self.build_index(persist=False)
220+
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
221+
context_container = self.context_builder.build_context_container()
222+
223+
# Reference: https://github.com/jerryjliu/llama_index/issues/987
224+
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
225+
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)
226+
227+
index = GPTSQLStructStoreIndex(
228+
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
229+
)
230+
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str)
231+
try:
232+
# Check if `SQL` is formatted ---> ``` SQL_text ```
233+
if "```" in str(res):
234+
res = (
235+
str(res)
236+
.split("```", 1)[1]
237+
.split(";", 1)[0]
238+
.strip()
239+
.replace("```", "")
240+
.replace("sql\n", "")
241+
.strip()
242+
)
243+
else:
244+
res = str(res).split("Explanation:", 1)[0].strip()
245+
sqlglot.transpile(res)
246+
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
247+
logger.info("We did the best we could, there might be still be some error:\n")
248+
logger.info(f"Realized query so far:\n {res}")
249+
else:
250+
# Load h2oGPT.NSQL model
251+
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B")
252+
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")
253+
254+
data_samples = context_queries
255+
256+
_context = {
257+
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
258+
"if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT",
259+
}
260+
261+
filtered_context = filter_samples(input_question, probable_qs=list(_context.keys()),
262+
model_path='', threshold=0.845)
263+
264+
print(f"Filter Context: {filtered_context}")
265+
266+
contextual_context = []
267+
for _item in filtered_context:
268+
_val = _context.get(_item, None)
269+
if _val:
270+
contextual_context.append(f"{_item}: {_val}")
271+
272+
print("Filtering Question/Query pairs")
273+
_samples = filter_samples(input_question, probable_qs=sample_pairs,
274+
model_path=local_model_path, threshold=0.90)
275+
276+
# If QnA pairs > 5, we keep only 5 of them for focused context
277+
if len(_samples) > 5:
278+
_samples = _samples[0:5][::-1]
279+
qna_samples = '\n'.join(_samples)
280+
281+
contextual_context_val = ', '.join(contextual_context)
282+
283+
if len(_samples) > 2:
284+
# Check for the columns in the QnA samples provided, if exists keep them
285+
context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()]
286+
if len(context_columns) > 0:
287+
contextual_data_samples = [_d for _cc in context_columns for _d in data_samples_list if _cc.lower() in _d.lower()]
288+
data_samples = contextual_data_samples
289+
relevant_columns = context_columns if len(context_columns) > 0 else column_names
290+
_data_info = ', '.join(relevant_columns)
291+
292+
query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples,
293+
sample_queries=qna_samples, context=contextual_context_val,
294+
question_txt=input_question)
295+
296+
input_ids = tokenizer(query, return_tensors="pt").input_ids
239297
return res
240298

241299
def task_formatter(self, input_task: str):

sidekick/utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,29 @@ def setup_dir(base_path: str):
9595
p.mkdir(parents=True, exist_ok=True)
9696

9797

98-
def csv_parser(input_path: str):
98+
def read_sample_pairs(input_path: str, model_name: str = "nsql"):
9999
df = pd.read_csv(input_path)
100100
df = df.dropna()
101101
df = df.drop_duplicates()
102102
df = df.reset_index(drop=True)
103103

104-
# Convert frame to below format
105-
# [
106-
# "# query": ""
107-
# "# answer": ""
108-
# ]
109-
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
104+
# NSQL format
105+
if model_name != 'nsql':
106+
# Open AI format
107+
# Convert frame to below format
108+
# [
109+
# "# query": ""
110+
# "# answer": ""
111+
# ]
112+
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
113+
else:
114+
# Convert frame to below format
115+
# [
116+
# "Question": <question_text>
117+
# "Answer":
118+
# <response_text>
119+
# ]
120+
res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list()
110121
return res
111122

112123

0 commit comments

Comments
 (0)