Skip to content

Commit 39fe7e7

Browse files
Fix string parsing errors #29
1 parent b4466dd commit 39fe7e7

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

sidekick/prompter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ def query_api(
470470
res, alt_res = sql_g.generate_sql(
471471
table_names, question, model_name=model_name, _dialect=db_dialect
472472
)
473+
res = res.replace("“", '"').replace("”", '"')
474+
[res := res.replace(s, '"') for s in "‘`’'" if s in res]
473475
logger.info(f"Input query: {question}")
474476
logger.info(f"Generated response:\n\n{res}")
475477

@@ -495,7 +497,8 @@ def query_api(
495497

496498
# Before executing, check if known vulnerabilities exist in the generated SQL code.
497499
_val = _val.replace("“", '"').replace("”", '"')
498-
[_val := _val.replace(s, "'") for s in "‘`" if s in _val]
500+
[_val := _val.replace(s, '"') for s in "‘`’'" if s in _val]
501+
499502
r, m = check_vulnerability(_val)
500503
if not r:
501504
q_res, err = db_obj.execute_query_db(query=_val)

sidekick/query.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,16 +323,19 @@ def generate_sql(
323323
else:
324324
# TODO Update needed for multiple tables
325325
columns_w_type = (
326-
self.context_builder.full_context_dict[table_name].split(":")[2].split("and")[0].strip()
326+
self.context_builder.full_context_dict[table_name]
327+
.split(":")[2]
328+
.split(" and foreign keys")[0]
329+
.strip()
327330
)
328331

329332
data_samples_list = self.load_column_samples(table_names)
330333

331334
_context = {
332335
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
333336
"if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT",
334-
"detailed summary": "include min, avg, max",
335-
"summary": "include min, avg, max",
337+
"detailed summary": "include min, avg, max for numeric columns",
338+
"summary": "include min, avg, max for numeric columns",
336339
}
337340

338341
m_path = f"{self.path}/models/sentence_transformers/"
@@ -408,7 +411,7 @@ def generate_sql(
408411
]
409412
data_samples_list = contextual_data_samples
410413

411-
relevant_columns = context_columns if len(context_columns) > 0 else clmn_names
414+
relevant_columns = context_columns if len(context_columns) > 0 else [columns_w_type]
412415
_column_info = ", ".join(relevant_columns)
413416

414417
logger.debug(f"Relevant sample column values: {data_samples_list}")

0 commit comments

Comments
 (0)