Skip to content

Commit 6108d85

Browse files
Fix runtime error for CPU #4
1 parent c0bc110 commit 6108d85

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

sidekick/query.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99
import sqlglot
1010
import torch
1111
from langchain import OpenAI
12-
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
13-
LLMPredictor, ServiceContext, SQLDatabase)
12+
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
1413
from llama_index.indices.struct_store import SQLContextContainerBuilder
15-
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
16-
NSQL_QUERY_PROMPT, QUERY_PROMPT,
17-
TASK_PROMPT)
14+
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT
1815
from sidekick.logger import logger
1916
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
2017
from sqlalchemy import create_engine
@@ -270,10 +267,12 @@ def generate_sql(
270267
else:
271268
# Load h2oGPT.NSQL model
272269
device = {"": 0} if torch.cuda.is_available() else "cpu"
270+
# https://github.com/pytorch/pytorch/issues/52291
271+
_load_in_8bit = False if "cpu" in device else True
273272
if self.model is None:
274273
self.tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
275274
self.model = AutoModelForCausalLM.from_pretrained(
276-
"NumbersStation/nsql-6B", device_map=device, load_in_8bit=True
275+
"NumbersStation/nsql-6B", device_map=device, load_in_8bit=_load_in_8bit
277276
)
278277

279278
# TODO Update needed for multiple tables
@@ -322,7 +321,8 @@ def generate_sql(
322321
threshold=0.9,
323322
)
324323
if len(context_queries) > 1
325-
else (context_queries, _)
324+
else context_queries,
325+
None,
326326
)
327327
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
328328
# If QnA pairs > 5, we keep top 5 for focused context

0 commit comments

Comments
 (0)