|
9 | 9 | import sqlglot
|
10 | 10 | import torch
|
11 | 11 | from langchain import OpenAI
|
12 |
| -from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, |
13 |
| - LLMPredictor, ServiceContext, SQLDatabase) |
| 12 | +from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase |
14 | 13 | from llama_index.indices.struct_store import SQLContextContainerBuilder
|
15 |
| -from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, |
16 |
| - NSQL_QUERY_PROMPT, QUERY_PROMPT, |
17 |
| - TASK_PROMPT) |
| 14 | +from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT |
18 | 15 | from sidekick.logger import logger
|
19 | 16 | from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
|
20 | 17 | from sqlalchemy import create_engine
|
@@ -271,7 +268,7 @@ def generate_sql(
|
271 | 268 | # Load h2oGPT.NSQL model
|
272 | 269 | device = {"": 0} if torch.cuda.is_available() else "cpu"
|
273 | 270 | if self.model is None:
|
274 |
| - self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device) |
| 271 | + self.tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device) |
275 | 272 | self.model = AutoModelForCausalLM.from_pretrained(
|
276 | 273 | "NumbersStation/nsql-6B", device_map=device, load_in_8bit=True
|
277 | 274 | )
|
@@ -362,7 +359,7 @@ def generate_sql(
|
362 | 359 | )
|
363 | 360 |
|
364 | 361 | logger.debug(f"Query Text:\n {query}")
|
365 |
| - inputs = tokenizer([query], return_tensors="pt") |
| 362 | + inputs = self.tokenizer([query], return_tensors="pt") |
366 | 363 | input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
|
367 | 364 | # Generate SQL
|
368 | 365 | random_seed = random.randint(0, 50)
|
|
0 commit comments