Skip to content

Commit c1e2cc4

Browse files
Initial skeleton to support local LLM #4
1 parent 1a1e3de commit c1e2cc4

File tree

5 files changed

+138
-63
lines changed

5 files changed

+138
-63
lines changed

sidekick/configs/.env.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[OPENAI]
22
OPENAI_API_KEY = ""
3-
MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
3+
MODEL_NAME = "nsql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
44

55
[LOCAL_DB_CONFIG]
66
HOST_NAME = "localhost"
@@ -13,7 +13,7 @@ PORT = "5432"
1313
LOG-LEVEL = "INFO"
1414

1515
[DB-DIALECT]
16-
DB_TYPE = "sqlite"
16+
DB_TYPE = "sqlite" # postgresql
1717

1818
[TABLE_INFO]
1919
TABLE_INFO_PATH = "/examples/test/table_info.jsonl"

sidekick/configs/prompt_template.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,20 @@
5959
Query:\n {qry_txt}
6060
""",
6161
}
62+
63+
NSQL_QUERY_PROMPT = """
64+
For SQL TABLE '{table_name}' sample question/answer pairs,\n({sample_queries})
65+
66+
CREATE TABLE '{table_name}'({data_info}
67+
)
68+
69+
Table '{table_name}' has sample values ({data_info_detailed})
70+
71+
72+
73+
-- Using valid {_dialect}, answer the following questions with the information for '{table_name}' provided above; for final SQL only use values from the question.
74+
75+
76+
-- Using reference for TABLES '{table_name}' {context}; {question_txt}?
77+
78+
SELECT"""

sidekick/db_config.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,6 @@ def create_table(self, schema_info_path=None, schema_info=None):
112112
# If schema information is not provided, extract from the template.
113113
self.schema_info = """,\n""".join(self._extract_schema_info(schema_info_path)).strip()
114114
logger.debug(f"Schema info used for creating table:\n {self.schema_info}")
115-
# self.schema_info = """
116-
# id uuid PRIMARY KEY,
117-
# ts TIMESTAMP WITH TIME ZONE NOT NULL,
118-
# kind TEXT NOT NULL, -- or int?,
119-
# user_id TEXT,
120-
# user_name TEXT,
121-
# resource_type TEXT NOT NULL, -- or int?,
122-
# resource_id TEXT,
123-
# stream TEXT NOT NULL,
124-
# source TEXT NOT NULL,
125-
# payload jsonb NOT NULL
126-
# """
127115
create_syntax = f"""
128116
CREATE TABLE IF NOT EXISTS {self.table_name} (
129117
{self.schema_info}

sidekick/query.py

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
import openai
88
import sqlglot
99
import toml
10+
import torch
1011
from langchain import OpenAI
11-
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
12+
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
13+
LLMPredictor, ServiceContext, SQLDatabase)
1214
from llama_index.indices.struct_store import SQLContextContainerBuilder
13-
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
15+
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
16+
TASK_PROMPT)
1417
from sidekick.logger import logger
15-
from sidekick.utils import csv_parser, filter_samples, remove_duplicates
18+
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
1619
from sqlalchemy import create_engine
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
1721

1822

1923
def _check_file_info(file_path: str):
@@ -63,7 +67,7 @@ def update_context_queries(self):
6367
new_context_queries = []
6468
if self.sample_queries_path is not None and Path(self.sample_queries_path).exists():
6569
logger.info(f"Using samples from path {self.sample_queries_path}")
66-
new_context_queries = csv_parser(self.sample_queries_path)
70+
new_context_queries = read_sample_pairs(self.sample_queries_path, "gpt")
6771
# cache the samples for future use
6872
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f:
6973
json.dump(new_context_queries, f, indent=2)
@@ -191,51 +195,106 @@ def generate_tasks(self, table_names: list, input_question: str):
191195
except Exception as se:
192196
raise se
193197

194-
def generate_sql(
195-
self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "gpt-3.5-turbo-0301"
196-
):
197-
_tasks = self.task_formatter(self._tasks)
198+
199+
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "nsql"):
198200
context_file = f"{self.path}/var/lib/tmp/data/context.json"
199201
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
200-
201202
context_queries = self.content_queries
202-
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
203-
query_str = QUERY_PROMPT.format(
204-
_dialect=_dialect,
205-
_data_info=self._data_info,
206-
_question=input_question,
207-
_table_name=table_name,
208-
_sample_queries=context_queries,
209-
_tasks=_tasks,
210-
)
211203

212-
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
213-
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)
204+
if model_name != "nsql":
205+
_tasks = self.task_formatter(self._tasks)
214206

215-
table_schema_index = self.build_index(persist=False)
216-
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
217-
context_container = self.context_builder.build_context_container()
207+
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
208+
query_str = QUERY_PROMPT.format(
209+
_dialect=_dialect,
210+
_data_info=self._data_info,
211+
_question=input_question,
212+
_table_name=table_name,
213+
_sample_queries=context_queries,
214+
_tasks=_tasks,
215+
)
218216

219-
# Reference: https://github.com/jerryjliu/llama_index/issues/987
220-
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
221-
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)
217+
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
218+
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)
222219

223-
index = GPTSQLStructStoreIndex(
224-
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
225-
)
226-
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str, _dialect = _dialect)
227-
try:
228-
# Check if `SQL` is formatted ---> ``` SQL_text ```
229-
if "```" in str(res):
230-
res = (
231-
str(res).split("```", 1)[1].split(";", 1)[0].strip().replace("```", "").replace("sql\n", "").strip()
232-
)
233-
else:
234-
res = str(res).split("Explanation:", 1)[0].strip()
235-
sqlglot.transpile(res)
236-
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
237-
logger.info("We did the best we could, there might be still be some error:\n")
238-
logger.info(f"Realized query so far:\n {res}")
220+
table_schema_index = self.build_index(persist=False)
221+
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
222+
context_container = self.context_builder.build_context_container()
223+
224+
# Reference: https://github.com/jerryjliu/llama_index/issues/987
225+
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
226+
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)
227+
228+
index = GPTSQLStructStoreIndex(
229+
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
230+
)
231+
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str)
232+
try:
233+
# Check if `SQL` is formatted ---> ``` SQL_text ```
234+
if "```" in str(res):
235+
res = (
236+
str(res)
237+
.split("```", 1)[1]
238+
.split(";", 1)[0]
239+
.strip()
240+
.replace("```", "")
241+
.replace("sql\n", "")
242+
.strip()
243+
)
244+
else:
245+
res = str(res).split("Explanation:", 1)[0].strip()
246+
sqlglot.transpile(res)
247+
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
248+
logger.info("We did the best we could, there might be still be some error:\n")
249+
logger.info(f"Realized query so far:\n {res}")
250+
else:
251+
# Load h2oGPT.NSQL model
252+
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B")
253+
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")
254+
255+
data_samples = context_queries
256+
257+
_context = {
258+
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
259+
"if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT",
260+
}
261+
262+
filtered_context = filter_samples(input_question, probable_qs=list(_context.keys()),
263+
model_path='', threshold=0.845)
264+
265+
print(f"Filter Context: {filtered_context}")
266+
267+
contextual_context = []
268+
for _item in filtered_context:
269+
_val = _context.get(_item, None)
270+
if _val:
271+
contextual_context.append(f"{_item}: {_val}")
272+
273+
print("Filtering Question/Query pairs")
274+
_samples = filter_samples(input_question, probable_qs=sample_pairs,
275+
model_path=local_model_path, threshold=0.90)
276+
277+
# If QnA pairs > 5, we keep only 5 of them for focused context
278+
if len(_samples) > 5:
279+
_samples = _samples[0:5][::-1]
280+
qna_samples = '\n'.join(_samples)
281+
282+
contextual_context_val = ', '.join(contextual_context)
283+
284+
if len(_samples) > 2:
285+
# Check for the columns in the QnA samples provided, if exists keep them
286+
context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()]
287+
if len(context_columns) > 0:
288+
contextual_data_samples = [_d for _cc in context_columns for _d in data_samples_list if _cc.lower() in _d.lower()]
289+
data_samples = contextual_data_samples
290+
relevant_columns = context_columns if len(context_columns) > 0 else column_names
291+
_data_info = ', '.join(relevant_columns)
292+
293+
query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples,
294+
sample_queries=qna_samples, context=contextual_context_val,
295+
question_txt=input_question)
296+
297+
input_ids = tokenizer(query, return_tensors="pt").input_ids
239298
return res
240299

241300
def task_formatter(self, input_task: str):

sidekick/utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,29 @@ def setup_dir(base_path: str):
9595
p.mkdir(parents=True, exist_ok=True)
9696

9797

98-
def csv_parser(input_path: str):
98+
def read_sample_pairs(input_path: str, model_name: str = "nsql"):
9999
df = pd.read_csv(input_path)
100100
df = df.dropna()
101101
df = df.drop_duplicates()
102102
df = df.reset_index(drop=True)
103103

104-
# Convert frame to below format
105-
# [
106-
# "# query": ""
107-
# "# answer": ""
108-
# ]
109-
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
104+
# NSQL format
105+
if model_name != 'nsql':
106+
# Open AI format
107+
# Convert frame to below format
108+
# [
109+
# "# query": ""
110+
# "# answer": ""
111+
# ]
112+
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
113+
else:
114+
# Convert frame to below format
115+
# [
116+
# "Question": <question_text>
117+
# "Answer":
118+
# <response_text>
119+
# ]
120+
res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list()
110121
return res
111122

112123

0 commit comments

Comments
 (0)