Skip to content

Commit b582ff7

Browse files
Enable workflow to save in-correct generation #44
1 parent 0d37526 commit b582ff7

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

sidekick/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from sentence_transformers import SentenceTransformer
1616
from sidekick.logger import logger
1717
from sklearn.metrics.pairwise import cosine_similarity
18-
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
19-
BitsAndBytesConfig)
18+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2019

2120
model_choices_map = {
2221
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
@@ -145,7 +144,9 @@ def remove_duplicates(
145144
return res
146145

147146

148-
def save_query(output_path: str, table_name: str, query, response, extracted_entity: Optional[dict] = ""):
147+
def save_query(
148+
output_path: str, table_name: str, query, response, extracted_entity: Optional[dict] = "", is_invalid: bool = False
149+
):
149150
_response = response
150151
# Probably need to find a better way to extra the info rather than depending on key phrases
151152
if response and "Generated response for question,".lower() in response.lower():
@@ -155,7 +156,11 @@ def save_query(output_path: str, table_name: str, query, response, extracted_ent
155156
chat_history = {"Query": query, "Answer": _response, "Entity": extracted_entity}
156157

157158
# Persist history for contextual reference wrt to the table.
158-
dir_name = f"{output_path}/var/lib/tmp/.cache/{table_name}"
159+
dir_name = (
160+
f"{output_path}/var/lib/tmp/.cache/{table_name}"
161+
if not is_invalid
162+
else f"{output_path}/var/lib/tmp/.cache/{table_name}/invalid"
163+
)
159164
make_dir(dir_name)
160165
with open(f"{dir_name}/history.jsonl", "a") as outfile:
161166
json.dump(chat_history, outfile)

ui/app.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ async def chat(q: Q):
157157
box=ui.box("vertical", height="500px"),
158158
name="chatbot",
159159
data=data(fields="content from_user", t="list", size=-50),
160+
commands=[
161+
ui.command(name="download_accept", label="Download QnA history", icon="Download"),
162+
ui.command(name="download_reject", label="Download in-correct QnA history", icon="Download"),
163+
],
160164
),
161165
),
162166
add_card(
@@ -170,8 +174,8 @@ async def chat(q: Q):
170174
ui.button(
171175
name="regenerate",
172176
icon="RepeatOne",
173-
caption="Attempts regeneration",
174-
label="Regenerate",
177+
caption="Attempts regeneration of the last response",
178+
label="Try Again",
175179
primary=True,
176180
),
177181
ui.button(
@@ -182,9 +186,15 @@ async def chat(q: Q):
182186
),
183187
ui.button(
184188
name="save_conversation",
185-
caption="Saves the conversation for future reference/to improve response",
186-
label="Save",
187-
icon="Save",
189+
caption="Saves the conversation in the history for future reference to improve response",
190+
label="Accept",
191+
icon="Emoji2",
192+
),
193+
ui.button(
194+
name="save_rejected_conversation",
195+
caption="Saves the disappointed conversation to improve response.",
196+
label="Reject",
197+
icon="EmojiDisappointed",
188198
),
189199
],
190200
justify="center",
@@ -629,11 +639,21 @@ async def on_event(q: Q):
629639
# Refresh response is triggered when user selects a table via dropdown
630640
event_handled = True
631641

632-
if q.args.save_conversation or (q.args.chatbot and "save the qna pair:" in q.args.chatbot.lower()):
642+
if (
643+
q.args.save_conversation
644+
or q.args.save_rejected_conversation
645+
or (q.args.chatbot and "save the qna pair:" in q.args.chatbot.lower())
646+
):
633647
question = q.client.query
634648
_val = q.client.llm_response
635649
# Currently, any manual input by the user is a Question by default
636650
table_name = q.user.table_name if q.user.table_name else "default"
651+
_is_invalid = True if q.args.save_rejected_conversation else False
652+
_msg = (
653+
"Conversation saved successfully!"
654+
if not _is_invalid
655+
else "Sorry, we couldn't get it right, we will try to improve!"
656+
)
637657
if (
638658
question is not None
639659
and "SELECT" in question
@@ -642,12 +662,10 @@ async def on_event(q: Q):
642662
_q = question.lower().split("q:")[1].split("r:")[0].strip()
643663
_r = question.lower().split("r:")[1].strip()
644664
logging.info(f"Saving conversation for question: {_q} and response: {_r}")
645-
save_query(base_path, table_name, query=_q, response=_r)
646-
_msg = "Conversation saved successfully!"
665+
save_query(base_path, table_name, query=_q, response=_r, is_invalid=_is_invalid)
647666
elif question is not None and _val is not None and _val.strip() != "":
648667
logging.info(f"Saving conversation for question: {question} and response: {_val}")
649-
save_query(base_path, table_name, query=question, response=_val)
650-
_msg = "Conversation saved successfully!"
668+
save_query(base_path, table_name, query=question, response=_val, is_invalid=_is_invalid)
651669
else:
652670
_msg = "Sorry, try generating a conversation to save."
653671
q.page["chat_card"].data += [_msg, False]

0 commit comments

Comments
 (0)