Skip to content

Commit 796bda6

Browse files
Enable task mode #44
1 parent 405991a commit 796bda6

File tree

3 files changed

+51
-13
lines changed

3 files changed

+51
-13
lines changed

sidekick/prompter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ def query_api(
452452
logger.info("Executing user provided SQL without re-generation...")
453453
res = question.strip().lower().split("execute sql:")[1].strip()
454454
else:
455+
_check_cond = question.strip().lower().split("execute sql:")
456+
if len(_check_cond) > 1:
457+
question = question.strip().lower().split("execute sql:")[1].strip()
455458
res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
456459
logger.info(f"Input query: {question}")
457460
logger.info(f"Generated response:\n\n{res}")

sidekick/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,23 @@
1515
from sentence_transformers import SentenceTransformer
1616
from sidekick.logger import logger
1717
from sklearn.metrics.pairwise import cosine_similarity
18-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
18+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
19+
BitsAndBytesConfig)
1920

20-
model_choices_map = {
21+
MODEL_CHOICE_MAP = {
2122
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
2223
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
2324
}
2425

25-
model_device_map = {
26+
MODEL_DEVICE_MAP = {
2627
"h2ogpt-sql-sqlcoder2": 0,
2728
"h2ogpt-sql-nsql-llama-2-7B": 1,
2829
}
2930

31+
TASK_CHOICE = {
32+
"q_a": "Question/Answering",
33+
"sqld": "SQL Debugging",
34+
}
3035

3136
def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
3237
# Reference:
@@ -290,7 +295,7 @@ def is_resource_low(model_name: str):
290295
off_load = False
291296
else:
292297
n_gpus = torch.cuda.device_count()
293-
device_index = model_device_map[model_name] if model_name and n_gpus > 1 else 0
298+
device_index = MODEL_DEVICE_MAP[model_name] if model_name and n_gpus > 1 else 0
294299
logger.debug(f"Information on device: {device_index}")
295300
free_in_GB = int(torch.cuda.mem_get_info(device_index)[0] / 1024**3)
296301
total_memory = int(torch.cuda.get_device_properties(device_index).total_memory / 1024**3)
@@ -382,14 +387,14 @@ def _load_llm(model_type: str, device_index: int = 0, load_in_4bit=True):
382387

383388
if not model_type: # if None, load all models
384389
for device_index in range(n_gpus):
385-
model_name = list(model_choices_map.values())[device_index]
390+
model_name = list(MODEL_CHOICE_MAP.values())[device_index]
386391
model, tokenizer = _load_llm(model_name, device_index)
387-
_name = list(model_choices_map.keys())[device_index]
392+
_name = list(MODEL_CHOICE_MAP.keys())[device_index]
388393
models[_name] = model
389394
tokenizers[_name] = tokenizer
390395
else:
391-
model_name = model_choices_map[model_type]
392-
d_index = model_device_map[model_type] if n_gpus > 1 else 0
396+
model_name = MODEL_CHOICE_MAP[model_type]
397+
d_index = MODEL_DEVICE_MAP[model_type] if n_gpus > 1 else 0
393398
model, tokenizer = _load_llm(model_name, d_index)
394399
models[model_type] = model
395400
tokenizers[model_type] = tokenizer

ui/app.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from h2o_wave.core import expando_to_dict
1313
from sidekick.prompter import db_setup_api, query_api
1414
from sidekick.query import SQLGenerator
15-
from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables
15+
from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables, TASK_CHOICE
1616

1717
# Load the config file and initialize required paths
1818
base_path = (Path(__file__).parent / "../").resolve()
@@ -91,7 +91,7 @@ def clear_cards(q, ignore: Optional[List[str]] = []) -> None:
9191
async def chat(q: Q):
9292
q.page["sidebar"].value = "#chat"
9393

94-
if q.args.table_dropdown or q.args.model_choice_dropdown:
94+
if q.args.table_dropdown or q.args.model_choice_dropdown or q.args.task_dropdown:
9595
# If a table/model is selected, the trigger causes refresh of the page
9696
# so we update chat history with table name selection and return
9797
# avoiding re-drawing.
@@ -113,6 +113,9 @@ async def chat(q: Q):
113113
ui.choice("h2ogpt-sql-sqlcoder2", "h2ogpt-sql-sqlcoder2"),
114114
]
115115
q.user.model_choice_dropdown = "h2ogpt-sql-sqlcoder2"
116+
117+
task_choices = [ui.choice("q_a", "Question/Answering"), ui.choice("sqld", "SQL Debugging")]
118+
q.user.task_choice_dropdown = "q_a"
116119
add_card(
117120
q,
118121
"background_card",
@@ -123,7 +126,7 @@ async def chat(q: Q):
123126
ui.inline(items=[ui.toggle(name="demo_mode", label="Demo", trigger=True)], justify="end"),
124127
],
125128
),
126-
)
129+
),
127130

128131
add_card(
129132
q,
@@ -149,7 +152,24 @@ async def chat(q: Q):
149152
),
150153
],
151154
),
152-
)
155+
),
156+
add_card(
157+
q,
158+
"task_choice",
159+
ui.form_card(
160+
box="vertical",
161+
items=[
162+
ui.dropdown(
163+
name="task_dropdown",
164+
label="Task",
165+
required=True,
166+
choices=task_choices,
167+
value=q.user.task_choice_dropdown if q.user.task_choice_dropdown else None,
168+
trigger=True,
169+
)
170+
],
171+
),
172+
),
153173
add_card(
154174
q,
155175
"chat_card",
@@ -228,11 +248,15 @@ async def chatbot(q: Q):
228248
if (
229249
f"Table {q.user.table_dropdown} selected" in q.args.chatbot
230250
or f"Model {q.user.model_choice_dropdown} selected" in q.args.chatbot
251+
or f"Task {q.user.task_dropdown} selected" in q.args.chatbot
231252
):
232253
return
233254

234255
# Append bot response.
235256
question = f"{q.args.chatbot}"
257+
# Check on task choice.
258+
if q.user.task_dropdown == "sqld":
259+
question = f"Execute SQL:\n{q.args.chatbot}"
236260
logging.info(f"Question: {question}")
237261

238262
# For regeneration, currently there are 2 modes
@@ -638,7 +662,13 @@ async def on_event(q: Q):
638662
q.args.chatbot = f"Model {q.user.model_choice_dropdown} selected"
639663
# Refresh response is triggered when user selects a table via dropdown
640664
event_handled = True
641-
665+
if q.args.task_dropdown and not q.args.chatbot and q.user.task_dropdown != q.args.task_dropdown:
666+
logging.info(f"User selected task: {q.args.task_dropdown}")
667+
q.user.task_dropdown = q.args.task_dropdown
668+
q.page["task_choice"].task_dropdown.value = q.user.task_dropdown
669+
q.args.chatbot = f"Task '{TASK_CHOICE[q.user.task_dropdown]}' selected"
670+
# Refresh response is triggered when user selects a table via dropdown
671+
event_handled = True
642672
if (
643673
q.args.save_conversation
644674
or q.args.save_rejected_conversation

0 commit comments

Comments
 (0)