@@ -86,6 +86,11 @@ async def chat(q: Q):
86
86
original_name = meta_data [table ].get ("original_name" , q .user .original_name )
87
87
table_names .append (ui .choice (table , f"{ original_name } " ))
88
88
89
+ model_choices = [
90
+ ui .choice ("h2ogpt-sql-nsql-llama-2-7B" , "h2ogpt-sql-nsql-llama-2-7B" ),
91
+ ui .choice ("h2ogpt-sql-sqlcoder2" , "h2ogpt-sql-sqlcoder2" ),
92
+ ]
93
+ q .user .model_choice_dropdown = "h2ogpt-sql-sqlcoder2"
89
94
add_card (
90
95
q ,
91
96
"background_card" ,
@@ -111,7 +116,15 @@ async def chat(q: Q):
111
116
choices = table_names ,
112
117
value = q .user .table_name if q .user .table_name else None ,
113
118
trigger = True ,
114
- )
119
+ ),
120
+ ui .dropdown (
121
+ name = "model_choice_dropdown" ,
122
+ label = "Model Choice" ,
123
+ required = True ,
124
+ choices = model_choices ,
125
+ value = q .user .model_choice_dropdown if q .user .model_choice_dropdown else None ,
126
+ trigger = True ,
127
+ ),
115
128
],
116
129
),
117
130
)
@@ -209,6 +222,7 @@ async def chatbot(q: Q):
209
222
sample_queries_path = q .user .sample_qna_path ,
210
223
table_info_path = q .user .table_info_path ,
211
224
table_name = q .user .table_name ,
225
+ model_name = q .user .model_choice_dropdown ,
212
226
is_regenerate = True ,
213
227
is_regen_with_options = False ,
214
228
)
@@ -227,6 +241,7 @@ async def chatbot(q: Q):
227
241
sample_queries_path = q .user .sample_qna_path ,
228
242
table_info_path = q .user .table_info_path ,
229
243
table_name = q .user .table_name ,
244
+ model_name = q .user .model_choice_dropdown ,
230
245
is_regenerate = False ,
231
246
is_regen_with_options = True ,
232
247
)
@@ -248,6 +263,7 @@ async def chatbot(q: Q):
248
263
sample_queries_path = q .user .sample_qna_path ,
249
264
table_info_path = q .user .table_info_path ,
250
265
table_name = q .user .table_name ,
266
+ model_name = q .user .model_choice_dropdown ,
251
267
)
252
268
llm_response = "\n " .join (llm_response )
253
269
except (MemoryError , RuntimeError ) as e :
@@ -567,12 +583,23 @@ async def on_event(q: Q):
567
583
elif q .args .regenerate :
568
584
q .args .chatbot = "regenerate"
569
585
570
- if q .args .table_dropdown and not q .args .chatbot :
586
+ if q .args .table_dropdown and not q .args .chatbot and q . user . table_name != q . args . table_dropdown :
571
587
logging .info (f"User selected table: { q .args .table_dropdown } " )
572
588
await submit_table (q )
573
589
q .args .chatbot = f"Table { q .args .table_dropdown } selected"
574
590
# Refresh response is triggered when user selects a table via dropdown
575
591
event_handled = True
592
+ if (
593
+ q .args .model_choice_dropdown
594
+ and not q .args .chatbot
595
+ and q .user .model_choice_dropdown != q .args .model_choice_dropdown
596
+ ):
597
+ logging .info (f"User selected model type: { q .args .model_choice_dropdown } " )
598
+ q .user .model_choice_dropdown = q .args .model_choice_dropdown
599
+ q .page ["select_tables" ].model_choice_dropdown .value = q .user .model_choice_dropdown
600
+ q .args .chatbot = f"Model { q .args .model_choice_dropdown } selected"
601
+ # Refresh response is triggered when user selects a table via dropdown
602
+ event_handled = True
576
603
577
604
if q .args .save_conversation or (q .args .chatbot and "save the qna pair:" in q .args .chatbot .lower ()):
578
605
question = q .client .query
0 commit comments