@@ -51,6 +51,8 @@ def __init__(
51
51
self ._tasks = None
52
52
self .openai_key = openai_key
53
53
self .content_queries = None
54
+ self .model = None # Used for local LLMs
55
+ self .tokenizer = None # Used for local tokenizer
54
56
55
57
def load_column_samples (self , tables : list ):
56
58
# TODO: Maybe we add table name as a member variable
@@ -267,8 +269,12 @@ def generate_sql(
267
269
logger .info (f"Realized query so far:\n { res } " )
268
270
else :
269
271
# Load h2oGPT.NSQL model
270
- tokenizer = AutoTokenizer .from_pretrained ("NumbersStation/nsql-6B" )
271
- model = AutoModelForCausalLM .from_pretrained ("NumbersStation/nsql-6B" )
272
+ device = {"" : 0 } if torch .cuda .is_available () else "cpu"
273
+ if self .model is None :
274
+ self .tokenizer = tokenizer = AutoTokenizer .from_pretrained ("NumbersStation/nsql-6B" , device_map = device )
275
+ self .model = AutoModelForCausalLM .from_pretrained (
276
+ "NumbersStation/nsql-6B" , device_map = device , load_in_8bit = True
277
+ )
272
278
273
279
# TODO Update needed for multiple tables
274
280
columns_w_type = (
@@ -321,8 +327,8 @@ def generate_sql(
321
327
logger .info (f"Number of possible contextual queries to question: { len (filtered_context )} " )
322
328
# If QnA pairs > 5, we keep top 5 for focused context
323
329
_samples = filtered_context
324
- if len (filtered_context ) > 5 :
325
- _samples = filtered_context [0 :5 ][::- 1 ]
330
+ if len (filtered_context ) > 3 :
331
+ _samples = filtered_context [0 :3 ][::- 1 ]
326
332
qna_samples = "\n " .join (_samples )
327
333
328
334
contextual_context_val = ", " .join (contextual_context )
@@ -357,24 +363,28 @@ def generate_sql(
357
363
358
364
logger .debug (f"Query Text:\n { query } " )
359
365
inputs = tokenizer ([query ], return_tensors = "pt" )
360
- input_length = 1 if model .config .is_encoder_decoder else inputs .input_ids .shape [1 ]
366
+ input_length = 1 if self . model .config .is_encoder_decoder else inputs .input_ids .shape [1 ]
361
367
# Generate SQL
362
368
random_seed = random .randint (0 , 50 )
363
369
torch .manual_seed (random_seed )
364
370
365
371
# Greedy search for quick response
366
- output = model .generate (
367
- ** inputs ,
372
+ self .model .eval ()
373
+ device_type = "cuda" if torch .cuda .is_available () else "cpu"
374
+ output = self .model .generate (
375
+ ** inputs .to (device_type ),
368
376
max_new_tokens = 300 ,
369
377
temperature = 0.5 ,
370
378
output_scores = True ,
371
379
return_dict_in_generate = True ,
372
380
)
373
381
374
382
generated_tokens = output .sequences [:, input_length :]
375
- _res = tokenizer .decode (generated_tokens [0 ], skip_special_tokens = True )
383
+ _res = self . tokenizer .decode (generated_tokens [0 ], skip_special_tokens = True )
376
384
# Below is a pre-caution in-case of an error in table name during generation
377
- res = "SELECT" + _res .replace ("table_name" , table_names [0 ])
385
+ # COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
386
+ _temp = _res .replace ("table_name" , table_names [0 ]).split (";" )[0 ]
387
+ res = "SELECT" + _temp + " COLLATE NOCASE;"
378
388
return res
379
389
380
390
def task_formatter (self , input_task : str ):
0 commit comments