Skip to content

Commit 9c46ff1

Browse files
Fix respective config paths #4
1 parent a7dfae3 commit 9c46ff1

File tree

4 files changed

+33
-18
lines changed

4 files changed

+33
-18
lines changed

sidekick/configs/.env.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ LOG-LEVEL = "INFO"
1616
DB_TYPE = "sqlite"
1717

1818
[TABLE_INFO]
19-
TABLE_INFO_PATH = "/examples/demo/table_info.jsonl"
20-
TABLE_SAMPLES_PATH = "/examples/demo/demo_data.csv"
19+
TABLE_INFO_PATH = "examples/demo/table_info.jsonl"
20+
SAMPLE_QNA_PATH = "examples/demo/demo_qa.csv"
21+
TABLE_SAMPLES_PATH = "examples/demo/demo_data.csv"
2122
TABLE_NAME = "demo"

sidekick/prompter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,13 @@ def update_context():
283283
@cli.command()
284284
@click.option("--question", "-q", help="Database name", prompt="Ask a question")
285285
@click.option("--table-info-path", "-t", help="Table info path", default=None)
286-
@click.option("--sample-queries", "-s", help="Samples path", default=None)
287-
def query(question: str, table_info_path: str, sample_queries: str):
286+
@click.option("--sample_qna_path", "-s", help="Samples path", default=None)
287+
def query(question: str, table_info_path: str, sample_qna_path: str):
288288
"""Asks question and returns SQL."""
289-
query_api(question=question, table_info_path=table_info_path, sample_queries=sample_queries, is_command=True)
289+
query_api(question=question, table_info_path=table_info_path, sample_queries_path=sample_qna_path, is_command=True)
290290

291291

292-
def query_api(question: str, table_info_path: str, sample_queries: str, is_command: bool = False):
292+
def query_api(question: str, table_info_path: str, sample_queries_path: str, is_command: bool = False):
293293
"""Asks question and returns SQL."""
294294
results = []
295295
# Book-keeping
@@ -357,7 +357,7 @@ def query_api(question: str, table_info_path: str, sample_queries: str, is_comma
357357
table_info_path = _get_table_info(path)
358358

359359
sql_g = SQLGenerator(
360-
db_url, api_key, job_path=base_path, data_input_path=table_info_path, samples_queries=sample_queries
360+
db_url, api_key, job_path=base_path, data_input_path=table_info_path, sample_queries_path=sample_queries_path
361361
)
362362
if "h2ogpt-sql" not in model_name:
363363
sql_g._tasks = sql_g.generate_tasks(table_names, question)

sidekick/query.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
import sqlglot
1010
import torch
1111
from langchain import OpenAI
12-
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
12+
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
13+
LLMPredictor, ServiceContext, SQLDatabase)
1314
from llama_index.indices.struct_store import SQLContextContainerBuilder
14-
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT
15+
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
16+
NSQL_QUERY_PROMPT, QUERY_PROMPT,
17+
TASK_PROMPT)
1518
from sidekick.logger import logger
1619
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
1720
from sqlalchemy import create_engine
@@ -33,7 +36,7 @@ def __init__(
3336
db_url: str,
3437
openai_key: str = None,
3538
data_input_path: str = "./table_info.jsonl",
36-
samples_queries: str = "./samples.csv",
39+
sample_queries_path: str = "./samples.csv",
3740
job_path: str = "../var/lib/tmp/data",
3841
):
3942
self.db_url = db_url
@@ -42,7 +45,7 @@ def __init__(
4245
self.similarity_model = None
4346
self.context_builder = None
4447
self.data_input_path = _check_file_info(data_input_path)
45-
self.sample_queries_path = samples_queries
48+
self.sample_queries_path = sample_queries_path
4649
self.path = job_path
4750
self._data_info = None
4851
self._tasks = None
@@ -78,7 +81,7 @@ def update_context_queries(self):
7881
# Check if seed samples were provided
7982
new_context_queries = []
8083
if self.sample_queries_path is not None and Path(self.sample_queries_path).exists():
81-
logger.info(f"Using samples from path {self.sample_queries_path}")
84+
logger.info(f"Using QnA samples from path {self.sample_queries_path}")
8285
new_context_queries = read_sample_pairs(self.sample_queries_path, "gpt")
8386
# cache the samples for future use
8487
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f:
@@ -319,7 +322,7 @@ def generate_sql(
319322
threshold=0.9,
320323
)
321324
if len(context_queries) > 1
322-
else context_queries
325+
else (context_queries, _)
323326
)
324327
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
325328
# If QnA pairs > 5, we keep top 5 for focused context

ui/app.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
22
from pathlib import Path
33
from typing import List, Optional
4-
from sidekick.prompter import db_setup_api, query_api
54

65
import openai
76
import toml
87
from h2o_wave import Q, app, data, handle_on, main, on, ui
8+
from sidekick.prompter import db_setup_api, query_api
99

1010
# Load the config file and initialize required paths
1111
base_path = (Path(__file__).parent / "../").resolve()
@@ -20,8 +20,10 @@
2020
password = db_settings["LOCAL_DB_CONFIG"]["PASSWORD"]
2121
db_name = db_settings["LOCAL_DB_CONFIG"]["DB_NAME"]
2222
port = db_settings["LOCAL_DB_CONFIG"]["PORT"]
23+
# Related to the selected table - currently demo
2324
table_info_path = f'{base_path}/{db_settings["TABLE_INFO"]["TABLE_INFO_PATH"]}'
2425
table_samples_path = f'{base_path}/{db_settings["TABLE_INFO"]["TABLE_SAMPLES_PATH"]}'
26+
sample_qna_path = db_settings["TABLE_INFO"]["SAMPLE_QNA_PATH"]
2527
table_name = db_settings["TABLE_INFO"]["TABLE_NAME"]
2628

2729
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s")
@@ -74,11 +76,20 @@ async def chatbot(q: Q):
7476
logging.info(f"Question: {question}")
7577

7678
if q.args.chatbot.lower() == "db setup":
77-
llm_response = db_setup_api(db_name=db_name, hostname=host_name, user_name=user_name, password=password, port=port, table_info_path=table_info_path, table_samples_path=table_samples_path, table_name= table_name)
79+
llm_response = db_setup_api(
80+
db_name=db_name,
81+
hostname=host_name,
82+
user_name=user_name,
83+
password=password,
84+
port=port,
85+
table_info_path=table_info_path,
86+
table_samples_path=table_samples_path,
87+
table_name=table_name,
88+
)
7889
else:
79-
llm_response = query_api(question = question,
80-
sample_queries=None,
81-
table_info_path=table_info_path)
90+
llm_response = query_api(
91+
question=question, sample_queries_path=sample_qna_path, table_info_path=table_info_path
92+
)
8293
llm_response = "\n".join(llm_response)
8394

8495
q.page["chat_card"].data += [llm_response, False]

0 commit comments

Comments
 (0)