Skip to content

Commit 3ba31f0

Browse files
Added support for faster n exhaustive regeneration #4
1 parent 427ad8b commit 3ba31f0

File tree

6 files changed

+83
-35
lines changed

6 files changed

+83
-35
lines changed

about.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
**Actively Being Maintained:** Yes (Demo release: _In active RnD_)
66

7-
**Last Updated:** August, 2023
7+
**Last Updated:** September, 2023
88

99
**Allows uploading and using new model and data:** Yes
1010

app.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ title = "SQL-Sidekick"
44
description = "QnA with tabular data using NLI"
55
LongDescription = "about.md"
66
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
7-
Version = "0.0.6"
7+
Version = "0.0.7"
88

99
[Runtime]
1010
MemoryLimit = "64Gi"

sidekick/prompter.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def query_api(
321321
sample_queries_path: str,
322322
table_name: str,
323323
is_regenerate: bool = False,
324+
is_regen_with_options: bool = False,
324325
is_command: bool = False,
325326
):
326327
"""Asks question and returns SQL."""
@@ -400,7 +401,8 @@ def query_api(
400401
job_path=base_path,
401402
data_input_path=table_info_path,
402403
sample_queries_path=sample_queries_path,
403-
is_regenerate = is_regenerate
404+
is_regenerate_with_options=is_regen_with_options,
405+
is_regenerate=is_regenerate,
404406
)
405407
if "h2ogpt-sql" not in model_name:
406408
sql_g._tasks = sql_g.generate_tasks(table_names, question)
@@ -418,9 +420,7 @@ def query_api(
418420
if updated_tasks is not None:
419421
sql_g._tasks = updated_tasks
420422
alt_res = None
421-
res, alt_res = sql_g.generate_sql(
422-
table_names, question, model_name=model_name, _dialect=db_dialect, is_regenerate=is_regenerate
423-
)
423+
res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect)
424424
logger.info(f"Input query: {question}")
425425
logger.info(f"Generated response:\n\n{res}")
426426

