Skip to content

Commit 7ebed00

Browse files
Adjust fix wen generated token is missing SELECT #44
1 parent 4dd3806 commit 7ebed00

File tree

3 files changed

+97
-72
lines changed

3 files changed

+97
-72
lines changed

sidekick/configs/prompt_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
### *History*:\n{_sample_queries}
3636
### *Question*: For table {_table_name}, {_question}
3737
# SELECT 1
38-
### *Tasks for table {_table_name}*:\n{_tasks}
38+
### *Plan for table {_table_name}*:\n{_tasks}
3939
### *Policies for SQL generation*:
4040
# Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug
4141
# Avoid patterns that might be vulnerable to SQL injection
@@ -118,7 +118,7 @@
118118
- Only use supplied table names: **{table_name}** for generation
119119
- Only use column names from the CREATE TABLE statement: **{column_info}** for generation. DO NOT USE any other column names outside of this.
120120
- Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug
121-
- Avoid patterns that might be vulnerable to SQL injection, e.g. sanitize inputs
121+
- Avoid patterns that might be vulnerable to SQL injection, e.g. use proper sanitization and escaping for raw user input
122122
- Always cast the numerator as float when computing ratios
123123
- Always use COUNT(1) instead of COUNT(*)
124124
- If the question is asking for a rate, use COUNT to compute percentage

