Skip to content

Commit a7dfae3

Browse files
Fix initialization #4
1 parent dd6ab34 commit a7dfae3

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

sidekick/query.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99
import sqlglot
1010
import torch
1111
from langchain import OpenAI
12-
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
13-
LLMPredictor, ServiceContext, SQLDatabase)
12+
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
1413
from llama_index.indices.struct_store import SQLContextContainerBuilder
15-
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
16-
NSQL_QUERY_PROMPT, QUERY_PROMPT,
17-
TASK_PROMPT)
14+
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT
1815
from sidekick.logger import logger
1916
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
2017
from sqlalchemy import create_engine
@@ -271,7 +268,7 @@ def generate_sql(
271268
# Load h2oGPT.NSQL model
272269
device = {"": 0} if torch.cuda.is_available() else "cpu"
273270
if self.model is None:
274-
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
271+
self.tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
275272
self.model = AutoModelForCausalLM.from_pretrained(
276273
"NumbersStation/nsql-6B", device_map=device, load_in_8bit=True
277274
)
@@ -362,7 +359,7 @@ def generate_sql(
362359
)
363360

364361
logger.debug(f"Query Text:\n {query}")
365-
inputs = tokenizer([query], return_tensors="pt")
362+
inputs = self.tokenizer([query], return_tensors="pt")
366363
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
367364
# Generate SQL
368365
random_seed = random.randint(0, 50)

0 commit comments

Comments
 (0)