Skip to content

Commit e7aed31

Browse files
Optimize model download and improve data upload errors #44
1 parent bdf020c commit e7aed31

File tree

2 files changed

+57
-25
lines changed

2 files changed

+57
-25
lines changed

start.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import multiprocessing
12
import os
23
import shlex
34
import subprocess
@@ -7,10 +8,21 @@
78

89
print(f"Download model...")
910
base_path = (Path(__file__).parent).resolve()
10-
# Model 1:
11-
snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/")
12-
# Model 2:
13-
snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/")
11+
12+
MODEL_CHOICE_MAP = {
13+
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
14+
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
15+
}
16+
17+
18+
def f(model_id):
19+
print(f"Downloading {model_id}...")
20+
snapshot_download(repo_id=model_id, cache_dir=f"{base_path}/models/")
21+
22+
23+
with multiprocessing.Pool(len(MODEL_CHOICE_MAP)) as pool:
24+
_ = pool.map(f, [_m for _m in MODEL_CHOICE_MAP.values()])
25+
1426
print(f"Download embedding model...")
1527
snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/")
1628

ui/app.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ async def chatbot(q: Q):
343343
@on("file_upload")
344344
async def fileupload(q: Q):
345345
q.page["dataset"].error_bar.visible = False
346+
q.page["dataset"].error_upload_bar.visible = False
346347
q.page["dataset"].success_bar.visible = False
347348
q.page["dataset"].progress_bar.visible = True
348349

@@ -360,7 +361,7 @@ async def fileupload(q: Q):
360361
remove_chars = [" ", "-"]
361362
org_table_name = usr_table_name = None
362363
if (
363-
q.args.table_name == "" or q.args.table_name is None and sample_data
364+
(q.args.table_name == "" or q.args.table_name is None) and sample_data and len(sample_data) > 0
364365
): # User did not provide a table name, use the filename as table name
365366
org_table_name = sample_data[0].split(".")[0].split("/")[-1]
366367
logging.info(f"Using provided filename as table name: {org_table_name}")
@@ -374,6 +375,7 @@ async def fileupload(q: Q):
374375
logging.info(f"Upload initiated for {org_table_name} with scheme input: {sample_schema}")
375376
if sample_data is None:
376377
q.page["dataset"].error_bar.visible = True
378+
q.page["dataset"].error_upload_bar.visible = False
377379
q.page["dataset"].progress_bar.visible = False
378380
else:
379381
if sample_data:
@@ -396,27 +398,39 @@ async def fileupload(q: Q):
396398
"samples_path": usr_samples_path,
397399
"samples_qa": usr_sample_qa,
398400
}
399-
logging.info(f"Table metadata: {table_metadata}")
400-
update_tables(f"{tmp_path}/data/tables.json", table_metadata)
401+
try:
402+
logging.info(f"Table metadata: {table_metadata}")
403+
update_tables(f"{tmp_path}/data/tables.json", table_metadata)
401404

402-
q.user.table_name = usr_table_name
403-
q.user.table_samples_path = usr_samples_path
404-
q.user.table_info_path = usr_info_path
405-
q.user.sample_qna_path = usr_sample_qa
405+
q.user.table_name = usr_table_name
406+
q.user.table_samples_path = usr_samples_path
407+
q.user.table_info_path = usr_info_path
408+
q.user.sample_qna_path = usr_sample_qa
406409

407-
db_resp = db_setup_api(
408-
db_name=q.user.db_name,
409-
hostname=q.user.host_name,
410-
user_name=q.user.user_name,
411-
password=q.user.password,
412-
port=q.user.port,
413-
table_info_path=q.user.table_info_path,
414-
table_samples_path=q.user.table_samples_path,
415-
table_name=q.user.table_name,
416-
)
417-
logging.info(f"DB updates: \n {db_resp}")
418-
q.page["dataset"].progress_bar.visible = False
419-
q.page["dataset"].success_bar.visible = True
410+
db_resp = db_setup_api(
411+
db_name=q.user.db_name,
412+
hostname=q.user.host_name,
413+
user_name=q.user.user_name,
414+
password=q.user.password,
415+
port=q.user.port,
416+
table_info_path=q.user.table_info_path,
417+
table_samples_path=q.user.table_samples_path,
418+
table_name=q.user.table_name,
419+
)
420+
logging.info(f"DB updates: \n {db_resp}")
421+
if "error" in str(db_resp).lower():
422+
q.page["dataset"].error_upload_bar.visible = True
423+
q.page["dataset"].error_bar.visible = False
424+
q.page["dataset"].progress_bar.visible = False
425+
else:
426+
q.page["dataset"].progress_bar.visible = False
427+
q.page["dataset"].success_bar.visible = True
428+
except Exception as e:
429+
logging.error(f"Something went wrong while uploading the dataset: {e}")
430+
q.page["dataset"].error_upload_bar.visible = True
431+
q.page["dataset"].error_bar.visible = False
432+
q.page["dataset"].progress_bar.visible = False
433+
return
420434

421435

422436
@on("#datasets")
@@ -434,7 +448,13 @@ async def datasets(q: Q):
434448
ui.message_bar(
435449
name="error_bar",
436450
type="error",
437-
text="Please input table name, data & schema files to upload!",
451+
text="Please input table name and upload data to get started!",
452+
visible=False,
453+
),
454+
ui.message_bar(
455+
name="error_upload_bar",
456+
type="error",
457+
text="Upload failed; something went wrong. Please check the dataset name/column name for special characters and try again!",
438458
visible=False,
439459
),
440460
ui.message_bar(name="success_bar", type="success", text="Files Uploaded Successfully!", visible=False),

0 commit comments

Comments
 (0)