Skip to content

Commit a0b03c5

Browse files
Add widget to control model selection #40
1 parent 33e7de3 commit a0b03c5

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

ui/app.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ async def chat(q: Q):
8686
original_name = meta_data[table].get("original_name", q.user.original_name)
8787
table_names.append(ui.choice(table, f"{original_name}"))
8888

89+
model_choices = [
90+
ui.choice("h2ogpt-sql-nsql-llama-2-7B", "h2ogpt-sql-nsql-llama-2-7B"),
91+
ui.choice("h2ogpt-sql-sqlcoder2", "h2ogpt-sql-sqlcoder2"),
92+
]
93+
q.user.model_choice_dropdown = "h2ogpt-sql-sqlcoder2"
8994
add_card(
9095
q,
9196
"background_card",
@@ -111,7 +116,15 @@ async def chat(q: Q):
111116
choices=table_names,
112117
value=q.user.table_name if q.user.table_name else None,
113118
trigger=True,
114-
)
119+
),
120+
ui.dropdown(
121+
name="model_choice_dropdown",
122+
label="Model Choice",
123+
required=True,
124+
choices=model_choices,
125+
value=q.user.model_choice_dropdown if q.user.model_choice_dropdown else None,
126+
trigger=True,
127+
),
115128
],
116129
),
117130
)
@@ -209,6 +222,7 @@ async def chatbot(q: Q):
209222
sample_queries_path=q.user.sample_qna_path,
210223
table_info_path=q.user.table_info_path,
211224
table_name=q.user.table_name,
225+
model_name=q.user.model_choice_dropdown,
212226
is_regenerate=True,
213227
is_regen_with_options=False,
214228
)
@@ -227,6 +241,7 @@ async def chatbot(q: Q):
227241
sample_queries_path=q.user.sample_qna_path,
228242
table_info_path=q.user.table_info_path,
229243
table_name=q.user.table_name,
244+
model_name=q.user.model_choice_dropdown,
230245
is_regenerate=False,
231246
is_regen_with_options=True,
232247
)
@@ -248,6 +263,7 @@ async def chatbot(q: Q):
248263
sample_queries_path=q.user.sample_qna_path,
249264
table_info_path=q.user.table_info_path,
250265
table_name=q.user.table_name,
266+
model_name=q.user.model_choice_dropdown,
251267
)
252268
llm_response = "\n".join(llm_response)
253269
except (MemoryError, RuntimeError) as e:
@@ -567,12 +583,23 @@ async def on_event(q: Q):
567583
elif q.args.regenerate:
568584
q.args.chatbot = "regenerate"
569585

570-
if q.args.table_dropdown and not q.args.chatbot:
586+
if q.args.table_dropdown and not q.args.chatbot and q.user.table_name != q.args.table_dropdown:
571587
logging.info(f"User selected table: {q.args.table_dropdown}")
572588
await submit_table(q)
573589
q.args.chatbot = f"Table {q.args.table_dropdown} selected"
574590
# Refresh response is triggered when user selects a table via dropdown
575591
event_handled = True
592+
if (
593+
q.args.model_choice_dropdown
594+
and not q.args.chatbot
595+
and q.user.model_choice_dropdown != q.args.model_choice_dropdown
596+
):
597+
logging.info(f"User selected model type: {q.args.model_choice_dropdown}")
598+
q.user.model_choice_dropdown = q.args.model_choice_dropdown
599+
q.page["select_tables"].model_choice_dropdown.value = q.user.model_choice_dropdown
600+
q.args.chatbot = f"Model {q.args.model_choice_dropdown} selected"
601+
# Refresh response is triggered when user selects a table via dropdown
602+
event_handled = True
576603

577604
if q.args.save_conversation or (q.args.chatbot and "save the qna pair:" in q.args.chatbot.lower()):
578605
question = q.client.query

0 commit comments

Comments
 (0)