@@ -604,7 +604,7 @@ def generate_sql(
604604 # Reset temperature to 0.5
605605 current_temperature = 0.5
606606 if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B" :
607- m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder2 " )
607+ m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder-34b-alpha " )
608608 query_txt = [{"role" : "user" , "content" : query },]
609609 logger .debug (f"Generation with default temperature : { current_temperature } " )
610610 completion = self .h2ogpt_client .with_options (max_retries = 3 ).chat .completions .create (
@@ -633,79 +633,104 @@ def generate_sql(
633633 # throttle temperature for different result
634634 logger .info ("Regeneration requested on previous query ..." )
635635 logger .debug (f"Selected temperature for fast regeneration : { random_temperature } " )
636- output = model .generate (
637- ** inputs .to (device_type ),
638- max_new_tokens = 512 ,
639- temperature = random_temperature ,
640- output_scores = True ,
641- do_sample = True ,
642- return_dict_in_generate = True ,
643- )
644- generated_tokens = output .sequences [:, input_length :][0 ]
636+ if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B" :
637+ m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder-34b-alpha" )
638+ query_txt = [{"role" : "user" , "content" : query },]
639+ completion = self .h2ogpt_client .with_options (max_retries = 3 ).chat .completions .create (
640+ model = m_name ,
641+ messages = query_txt ,
642+ max_tokens = 512 ,
643+ temperature = random_temperature ,
644+ stop = "```" ,
645+ seed = random_seed )
646+ generated_tokens = completion .choices [0 ].message .content
647+ else :
648+ output = model .generate (
649+ ** inputs .to (device_type ),
650+ max_new_tokens = 512 ,
651+ temperature = random_temperature ,
652+ output_scores = True ,
653+ do_sample = True ,
654+ return_dict_in_generate = True ,
655+ )
656+ generated_tokens = output .sequences [:, input_length :][0 ]
645657 self .current_temps [model_name ] = random_temperature
646658 logger .debug (f"Temperature saved: { self .current_temps [model_name ]} " )
647659 else :
648660 logger .info ("Regeneration with options requested on previous query ..." )
649- # Diverse beam search decoding to explore more options
650- logger .debug (f"Selected temperature for diverse beam search: { random_temperature } " )
651- output_re = model .generate (
652- ** inputs .to (device_type ),
653- max_new_tokens = 512 ,
654- temperature = random_temperature ,
655- top_k = 5 ,
656- top_p = 0.9 ,
657- num_beams = 5 ,
658- num_beam_groups = 5 ,
659- num_return_sequences = 5 ,
660- output_scores = True ,
661- do_sample = False ,
662- diversity_penalty = 2.0 ,
663- return_dict_in_generate = True ,
664- )
661+ if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B" :
662+ logger .info ("Generating diverse options, not enabled for remote models" )
663+ m_name = MODEL_CHOICE_MAP_EVAL_MODE .get (model_name , "h2ogpt-sql-sqlcoder-34b-alpha" )
664+ query_txt = [{"role" : "user" , "content" : query },]
665+ completion = self .h2ogpt_client .with_options (max_retries = 3 ).chat .completions .create (
666+ model = m_name ,
667+ messages = query_txt ,
668+ max_tokens = 512 ,
669+ temperature = random_temperature ,
670+ stop = "```" ,
671+ seed = random_seed )
672+ generated_tokens = completion .choices [0 ].message .content
673+ else :
674+ # Diverse beam search decoding to explore more options
675+ logger .debug (f"Selected temperature for diverse beam search: { random_temperature } " )
676+ output_re = model .generate (
677+ ** inputs .to (device_type ),
678+ max_new_tokens = 512 ,
679+ temperature = random_temperature ,
680+ top_k = 5 ,
681+ top_p = 0.9 ,
682+ num_beams = 5 ,
683+ num_beam_groups = 5 ,
684+ num_return_sequences = 5 ,
685+ output_scores = True ,
686+ do_sample = True ,
687+ diversity_penalty = 2.0 ,
688+ return_dict_in_generate = True ,
689+ )
665690
666- transition_scores = model .compute_transition_scores (
667- output_re .sequences , output_re .scores , output_re .beam_indices , normalize_logits = False
668- )
691+ transition_scores = model .compute_transition_scores (
692+ output_re .sequences , output_re .scores , output_re .beam_indices , normalize_logits = False
693+ )
669694
670- # Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
671- mask = transition_scores < 0
672- # Sum the True values along axis 1
673- counts = torch .sum (mask , dim = 1 )
674- output_length = inputs .input_ids .shape [1 ] + counts
675- length_penalty = model .generation_config .length_penalty
676- reconstructed_scores = transition_scores .sum (axis = 1 ) / (output_length ** length_penalty )
677-
678- # Converting logit scores to prob scores
679- probabilities_scores = F .softmax (reconstructed_scores , dim = - 1 )
680- out_idx = torch .argmax (probabilities_scores )
681- # Final output
682- output = output_re .sequences [out_idx ]
683- generated_tokens = output [input_length :]
684-
685- logger .info (f"Generated options:\n " )
686- prob_sorted_idxs = sorted (
687- range (len (probabilities_scores )), key = lambda k : probabilities_scores [k ], reverse = True
688- )
689- for idx , sorted_idx in enumerate (prob_sorted_idxs ):
690- _out = output_re .sequences [sorted_idx ]
691- res = tokenizer .decode (_out [input_length :], skip_special_tokens = True )
692- result = res .replace ("table_name" , _table_name )
693- # Remove the last semi-colon if exists at the end
694- # we will add it later
695- if result .endswith (";" ):
696- result = result .replace (";" , "" )
697- if "LIMIT" .lower () not in result .lower ():
698- res = "SELECT " + result .strip () + " LIMIT 100;"
699- else :
700- res = "SELECT " + result .strip () + ";"
701-
702- pretty_sql = sqlparse .format (res , reindent = True , keyword_case = "upper" )
703- syntax_highlight = f"""``` sql\n { pretty_sql } \n ```\n \n """
704- alt_res = (
705- f"Option { idx + 1 } : (_probability_: { probabilities_scores [sorted_idx ]} )\n { syntax_highlight } \n "
695+ # Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0
696+ mask = transition_scores < 0
697+ # Sum the True values along axis 1
698+ counts = torch .sum (mask , dim = 1 )
699+ output_length = inputs .input_ids .shape [1 ] + counts
700+ length_penalty = model .generation_config .length_penalty
701+ reconstructed_scores = transition_scores .sum (axis = 1 ) / (output_length ** length_penalty )
702+
703+ # Converting logit scores to prob scores
704+ probabilities_scores = F .softmax (reconstructed_scores , dim = - 1 )
705+ out_idx = torch .argmax (probabilities_scores )
706+ # Final output
707+ output = output_re .sequences [out_idx ]
708+ generated_tokens = output [input_length :]
709+
710+ logger .info (f"Generated options:\n " )
711+ prob_sorted_idxs = sorted (
712+ range (len (probabilities_scores )), key = lambda k : probabilities_scores [k ], reverse = True
706713 )
707- alternate_queries .append (alt_res )
708- logger .info (alt_res )
714+ for idx , sorted_idx in enumerate (prob_sorted_idxs ):
715+ _out = output_re .sequences [sorted_idx ]
716+ res = tokenizer .decode (_out [input_length :], skip_special_tokens = True )
717+ result = res .replace ("table_name" , _table_name )
718+ # Remove the last semi-colon if exists at the end
719+ # we will add it later
720+ if result .endswith (";" ):
721+ result = result .replace (";" , "" )
722+ if "LIMIT" .lower () not in result .lower ():
723+ res = "SELECT " + result .strip () + " LIMIT 100;"
724+ else :
725+ res = "SELECT " + result .strip () + ";"
726+
727+ pretty_sql = sqlparse .format (res , reindent = True , keyword_case = "upper" )
728+ syntax_highlight = f"""``` sql\n { pretty_sql } \n ```\n \n """
729+ alt_res = (
730+ f"Option { idx + 1 } : (_probability_: { probabilities_scores [sorted_idx ]} )\n { syntax_highlight } \n "
731+ )
732+ alternate_queries .append (alt_res )
733+ logger .info (f"Alternate options:\n { alt_res } " )
709734
710735 _res = generated_tokens
711736 if not self .remote_model and tokenizer :
@@ -721,7 +746,7 @@ def generate_sql(
721746 # TODO Below should not happen, will have to check why its getting generated as part of response.
722747 # Not sure, if its a vllm or prompt issue.
723748 _temp = _temp .replace ("/[/INST]" , "" ).replace ("[INST]" , "" ).replace ("[/INST]" , "" ).strip ()
724- if "SELECT" .lower () not in _temp . lower ():
749+ if not _temp .lower (). startswith ( 'SELECT' . lower () ):
725750 _temp = "SELECT " + _temp .strip ()
726751 res = _temp
727752 if "LIMIT" .lower () not in _temp .lower ():
0 commit comments