@@ -51,6 +51,8 @@ def __init__(
5151 self ._tasks = None
5252 self .openai_key = openai_key
5353 self .content_queries = None
54+ self .model = None # Used for local LLMs
55+ self .tokenizer = None # Used for local tokenizer
5456
5557 def load_column_samples (self , tables : list ):
5658 # TODO: Maybe we add table name as a member variable
@@ -267,8 +269,12 @@ def generate_sql(
267269 logger .info (f"Realized query so far:\n { res } " )
268270 else :
269271 # Load h2oGPT.NSQL model
270- tokenizer = AutoTokenizer .from_pretrained ("NumbersStation/nsql-6B" )
271- model = AutoModelForCausalLM .from_pretrained ("NumbersStation/nsql-6B" )
272+ device = {"" : 0 } if torch .cuda .is_available () else "cpu"
273+ if self .model is None :
274+ self .tokenizer = tokenizer = AutoTokenizer .from_pretrained ("NumbersStation/nsql-6B" , device_map = device )
275+ self .model = AutoModelForCausalLM .from_pretrained (
276+ "NumbersStation/nsql-6B" , device_map = device , load_in_8bit = True
277+ )
272278
273279 # TODO Update needed for multiple tables
274280 columns_w_type = (
@@ -321,8 +327,8 @@ def generate_sql(
321327 logger .info (f"Number of possible contextual queries to question: { len (filtered_context )} " )
322328 # If QnA pairs > 5, we keep top 5 for focused context
323329 _samples = filtered_context
324- if len (filtered_context ) > 5 :
325- _samples = filtered_context [0 :5 ][::- 1 ]
330+ if len (filtered_context ) > 3 :
331+ _samples = filtered_context [0 :3 ][::- 1 ]
326332 qna_samples = "\n " .join (_samples )
327333
328334 contextual_context_val = ", " .join (contextual_context )
@@ -357,24 +363,28 @@ def generate_sql(
357363
358364 logger .debug (f"Query Text:\n { query } " )
359365 inputs = tokenizer ([query ], return_tensors = "pt" )
360- input_length = 1 if model .config .is_encoder_decoder else inputs .input_ids .shape [1 ]
366+ input_length = 1 if self . model .config .is_encoder_decoder else inputs .input_ids .shape [1 ]
361367 # Generate SQL
362368 random_seed = random .randint (0 , 50 )
363369 torch .manual_seed (random_seed )
364370
365371 # Greedy search for quick response
366- output = model .generate (
367- ** inputs ,
372+ self .model .eval ()
373+ device_type = "cuda" if torch .cuda .is_available () else "cpu"
374+ output = self .model .generate (
375+ ** inputs .to (device_type ),
368376 max_new_tokens = 300 ,
369377 temperature = 0.5 ,
370378 output_scores = True ,
371379 return_dict_in_generate = True ,
372380 )
373381
374382 generated_tokens = output .sequences [:, input_length :]
375- _res = tokenizer .decode (generated_tokens [0 ], skip_special_tokens = True )
383+ _res = self . tokenizer .decode (generated_tokens [0 ], skip_special_tokens = True )
376384 # Below is a pre-caution in-case of an error in table name during generation
377- res = "SELECT" + _res .replace ("table_name" , table_names [0 ])
385+ # COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
386+ _temp = _res .replace ("table_name" , table_names [0 ]).split (";" )[0 ]
387+ res = "SELECT" + _temp + " COLLATE NOCASE;"
378388 return res
379389
380390 def task_formatter (self , input_task : str ):
0 commit comments