Skip to content

Commit 061822c

Browse files
Cache column sample values for future use #4
1 parent c1e2cc4 commit 061822c

File tree

4 files changed

+77
-16
lines changed

4 files changed

+77
-16
lines changed

sidekick/configs/data_template.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
"Column Type": "",
1212
"Sample Values": []
1313
}
14+
15+
data_samples_template = "Column {column_name} contains values similar to {comma_separated_sample_values}."

sidekick/db_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sqlalchemy
88
from pandasql import sqldf
99
from psycopg2.extras import Json
10+
from sidekick.configs.data_template import data_samples_template
1011
from sidekick.logger import logger
1112
from sqlalchemy import create_engine
1213
from sqlalchemy_utils import database_exists
@@ -84,6 +85,7 @@ def _extract_schema_info(self, schema_info_path=None):
8485
with open(table_info_file, "w") as outfile:
8586
schema_info_path = json.load(outfile)["schema_info_path"]
8687
res = []
88+
sample_values = []
8789
try:
8890
if Path(schema_info_path).exists():
8991
with open(schema_info_path, "r") as in_file:
@@ -93,8 +95,21 @@ def _extract_schema_info(self, schema_info_path=None):
9395
if "Column Name" in data and "Column Type" in data:
9496
col_name = data["Column Name"]
9597
col_type = data["Column Type"]
98+
# if column has sample values, save in cache for future use.
99+
if "Sample Values" in data:
100+
_sample_values = data["Sample Values"]
101+
_ds = data_samples_template.format(
102+
column_name=col_name, comma_separated_sample_values=",".join(_sample_values)
103+
)
104+
sample_values.append(_ds)
96105
_new_samples = f"{col_name} {col_type}"
97106
res.append(_new_samples)
107+
if len(sample_values):
108+
# cache it for future use
109+
with open(
110+
f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w"
111+
) as outfile:
112+
json.dump(sample_values, outfile, indent=2, sort_keys=False)
98113
except ValueError as ve:
99114
logger.error(f"Error in reading table context file: {ve}")
100115
pass

sidekick/query.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
1313
LLMPredictor, ServiceContext, SQLDatabase)
1414
from llama_index.indices.struct_store import SQLContextContainerBuilder
15-
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
15+
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT, NSQL_QUERY_PROMPT,
1616
TASK_PROMPT)
1717
from sidekick.logger import logger
1818
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
@@ -50,6 +50,16 @@ def __init__(
5050
self.openai_key = openai_key
5151
self.content_queries = None
5252

53+
def load_table_info(self):
54+
# Read table_info.jsonl
55+
table_info_file = f"{self.path}/var/lib/tmp/data/table_context.json"
56+
def setup(self):
57+
58+
# Load the table information
59+
self.load_table_info()
60+
61+
62+
5363
def build_index(self, persist: bool = True):
5464
# Below re-assignment of the OPENAI API key is weird but without that, it throws an error.
5565
os.environ["OPENAI_API_KEY"] = self.openai_key
@@ -271,16 +281,16 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s
271281
contextual_context.append(f"{_item}: {_val}")
272282

273283
print("Filtering Question/Query pairs")
274-
_samples = filter_samples(input_question, probable_qs=sample_pairs,
275-
model_path=local_model_path, threshold=0.90)
284+
_samples = filter_samples(input_question, probable_qs=context_queries,
285+
model_path='', threshold=0.90)
276286

277287
# If QnA pairs > 5, we keep only 5 of them for focused context
278288
if len(_samples) > 5:
279289
_samples = _samples[0:5][::-1]
280290
qna_samples = '\n'.join(_samples)
281291

282292
contextual_context_val = ', '.join(contextual_context)
283-
293+
column_names = [str(_c) for _c in self.sql_database.get_column_names(table_name[0])]
284294
if len(_samples) > 2:
285295
# Check for the columns in the QnA samples provided, if exists keep them
286296
context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()]
@@ -290,7 +300,7 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s
290300
relevant_columns = context_columns if len(context_columns) > 0 else column_names
291301
_data_info = ', '.join(relevant_columns)
292302

293-
query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples,
303+
query = NSQL_QUERY_PROMPT.format(table_name=table_name, data_info=_data_info, data_info_detailed=data_samples,
294304
sample_queries=qna_samples, context=contextual_context_val,
295305
question_txt=input_question)
296306

sidekick/utils.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from pathlib import Path
55
from typing import Optional
66

7+
import torch
78
import numpy as np
89
import pandas as pd
910
from pandasql import sqldf
1011
from sentence_transformers import SentenceTransformer
12+
from InstructorEmbedding import INSTRUCTOR
1113
from sidekick.logger import logger
1214
from sklearn.metrics.pairwise import cosine_similarity
1315

@@ -37,6 +39,38 @@ def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, devic
3739
return all_res
3840

3941

42+
def generate_text_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = 'cpu'):
43+
# Reference:
44+
# 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models
45+
# 2. Evaluation result: https://www.sbert.net/_static/html/models_en_sentence_embeddings.html
46+
# 3. Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
47+
# 4. Reference: https://huggingface.co/spaces/mteb/leaderboard
48+
# Maps sentence & paragraphs to a 384 dimensional dense vector space.
49+
model_name_path = f"{model_path}/text_embedding/instructor-large"
50+
current_torch_home = os.environ.get("TORCH_HOME", "")
51+
if Path(model_name_path).is_dir():
52+
is_empty = not any(Path(model_name_path).iterdir())
53+
if is_empty:
54+
# Download n cache at the specified location
55+
os.environ["TORCH_HOME"] = model_path
56+
model_name_path = "hkunlp/instructor-large"
57+
sentence_model = INSTRUCTOR(model_name_path, device=device)
58+
if device != 'cuda':
59+
# Issue https://github.com/pytorch/pytorch/issues/69364
60+
# # In the initial experimentation, quantized model is generates slightly better results
61+
_model = torch.quantization.quantize_dynamic(
62+
sentence_model, {torch.nn.Linear}, dtype=torch.qint8)
63+
else:
64+
_model = sentence_model
65+
_sentences = [['Represent the Financial question for retrieving duplicate examples: ', _item] for _item in x]
66+
67+
res = _model.encode(_sentences)
68+
del sentence_model
69+
del _model
70+
os.environ["TORCH_HOME"] = current_torch_home
71+
return res
72+
73+
4074
def filter_samples(input_q: str, probable_qs: list, model_path: str, threshold: float = 0.45):
4175
# Only consider the questions, note: this might change in future.
4276
_inq = ("# query: " + input_q).strip().lower()
@@ -102,21 +136,21 @@ def read_sample_pairs(input_path: str, model_name: str = "nsql"):
102136
df = df.reset_index(drop=True)
103137

104138
# NSQL format
105-
if model_name != 'nsql':
139+
if model_name != "nsql":
106140
# Open AI format
107-
# Convert frame to below format
108-
# [
109-
# "# query": ""
110-
# "# answer": ""
111-
# ]
141+
# Convert frame to below format
142+
# [
143+
# "# query": ""
144+
# "# answer": ""
145+
# ]
112146
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
113147
else:
114148
# Convert frame to below format
115-
# [
116-
# "Question": <question_text>
117-
# "Answer":
118-
# <response_text>
119-
# ]
149+
# [
150+
# "Question": <question_text>
151+
# "Answer":
152+
# <response_text>
153+
# ]
120154
res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list()
121155
return res
122156

0 commit comments

Comments
 (0)