Skip to content

Commit a03d46c

Browse files
Save history/cache QnA pair wrt to each table
1 parent b60fb7b commit a03d46c

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

sidekick/memory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def save_context(self, info: str, extract_context: bool = True) -> Dict:
3131
split_token = ";"
3232
query = " ".join(info.partition(":")[2].split(split_token)[0].strip().split())
3333
response = " ".join(info.partition(":")[2].split(split_token)[1].partition(":")[2].strip().split())
34-
# TODO add additional guardrails to check if the response is a valid response.
35-
# At-least syntactically correct SQL.
3634

3735
# Check if entity extraction is enabled
3836
# Add logic for entity extraction
@@ -66,7 +64,8 @@ def save_context(self, info: str, extract_context: bool = True) -> Dict:
6664

6765
# Persist added information locally
6866
if chat_history:
69-
with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "a") as outfile:
67+
# TODO: Persist history for each user. This flow is currently only affects openai models.
68+
with open(f"{self.path}/var/lib/tmp/.cache/history.jsonl", "a") as outfile:
7069
json.dump(chat_history, outfile)
7170
outfile.write("\n")
7271
if extract_context:

sidekick/query.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sidekick.logger import logger
2222
from sidekick.utils import (_check_file_info, filter_samples, is_resource_low,
2323
load_causal_lm_model, load_embedding_model,
24-
read_sample_pairs, remove_duplicates)
24+
make_dir, read_sample_pairs, remove_duplicates)
2525
from sqlalchemy import create_engine
2626

2727

@@ -119,19 +119,21 @@ def build_index(self, persist: bool = True):
119119

120120
def update_context_queries(self):
121121
# Check if seed samples were provided
122+
cache_path = f"{self.path}/var/lib/tmp/.cache/{self.table_name}/"
122123
new_context_queries = []
123124
if self.sample_queries_path is not None and Path(self.sample_queries_path).exists():
124125
logger.info(f"Using QnA samples from path {self.sample_queries_path}")
125126
new_context_queries = read_sample_pairs(self.sample_queries_path, "h2ogpt-sql")
126127
# cache the samples for future use
127-
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f:
128+
make_dir(cache_path)
129+
with open(f"{cache_path}/queries_cache.json", "w") as f:
128130
json.dump(new_context_queries, f, indent=2)
129-
elif self.sample_queries_path is None and Path(f"{self.path}/var/lib/tmp/data/queries_cache.json").exists():
131+
elif self.sample_queries_path is None and Path(f"{cache_path}/queries_cache.json").exists():
130132
logger.info(f"Using samples from cache")
131-
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "r") as f:
133+
with open(f"{cache_path}/queries_cache.json", "r") as f:
132134
new_context_queries = json.load(f)
133135
# Read the history file and update the context queries
134-
history_file = f"{self.path}/var/lib/tmp/data/{self.table_name}/history.jsonl"
136+
history_file = f"{self.path}/var/lib/tmp/.cache/{self.table_name}/history.jsonl"
135137
try:
136138
if Path(history_file).exists():
137139
with open(history_file, "r") as in_file:

sidekick/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import errno
12
import glob
23
import json
34
import os
@@ -134,14 +135,19 @@ def remove_duplicates(
134135
return res
135136

136137

137-
def save_query(output_path: str, query, response, extracted_entity: Optional[dict] = ""):
138+
def save_query(output_path: str, table_name: str, query, response, extracted_entity: Optional[dict] = ""):
138139
_response = response
139140
# Probably need to find a better way to extra the info rather than depending on key phrases
140141
if response and "Generated response for question,".lower() in response.lower():
141-
_response = response.split("**Generated response for question,**")[1].split("\n")[3].strip()
142+
_response = (
143+
response.split("**Generated response for question,**")[1].split("``` sql")[1].split("```")[0].strip()
144+
)
142145
chat_history = {"Query": query, "Answer": _response, "Entity": extracted_entity}
143146

144-
with open(f"{output_path}/var/lib/tmp/data/history.jsonl", "a") as outfile:
147+
# Persist history for contextual reference wrt to the table.
148+
dir_name = f"{output_path}/var/lib/tmp/.cache/{table_name}"
149+
make_dir(dir_name)
150+
with open(f"{dir_name}/history.jsonl", "a") as outfile:
145151
json.dump(chat_history, outfile)
146152
outfile.write("\n")
147153

@@ -378,7 +384,7 @@ def make_dir(path: str):
378384
try:
379385
os.makedirs(path)
380386
except OSError as exc:
381-
if exc.errno == errno.EXIST and os.path.isdir(path):
387+
if exc.errno == errno.EEXIST and os.path.isdir(path):
382388
pass
383389
else:
384390
raise Exception("Error reported while creating default directory path.")

ui/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ async def on_event(q: Q):
578578
question = q.client.query
579579
_val = q.client.llm_response
580580
# Currently, any manual input by the user is a Question by default
581+
table_name = q.user.table_name if q.user.table_name else "default"
581582
if (
582583
question is not None
583584
and "SELECT" in question
@@ -586,11 +587,11 @@ async def on_event(q: Q):
586587
_q = question.lower().split("q:")[1].split("r:")[0].strip()
587588
_r = question.lower().split("r:")[1].strip()
588589
logging.info(f"Saving conversation for question: {_q} and response: {_r}")
589-
save_query(base_path, query=_q, response=_r)
590+
save_query(base_path, table_name, query=_q, response=_r)
590591
_msg = "Conversation saved successfully!"
591592
elif question is not None and _val is not None and _val.strip() != "":
592593
logging.info(f"Saving conversation for question: {question} and response: {_val}")
593-
save_query(base_path, query=question, response=_val)
594+
save_query(base_path, table_name, query=question, response=_val)
594595
_msg = "Conversation saved successfully!"
595596
else:
596597
_msg = "Sorry, try generating a conversation to save."

0 commit comments

Comments
 (0)