@@ -604,7 +604,7 @@ def generate_sql(
604
604
# Reset temperature to 0.5
605
605
current_temperature = 0.5
606
606
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B" :
607
- m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder2 " )
607
+ m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder-34b-alpha " )
608
608
query_txt = [{"role" : "user" , "content" : query },]
609
609
logger .debug (f"Generation with default temperature : { current_temperature } " )
610
610
completion = self .h2ogpt_client .with_options (max_retries = 3 ).chat .completions .create (
@@ -633,79 +633,104 @@ def generate_sql(
633
633
# throttle temperature for different result
634
634
logger .info ("Regeneration requested on previous query ..." )
635
635
logger .debug (f"Selected temperature for fast regeneration : { random_temperature } " )
636
- output = model .generate (
637
- ** inputs .to (device_type ),
638
- max_new_tokens = 512 ,
639
- temperature = random_temperature ,
640
- output_scores = True ,
641
- do_sample = True ,
642
- return_dict_in_generate = True ,
643
- )
644
- generated_tokens = output .sequences [:, input_length :][0 ]
636
+ if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B" :
637
+ m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder-34b-alpha" )
638
+ query_txt = [{"role" : "user" , "content" : query },]
639
+ completion = self .h2ogpt_client .with_options (max_retries = 3 ).chat .completions .create (
640
+ model = m_name ,
641
+ messages = query_txt ,
642
+ max_tokens = 512 ,
643
+ temperature = random_temperature ,
644
+ stop = "```" ,
645
+ seed = random_seed )
646
+ generated_tokens = completion .choices [0 ].message .content
647
+ else :
648
+ output = model .generate (
649
+ ** inputs .to (device_type ),
650
+ max_new_tokens = 512 ,
651
+ temperature = random_temperature ,
652
+ output_scores = True ,
653
+ do_sample = True ,
654
+ return_dict_in_generate = True ,
655
+ )
656
+ generated_tokens = output .sequences [:, input_length :][0 ]
645
657
self .current_temps [model_name ] = random_temperature
646
658
logger .debug (f"Temperature saved: { self .current_temps [model_name ]} " )
647
659
else :
648
660
logger .info ("Regeneration with options requested on previous query ..." )
649
- # Diverse beam search decoding to explore more options
650
- logger .debug (f"Selected temperature for diverse beam search: { random_temperature } " )
651
- output_re = model .generate (
652
- ** inputs .to (device_type ),
653
- max_new_tokens = 512 ,
654
- temperature = random_temperature ,
655
- top_k = 5 ,
656
- top_p = 0.9 ,
657
- num_beams = 5 ,
658
- num_beam_groups = 5 ,
659
- num_return_sequences = 5 ,
660
- output_scores = True ,
661
- do_sample = False ,
662
- diversity_penalty = 2.0 ,
663
- return_dict_in_generate = True ,
664
- )
661
+ if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B" :
662
+ logger .info ("Generating diverse options, not enabled for remote models" )
663
+ m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder-34b-alpha" )
664
+ query_txt = [{"role" : "user" , "content" : query },]
665
+ completion = self .h2ogpt_client .with_options (max_retries = 3 ).chat .completions .create (
666
+ model = m_name ,
667
+ messages = query_txt ,
668
+ max_tokens = 512 ,
669
+ temperature = random_temperature ,
670
+ stop = "```" ,
671
+ seed = random_seed )
672
+ generated_tokens = completion .choices [0 ].message .content
673
+ else :
674
+ # Diverse beam search decoding to explore more options
675
+ logger .debug (f"Selected temperature for diverse beam search: { random_temperature } " )
676
+ output_re = model .generate (
677
+ ** inputs .to (device_type ),
678
+ max_new_tokens = 512 ,
679
+ temperature = random_temperature ,
680
+ top_k = 5 ,
681
+ top_p = 0.9 ,
682
+ num_beams = 5 ,
683
+ num_beam_groups = 5 ,
684
+ num_return_sequences = 5 ,
685
+ output_scores = True ,
686
+ do_sample = True ,
687
+ diversity_penalty = 2.0 ,
688
+ return_dict_in_generate = True ,
689
+ )
665
690
666
- transition_scores = model .compute_transition_scores (
667
- output_re .sequences , output_re .scores , output_re .beam_indices , normalize_logits = False
668
- )
691
+ transition_scores = model .compute_transition_scores (
692
+ output_re .sequences , output_re .scores , output_re .beam_indices , normalize_logits = False
693
+ )
669
694
670
- # Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
671
- mask = transition_scores < 0
672
- # Sum the True values along axis 1
673
- counts = torch .sum (mask , dim = 1 )
674
- output_length = inputs .input_ids .shape [1 ] + counts
675
- length_penalty = model .generation_config .length_penalty
676
- reconstructed_scores = transition_scores .sum (axis = 1 ) / (output_length ** length_penalty )
677
-
678
- # Converting logit scores to prob scores
679
- probabilities_scores = F .softmax (reconstructed_scores , dim = - 1 )
680
- out_idx = torch .argmax (probabilities_scores )
681
- # Final output
682
- output = output_re .sequences [out_idx ]
683
- generated_tokens = output [input_length :]
684
-
685
- logger .info (f"Generated options:\n " )
686
- prob_sorted_idxs = sorted (
687
- range (len (probabilities_scores )), key = lambda k : probabilities_scores [k ], reverse = True
688
- )
689
- for idx , sorted_idx in enumerate (prob_sorted_idxs ):
690
- _out = output_re .sequences [sorted_idx ]
691
- res = tokenizer .decode (_out [input_length :], skip_special_tokens = True )
692
- result = res .replace ("table_name" , _table_name )
693
- # Remove the last semi-colon if exists at the end
694
- # we will add it later
695
- if result .endswith (";" ):
696
- result = result .replace (";" , "" )
697
- if "LIMIT" .lower () not in result .lower ():
698
- res = "SELECT " + result .strip () + " LIMIT 100;"
699
- else :
700
- res = "SELECT " + result .strip () + ";"
701
-
702
- pretty_sql = sqlparse .format (res , reindent = True , keyword_case = "upper" )
703
- syntax_highlight = f"""``` sql\n { pretty_sql } \n ```\n \n """
704
- alt_res = (
705
- f"Option { idx + 1 } : (_probability_: { probabilities_scores [sorted_idx ]} )\n { syntax_highlight } \n "
695
+ # Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
696
+ mask = transition_scores < 0
697
+ # Sum the True values along axis 1
698
+ counts = torch .sum (mask , dim = 1 )
699
+ output_length = inputs .input_ids .shape [1 ] + counts
700
+ length_penalty = model .generation_config .length_penalty
701
+ reconstructed_scores = transition_scores .sum (axis = 1 ) / (output_length ** length_penalty )
702
+
703
+ # Converting logit scores to prob scores
704
+ probabilities_scores = F .softmax (reconstructed_scores , dim = - 1 )
705
+ out_idx = torch .argmax (probabilities_scores )
706
+ # Final output
707
+ output = output_re .sequences [out_idx ]
708
+ generated_tokens = output [input_length :]
709
+
710
+ logger .info (f"Generated options:\n " )
711
+ prob_sorted_idxs = sorted (
712
+ range (len (probabilities_scores )), key = lambda k : probabilities_scores [k ], reverse = True
706
713
)
707
- alternate_queries .append (alt_res )
708
- logger .info (alt_res )
714
+ for idx , sorted_idx in enumerate (prob_sorted_idxs ):
715
+ _out = output_re .sequences [sorted_idx ]
716
+ res = tokenizer .decode (_out [input_length :], skip_special_tokens = True )
717
+ result = res .replace ("table_name" , _table_name )
718
+ # Remove the last semi-colon if exists at the end
719
+ # we will add it later
720
+ if result .endswith (";" ):
721
+ result = result .replace (";" , "" )
722
+ if "LIMIT" .lower () not in result .lower ():
723
+ res = "SELECT " + result .strip () + " LIMIT 100;"
724
+ else :
725
+ res = "SELECT " + result .strip () + ";"
726
+
727
+ pretty_sql = sqlparse .format (res , reindent = True , keyword_case = "upper" )
728
+ syntax_highlight = f"""``` sql\n { pretty_sql } \n ```\n \n """
729
+ alt_res = (
730
+ f"Option { idx + 1 } : (_probability_: { probabilities_scores [sorted_idx ]} )\n { syntax_highlight } \n "
731
+ )
732
+ alternate_queries .append (alt_res )
733
+ logger .info (f"Alternate options:\n { alt_res } " )
709
734
710
735
_res = generated_tokens
711
736
if not self .remote_model and tokenizer :
@@ -721,7 +746,7 @@ def generate_sql(
721
746
# TODO Below should not happen, will have to check why its getting generated as part of response.
722
747
# Not sure, if its a vllm or prompt issue.
723
748
_temp = _temp .replace ("/[/INST]" , "" ).replace ("[INST]" , "" ).replace ("[/INST]" , "" ).strip ()
724
- if "SELECT" .lower () not in _temp . lower ():
749
+ if not _temp .lower (). startswith ( 'SELECT' . lower () ):
725
750
_temp = "SELECT " + _temp .strip ()
726
751
res = _temp
727
752
if "LIMIT" .lower () not in _temp .lower ():
0 commit comments