Skip to content

Commit dd6ab34

Browse files
Load quantized version of the model for faster inferrence #4
1 parent 83ab1c2 commit dd6ab34

File tree

5 files changed

+28
-19
lines changed

5 files changed

+28
-19
lines changed

sidekick/configs/.env.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ LOG-LEVEL = "INFO"
1616
DB_TYPE = "sqlite"
1717

1818
[TABLE_INFO]
19-
TABLE_INFO_PATH = "/examples/test/table_info.jsonl"
20-
TABLE_SAMPLES_PATH = "/examples/test/masked_data_and_columns.csv"
19+
TABLE_INFO_PATH = "/examples/demo/table_info.jsonl"
20+
TABLE_SAMPLES_PATH = "/examples/demo/demo_data.csv"
2121
TABLE_NAME = "demo"
22-
DB_TYPE = "sqlite"

sidekick/db_config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def _extract_schema_info(self, schema_info_path=None):
9595
if "Column Name" in data and "Column Type" in data:
9696
col_name = data["Column Name"]
9797
col_type = data["Column Type"]
98+
if col_type.lower() == "text":
99+
col_type = col_type + " COLLATE NOCASE"
98100
# if column has sample values, save in cache for future use.
99101
if "Sample Values" in data:
100102
_sample_values = data["Sample Values"]
@@ -116,9 +118,7 @@ def _extract_schema_info(self, schema_info_path=None):
116118
return res
117119

118120
def create_table(self, schema_info_path=None, schema_info=None):
119-
engine = create_engine(
120-
self._url, isolation_level="AUTOCOMMIT"
121-
)
121+
engine = create_engine(self._url, isolation_level="AUTOCOMMIT")
122122
self._engine = engine
123123
if self.schema_info is None:
124124
if schema_info is not None:
@@ -139,9 +139,7 @@ def create_table(self, schema_info_path=None, schema_info=None):
139139
return
140140

141141
def has_table(self):
142-
engine = create_engine(
143-
self._url
144-
)
142+
engine = create_engine(self._url)
145143

146144
return sqlalchemy.inspect(engine).has_table(self.table_name)
147145

@@ -181,6 +179,7 @@ def execute_query_db(self, query=None, n_rows=100):
181179

182180
# Create a connection
183181
connection = engine.connect()
182+
logger.debug(f"Executing query:\n {query}")
184183
result = connection.execute(query)
185184

186185
# Process the query results

sidekick/prompter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def update_context():
285285
@click.option("--table-info-path", "-t", help="Table info path", default=None)
286286
@click.option("--sample-queries", "-s", help="Samples path", default=None)
287287
def query(question: str, table_info_path: str, sample_queries: str):
288+
"""Asks question and returns SQL."""
288289
query_api(question=question, table_info_path=table_info_path, sample_queries=sample_queries, is_command=True)
289290

290291

sidekick/query.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
self._tasks = None
5252
self.openai_key = openai_key
5353
self.content_queries = None
54+
self.model = None # Used for local LLMs
55+
self.tokenizer = None # Used for local tokenizer
5456

5557
def load_column_samples(self, tables: list):
5658
# TODO: Maybe we add table name as a member variable
@@ -267,8 +269,12 @@ def generate_sql(
267269
logger.info(f"Realized query so far:\n {res}")
268270
else:
269271
# Load h2oGPT.NSQL model
270-
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B")
271-
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")
272+
device = {"": 0} if torch.cuda.is_available() else "cpu"
273+
if self.model is None:
274+
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
275+
self.model = AutoModelForCausalLM.from_pretrained(
276+
"NumbersStation/nsql-6B", device_map=device, load_in_8bit=True
277+
)
272278

273279
# TODO Update needed for multiple tables
274280
columns_w_type = (
@@ -321,8 +327,8 @@ def generate_sql(
321327
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
322328
# If QnA pairs > 5, we keep top 5 for focused context
323329
_samples = filtered_context
324-
if len(filtered_context) > 5:
325-
_samples = filtered_context[0:5][::-1]
330+
if len(filtered_context) > 3:
331+
_samples = filtered_context[0:3][::-1]
326332
qna_samples = "\n".join(_samples)
327333

328334
contextual_context_val = ", ".join(contextual_context)
@@ -357,24 +363,28 @@ def generate_sql(
357363

358364
logger.debug(f"Query Text:\n {query}")
359365
inputs = tokenizer([query], return_tensors="pt")
360-
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
366+
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
361367
# Generate SQL
362368
random_seed = random.randint(0, 50)
363369
torch.manual_seed(random_seed)
364370

365371
# Greedy search for quick response
366-
output = model.generate(
367-
**inputs,
372+
self.model.eval()
373+
device_type = "cuda" if torch.cuda.is_available() else "cpu"
374+
output = self.model.generate(
375+
**inputs.to(device_type),
368376
max_new_tokens=300,
369377
temperature=0.5,
370378
output_scores=True,
371379
return_dict_in_generate=True,
372380
)
373381

374382
generated_tokens = output.sequences[:, input_length:]
375-
_res = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
383+
_res = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
376384
# Below is a pre-caution in-case of an error in table name during generation
377-
res = "SELECT" + _res.replace("table_name", table_names[0])
385+
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
386+
_temp = _res.replace("table_name", table_names[0]).split(";")[0]
387+
res = "SELECT" + _temp + " COLLATE NOCASE;"
378388
return res
379389

380390
def task_formatter(self, input_task: str):

sidekick/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def filter_samples(input_q: str, probable_qs: list, model_path: str, model_obj=N
9292
_scores.append(similarities_score[0][0])
9393

9494
sorted_res = sorted(res.items(), key=lambda x: x[1], reverse=True)
95-
logger.info(f"Sorted context: {sorted_res}")
95+
logger.debug(f"Sorted context: {sorted_res}")
9696
return list(dict(sorted_res).keys()), model_obj
9797

9898

0 commit comments

Comments
 (0)