@@ -439,11 +439,7 @@ def query_api(
439439
elif res_val.lower() == "r" or res_val.lower() == "regenerate":
440440
click.echo("Attempting to regenerate...")
441441
res, alt_res = sql_g.generate_sql(
442-
table_names,
443-
question,
444-
model_name=model_name,
445-
_dialect=db_dialect,
446-
is_regenerate=is_regenerate,
442+
table_names, question, model_name=model_name, _dialect=db_dialect
447443
)
448444
logger.info(f"Input query: {question}")
449445
logger.info(f"Generated response:\n\n{res}")

sidekick/query.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import gc
12
import json
23
import os
3-
import gc
44
import random
55
import sys
66
from pathlib import Path
@@ -18,14 +18,13 @@
1818
from sidekick.utils import (
1919
_check_file_info,
2020
filter_samples,
21+
is_resource_low,
2122
load_causal_lm_model,
2223
load_embedding_model,
2324
read_sample_pairs,
2425
remove_duplicates,
25-
is_resource_low,
2626
)
2727
from sqlalchemy import create_engine
28-
from transformers import AutoModelForCausalLM, AutoTokenizer
2928

3029

3130
class SQLGenerator:
@@ -41,6 +40,7 @@ def __new__(
4140
job_path: str = "./",
4241
device: str = "auto",
4342
is_regenerate: bool = False,
43+
is_regenerate_with_options: bool = False,
4444
):
4545
offloading = is_resource_low()
4646
if offloading and is_regenerate:
@@ -73,6 +73,7 @@ def __init__(
7373
job_path: str = "./",
7474
device: str = "cpu",
7575
is_regenerate: bool = False,
76+
is_regenerate_with_options: bool = False,
7677
):
7778
self.db_url = db_url
7879
self.engine = create_engine(db_url)
@@ -86,6 +87,9 @@ def __init__(
8687
self.model_name = model_name
8788
self.openai_key = openai_key
8889
self.content_queries = None
90+
self.is_regenerate_with_options = is_regenerate_with_options
91+
self.is_regenerate = is_regenerate
92+
self.device = device
8993

9094
def clear(self):
9195
del SQLGenerator._instance
@@ -252,12 +256,7 @@ def generate_tasks(self, table_names: list, input_question: str):
252256
raise se
253257

254258
def generate_sql(
255-
self,
256-
table_names: list,
257-
input_question: str,
258-
_dialect: str = "sqlite",
259-
model_name: str = "h2ogpt-sql",
260-
is_regenerate: bool = False,
259+
self, table_names: list, input_question: str, _dialect: str = "sqlite", model_name: str = "h2ogpt-sql"
261260
):
262261
context_file = f"{self.path}/var/lib/tmp/data/context.json"
263262
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
@@ -361,8 +360,8 @@ def generate_sql(
361360
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
362361
# If QnA pairs > 5, we keep top 5 for focused context
363362
_samples = filtered_context
364-
if len(filtered_context) > 3:
365-
_samples = filtered_context[0:3][::-1]
363+
if len(filtered_context) > 5:
364+
_samples = filtered_context[0:5][::-1]
366365

367366
qna_samples = "\n".join(_samples)
368367

@@ -431,7 +430,8 @@ def generate_sql(
431430
device_type = "cuda" if torch.cuda.is_available() else "cpu"
432431

433432
alternate_queries = []
434-
if not is_regenerate:
433+
if not self.is_regenerate_with_options and not self.is_regenerate:
434+
# Greedy decoding
435435
output = self.model.generate(
436436
**inputs.to(device_type),
437437
max_new_tokens=300,
@@ -442,17 +442,37 @@ def generate_sql(
442442
)
443443

444444
generated_tokens = output.sequences[:, input_length:][0]
445-
else:
445+
elif self.is_regenerate and not self.is_regenerate_with_options:
446+
# throttle temperature for different result
446447
logger.info("Regeneration requested on previous query ...")
447448
random_seed = random.randint(0, 50)
448449
torch.manual_seed(random_seed)
449-
random_temperature = round(random.uniform(0.5, 0.75), 2)
450+
possible_temp_choice = [0.1, 0.2, 0.3, 0.6, 0.75, 0.9]
451+
random_temperature = np.random.choice(possible_temp_choice, 1)[0]
452+
logger.debug(f"Selected temperature for fast regeneration : {random_temperature}")
453+
output = self.model.generate(
454+
**inputs.to(device_type),
455+
max_new_tokens=300,
456+
temperature=random_temperature,
457+
output_scores=True,
458+
do_sample=True,
459+
return_dict_in_generate=True,
460+
)
461+
generated_tokens = output.sequences[:, input_length:][0]
462+
else:
463+
logger.info("Regeneration with options requested on previous query ...")
464+
# Diverse beam search decoding to explore more options
465+
random_seed = random.randint(0, 50)
466+
torch.manual_seed(random_seed)
467+
possible_temp_choice = [0.1, 0.3, 0.5, 0.6, 0.75]
468+
random_temperature = np.random.choice(possible_temp_choice, 1)[0]
469+
logger.debug(f"Selected temperature for diverse beam search: {random_temperature}")
450470
output_re = self.model.generate(
451471
**inputs.to(device_type),
452472
max_new_tokens=300,
453473
temperature=random_temperature,
454474
top_k=5,
455-
top_p=0.95,
475+
top_p=0.9,
456476
num_beams=5,
457477
num_beam_groups=5,
458478
num_return_sequences=5,
@@ -465,6 +485,7 @@ def generate_sql(
465485
transition_scores = self.model.compute_transition_scores(
466486
output_re.sequences, output_re.scores, output_re.beam_indices, normalize_logits=False
467487
)
488+
468489
# Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
469490
mask = transition_scores < 0
470491
# Sum the True values along axis 1

sidekick/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import numpy as np
99
import pandas as pd
1010
import torch
11+
from accelerate import infer_auto_device_map, init_empty_weights
1112
from InstructorEmbedding import INSTRUCTOR
1213
from pandasql import sqldf
1314
from sentence_transformers import SentenceTransformer
1415
from sidekick.logger import logger
1516
from sklearn.metrics.pairwise import cosine_similarity
16-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
17-
from accelerate import init_empty_weights, infer_auto_device_map
18-
from transformers import BitsAndBytesConfig
17+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
18+
BitsAndBytesConfig)
1919

2020

2121
def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
@@ -324,7 +324,7 @@ def load_causal_lm_model(
324324
model_name, cache_dir=cache_path, device_map=device, quantization_config=nf4_config
325325
)
326326

327-
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path, device_map=device)
327+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path, device_map=device, use_fast=True)
328328

329329
return model, tokenizer
330330
except Exception as e:

ui/app.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async def chat(q: Q):
7272
table_names = []
7373
tables, _ = get_table_keys(f"{tmp_path}/data/tables.json", None)
7474
for table in tables:
75-
table_names.append(ui.choice(table, f"Table: {table}"))
75+
table_names.append(ui.choice(table, f"{table}"))
7676

7777
add_card(q, "background_card", ui.form_card(box="horizontal", items=[ui.text("Ask your questions:")]))
7878

@@ -100,7 +100,15 @@ async def chat(q: Q):
100100
box=ui.box("vertical", height="500px"),
101101
name="chatbot",
102102
data=data(fields="content from_user", t="list", size=-50),
103-
commands=[ui.command(name=f"regenerate_event", icon="RepeatAll", caption="Regenerate", label="Regenerate")],
103+
commands=[
104+
ui.command(name=f"regenerate", icon="RepeatOne", caption="Attempts regeneration", label="Regenerate"),
105+
ui.command(
106+
name=f"regenerate_with_options",
107+
icon="RepeatAll",
108+
caption="Regenerates with options",
109+
label="Try Harder",
110+
),
111+
],
104112
),
105113
)
106114

@@ -121,6 +129,10 @@ async def chatbot(q: Q):
121129
question = f"{q.args.chatbot}"
122130
logging.info(f"Question: {question}")
123131

132+
# For regeneration, currently there are 2 modes
133+
# 1. Quick fast approach by throttling the temperature
134+
# 2. "Try harder mode (THM)" Slow approach by using the diverse beam search
135+
124136
try:
125137
if q.args.chatbot.lower() == "db setup":
126138
llm_response, err = db_setup_api(
@@ -133,15 +145,30 @@ async def chatbot(q: Q):
133145
table_samples_path=q.user.table_samples_path,
134146
table_name=q.user.table_name,
135147
)
136-
elif q.args.chatbot.lower() == "regenerate" or q.args.regenerate_event:
137-
# Attempts to regenerate response on the last supplie query
148+
elif q.args.chatbot.lower() == "regenerate" or q.args.regenerate:
149+
# Attempts to regenerate response on the last supplied query
150+
logging.info(f"Attempt for regeneration")
138151
if q.client.query is not None and q.client.query.strip() != "":
139152
llm_response, alt_response, err = query_api(
140153
question=q.client.query,
141154
sample_queries_path=q.user.sample_qna_path,
142155
table_info_path=q.user.table_info_path,
143156
table_name=q.user.table_name,
144157
is_regenerate=True,
158+
is_regen_with_options=False,
159+
)
160+
llm_response = "\n".join(llm_response)
161+
elif q.args.chatbot.lower() == "try harder" or q.args.regenerate_with_options:
162+
# Attempts to regenerate response on the last supplied query
163+
logging.info(f"Attempt for regeneration with options.")
164+
if q.client.query is not None and q.client.query.strip() != "":
165+
llm_response, alt_response, err = query_api(
166+
question=q.client.query,
167+
sample_queries_path=q.user.sample_qna_path,
168+
table_info_path=q.user.table_info_path,
169+
table_name=q.user.table_name,
170+
is_regenerate=False,
171+
is_regen_with_options=True,
145172
)
146173
response = "\n".join(llm_response)
147174
if alt_response:
@@ -412,8 +439,12 @@ async def on_event(q: Q):
412439
logging.info(f"Event handled ... ")
413440
args_dict = expando_to_dict(q.args)
414441
logging.debug(f"Args dict {args_dict}")
415-
if q.args.regenerate_event:
442+
if q.args.regenerate_with_options:
443+
q.args.chatbot = "try harder"
444+
elif q.args.regenerate:
416445
q.args.chatbot = "regenerate"
446+
447+
if q.args.regenerate_with_options or q.args.regenerate:
417448
await chatbot(q)
418449
event_handled = True
419450
else: # default chatbot event

0 commit comments

Comments
 (0)