Skip to content

Commit 843b219

Browse files
Fix syntax for multiple alternate responses #44
1 parent 796bda6 commit 843b219

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

sidekick/configs/prompt_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
100100
101101
### Input:
102-
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a SQL (dialect:SQLite) query to answer the following question:\n{question_txt}.
102+
For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a valid SQL (dialect:SQLite) query to answer the following question:\n{question_txt}.
103103
This query will run on a database whose schema is represented in this string:
104104
CREATE TABLE '{table_name}' ({column_info}
105105
);

sidekick/query.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,10 @@ def generate_sql(
584584
_out = output_re.sequences[sorted_idx]
585585
res = tokenizer.decode(_out[input_length:], skip_special_tokens=True)
586586
result = res.replace("table_name", _table_name)
587+
# Remove the last semi-colon if exists at the end
588+
# we will add it later
589+
if result.endswith(";"):
590+
result = result.replace(";", "")
587591
if "LIMIT".lower() not in result.lower():
588592
res = "SELECT " + result.strip() + " LIMIT 100;"
589593
else:
@@ -602,6 +606,8 @@ def generate_sql(
602606
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
603607
_temp = _res.replace("table_name", table_name).split(";")[0]
604608

609+
if _temp.endswith(";"):
610+
_temp = _temp.replace(";", "")
605611
if "LIMIT".lower() not in _temp.lower():
606612
res = "SELECT " + _temp.strip() + " LIMIT 100;"
607613
else:

0 commit comments

Comments
 (0)