sidekick/prompter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def ask(
463463
"""
464464

465465
results = []
466-
err = None # TODO - Need to handle errors if occurred
466+
res = err = alt_res = None # TODO - Need to handle errors if occurred
467467
# Book-keeping
468468
base_path = local_base_path if local_base_path else default_base_path
469469
setup_dir(base_path)
@@ -575,7 +575,7 @@ def ask(
575575
click.echo("Skipping edit...")
576576
if updated_tasks is not None:
577577
sql_g._tasks = updated_tasks
578-
alt_res = None
578+
579579
# The interface could also be used to simply execute user provided SQL
580580
# Keyword: "Execute SQL: <SQL query>"
581581
if (

sidekick/query.py

Lines changed: 93 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def generate_sql(
604604
# Reset temperature to 0.5
605605
current_temperature = 0.5
606606
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
607-
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder2")
607+
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha")
608608
query_txt = [{"role": "user", "content": query},]
609609
logger.debug(f"Generation with default temperature : {current_temperature}")
610610
completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create(
@@ -633,79 +633,104 @@ def generate_sql(
633633
# throttle temperature for different result
634634
logger.info("Regeneration requested on previous query ...")
635635
logger.debug(f"Selected temperature for fast regeneration : {random_temperature}")
636-
output = model.generate(
637-
**inputs.to(device_type),
638-
max_new_tokens=512,
639-
temperature=random_temperature,
640-
output_scores=True,
641-
do_sample=True,
642-
return_dict_in_generate=True,
643-
)
644-
generated_tokens = output.sequences[:, input_length:][0]
636+
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
637+
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha")
638+
query_txt = [{"role": "user", "content": query},]
639+
completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create(
640+
model=m_name,
641+
messages=query_txt,
642+
max_tokens=512,
643+
temperature=random_temperature,
644+
stop="```",
645+
seed=random_seed)
646+
generated_tokens = completion.choices[0].message.content
647+
else:
648+
output = model.generate(
649+
**inputs.to(device_type),
650+
max_new_tokens=512,
651+
temperature=random_temperature,
652+
output_scores=True,
653+
do_sample=True,
654+
return_dict_in_generate=True,
655+
)
656+
generated_tokens = output.sequences[:, input_length:][0]
645657
self.current_temps[model_name] = random_temperature
646658
logger.debug(f"Temperature saved: {self.current_temps[model_name]}")
647659
else:
648660
logger.info("Regeneration with options requested on previous query ...")
649-
# Diverse beam search decoding to explore more options
650-
logger.debug(f"Selected temperature for diverse beam search: {random_temperature}")
651-
output_re = model.generate(
652-
**inputs.to(device_type),
653-
max_new_tokens=512,
654-
temperature=random_temperature,
655-
top_k=5,
656-
top_p=0.9,
657-
num_beams=5,
658-
num_beam_groups=5,
659-
num_return_sequences=5,
660-
output_scores=True,
661-
do_sample=False,
662-
diversity_penalty=2.0,
663-
return_dict_in_generate=True,
664-
)
661+
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
662+
logger.info("Generating diverse options, not enabled for remote models")
663+
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha")
664+
query_txt = [{"role": "user", "content": query},]
665+
completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create(
666+
model=m_name,
667+
messages=query_txt,
668+
max_tokens=512,
669+
temperature=random_temperature,
670+
stop="```",
671+
seed=random_seed)
672+
generated_tokens = completion.choices[0].message.content
673+
else:
674+
# Diverse beam search decoding to explore more options
675+
logger.debug(f"Selected temperature for diverse beam search: {random_temperature}")
676+
output_re = model.generate(
677+
**inputs.to(device_type),
678+
max_new_tokens=512,
679+
temperature=random_temperature,
680+
top_k=5,
681+
top_p=0.9,
682+
num_beams=5,
683+
num_beam_groups=5,
684+
num_return_sequences=5,
685+
output_scores=True,
686+
do_sample=True,
687+
diversity_penalty=2.0,
688+
return_dict_in_generate=True,
689+
)
665690

666-
transition_scores = model.compute_transition_scores(
667-
output_re.sequences, output_re.scores, output_re.beam_indices, normalize_logits=False
668-
)
691+
transition_scores = model.compute_transition_scores(
692+
output_re.sequences, output_re.scores, output_re.beam_indices, normalize_logits=False
693+
)
669694

670-
# Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
671-
mask = transition_scores < 0
672-
# Sum the True values along axis 1
673-
counts = torch.sum(mask, dim=1)
674-
output_length = inputs.input_ids.shape[1] + counts
675-
length_penalty = model.generation_config.length_penalty
676-
reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
677-
678-
# Converting logit scores to prob scores
679-
probabilities_scores = F.softmax(reconstructed_scores, dim=-1)
680-
out_idx = torch.argmax(probabilities_scores)
681-
# Final output
682-
output = output_re.sequences[out_idx]
683-
generated_tokens = output[input_length:]
684-
685-
logger.info(f"Generated options:\n")
686-
prob_sorted_idxs = sorted(
687-
range(len(probabilities_scores)), key=lambda k: probabilities_scores[k], reverse=True
688-
)
689-
for idx, sorted_idx in enumerate(prob_sorted_idxs):
690-
_out = output_re.sequences[sorted_idx]
691-
res = tokenizer.decode(_out[input_length:], skip_special_tokens=True)
692-
result = res.replace("table_name", _table_name)
693-
# Remove the last semi-colon if exists at the end
694-
# we will add it later
695-
if result.endswith(";"):
696-
result = result.replace(";", "")
697-
if "LIMIT".lower() not in result.lower():
698-
res = "SELECT " + result.strip() + " LIMIT 100;"
699-
else:
700-
res = "SELECT " + result.strip() + ";"
701-
702-
pretty_sql = sqlparse.format(res, reindent=True, keyword_case="upper")
703-
syntax_highlight = f"""``` sql\n{pretty_sql}\n```\n\n"""
704-
alt_res = (
705-
f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{syntax_highlight}\n"
695+
# Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
696+
mask = transition_scores < 0
697+
# Sum the True values along axis 1
698+
counts = torch.sum(mask, dim=1)
699+
output_length = inputs.input_ids.shape[1] + counts
700+
length_penalty = model.generation_config.length_penalty
701+
reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
702+
703+
# Converting logit scores to prob scores
704+
probabilities_scores = F.softmax(reconstructed_scores, dim=-1)
705+
out_idx = torch.argmax(probabilities_scores)
706+
# Final output
707+
output = output_re.sequences[out_idx]
708+
generated_tokens = output[input_length:]
709+
710+
logger.info(f"Generated options:\n")
711+
prob_sorted_idxs = sorted(
712+
range(len(probabilities_scores)), key=lambda k: probabilities_scores[k], reverse=True
706713
)
707-
alternate_queries.append(alt_res)
708-
logger.info(alt_res)
714+
for idx, sorted_idx in enumerate(prob_sorted_idxs):
715+
_out = output_re.sequences[sorted_idx]
716+
res = tokenizer.decode(_out[input_length:], skip_special_tokens=True)
717+
result = res.replace("table_name", _table_name)
718+
# Remove the last semi-colon if exists at the end
719+
# we will add it later
720+
if result.endswith(";"):
721+
result = result.replace(";", "")
722+
if "LIMIT".lower() not in result.lower():
723+
res = "SELECT " + result.strip() + " LIMIT 100;"
724+
else:
725+
res = "SELECT " + result.strip() + ";"
726+
727+
pretty_sql = sqlparse.format(res, reindent=True, keyword_case="upper")
728+
syntax_highlight = f"""``` sql\n{pretty_sql}\n```\n\n"""
729+
alt_res = (
730+
f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{syntax_highlight}\n"
731+
)
732+
alternate_queries.append(alt_res)
733+
logger.info(f"Alternate options:\n{alt_res}")
709734

710735
_res = generated_tokens
711736
if not self.remote_model and tokenizer:
@@ -721,7 +746,7 @@ def generate_sql(
721746
# TODO Below should not happen, will have to check why its getting generated as part of response.
722747
# Not sure, if its a vllm or prompt issue.
723748
_temp = _temp.replace("/[/INST]", "").replace("[INST]", "").replace("[/INST]", "").strip()
724-
if "SELECT".lower() not in _temp.lower():
749+
if not _temp.lower().startswith('SELECT'.lower()):
725750
_temp = "SELECT " + _temp.strip()
726751
res = _temp
727752
if "LIMIT".lower() not in _temp.lower():

0 commit comments

Comments
 (0)