Skip to content

Commit 8b50af6

Browse files
Load default model during app init #4
1 parent b0d0261 commit 8b50af6

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

sidekick/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def __init__(
7979
is_regenerate_with_options: bool = False,
8080
):
8181
self.db_url = db_url
82-
self.engine = create_engine(db_url)
83-
self.sql_database = SQLDatabase(self.engine)
82+
self.engine = create_engine(db_url) if db_url else None
83+
self.sql_database = SQLDatabase(self.engine) if self.engine else None
8484
self.context_builder = None
8585
self.data_input_path = _check_file_info(data_input_path)
8686
self.sample_queries_path = sample_queries_path

sidekick/utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from sentence_transformers import SentenceTransformer
1616
from sidekick.logger import logger
1717
from sklearn.metrics.pairwise import cosine_similarity
18-
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
19-
BitsAndBytesConfig)
18+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2019

2120

2221
def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
@@ -269,9 +268,9 @@ def get_table_keys(file_path: str, table_key: str):
269268
return res, data
270269

271270

272-
def is_resource_low():
273-
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
274-
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
271+
def is_resource_low(device_index: int = 0):
272+
free_in_GB = int(torch.cuda.mem_get_info(device_index)[0] / 1024**3)
273+
total_memory = int(torch.cuda.get_device_properties(device_index).total_memory / 1024**3)
275274
logger.info(f"Total Memory: {total_memory}GB")
276275
logger.info(f"Free GPU memory: {free_in_GB}GB")
277276
off_load = True
@@ -296,20 +295,21 @@ def load_causal_lm_model(
296295
}
297296
model_name = model_choices_map[model_type]
298297
logger.info(f"Loading model: {model_name}")
298+
device_index = 0
299299
# Load h2oGPT.SQL model
300-
device = {"": 0} if torch.cuda.is_available() else "cpu" if device == "auto" else device
300+
device = {"": device_index} if torch.cuda.is_available() else "cpu" if device == "auto" else device
301301
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
302302
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
303303
logger.info(f"Free GPU memory: {free_in_GB}GB")
304304
n_gpus = torch.cuda.device_count()
305+
logger.info(f"Total GPUs: {n_gpus}")
305306
_load_in_8bit = load_in_8bit
306307

307308
# 22GB (Least requirement on GPU) is a magic number for the current model size.
308309
if off_load and re_generate and total_memory < 22:
309310
# To prevent the system from crashing in-case memory runs low.
310311
# TODO: Performance when offloading to CPU.
311-
max_memory = f"{4}GB"
312-
max_memory = {i: max_memory for i in range(n_gpus)}
312+
max_memory = {device_index: f"{4}GB"}
313313
logger.info(f"Max Memory: {max_memory}, offloading to CPU")
314314
with init_empty_weights():
315315
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path)
@@ -322,8 +322,7 @@ def load_causal_lm_model(
322322
_load_in_8bit = True
323323
load_in_4bit = False
324324
else:
325-
max_memory = f"{int(free_in_GB)-2}GB"
326-
max_memory = {i: max_memory for i in range(n_gpus)}
325+
max_memory = {device_index: f"{int(free_in_GB)-2}GB"}
327326
_offload_state_dict = False
328327
_llm_int8_enable_fp32_cpu_offload = False
329328

ui/app.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from h2o_wave import Q, app, data, handle_on, main, on, ui
1212
from h2o_wave.core import expando_to_dict
1313
from sidekick.prompter import db_setup_api, query_api
14+
from sidekick.query import SQLGenerator
1415
from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables
1516

1617
# Load the config file and initialize required paths
@@ -197,6 +198,8 @@ async def chatbot(q: Q):
197198
question = f"{q.args.chatbot}"
198199
logging.info(f"Question: {question}")
199200

201+
if q.args.table_dropdown or q.args.model_choice_dropdown:
202+
return
200203
# For regeneration, currently there are 2 modes
201204
# 1. Quick fast approach by throttling the temperature
202205
# 2. "Try harder mode (THM)" Slow approach by using the diverse beam search
@@ -662,6 +665,23 @@ async def on_event(q: Q):
662665
return event_handled
663666

664667

668+
def on_startup():
669+
logging.info("SQL-Assistant started!")
670+
logging.info(f"Initializing default model")
671+
672+
_ = SQLGenerator(
673+
None,
674+
None,
675+
model_name="h2ogpt-sql-sqlcoder2",
676+
job_path=base_path,
677+
data_input_path="",
678+
sample_queries_path="",
679+
is_regenerate_with_options="",
680+
is_regenerate="",
681+
)
682+
return
683+
684+
665685
@app("/", on_shutdown=on_shutdown)
666686
async def serve(q: Q):
667687
# Run only once per client connection.
@@ -670,6 +690,7 @@ async def serve(q: Q):
670690
setup_dir(base_path)
671691
await init(q)
672692
q.client.initialized = True
693+
on_startup()
673694
logging.info("App initialized.")
674695

675696
# Handle routing.

0 commit comments

Comments
 (0)