|
9 | 9 | import sqlglot |
10 | 10 | import torch |
11 | 11 | from langchain import OpenAI |
12 | | -from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, |
13 | | - LLMPredictor, ServiceContext, SQLDatabase) |
| 12 | +from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase |
14 | 13 | from llama_index.indices.struct_store import SQLContextContainerBuilder |
15 | | -from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, |
16 | | - NSQL_QUERY_PROMPT, QUERY_PROMPT, |
17 | | - TASK_PROMPT) |
| 14 | +from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT |
18 | 15 | from sidekick.logger import logger |
19 | 16 | from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates |
20 | 17 | from sqlalchemy import create_engine |
@@ -270,10 +267,12 @@ def generate_sql( |
270 | 267 | else: |
271 | 268 | # Load h2oGPT.NSQL model |
272 | 269 | device = {"": 0} if torch.cuda.is_available() else "cpu" |
| 270 | + # https://github.com/pytorch/pytorch/issues/52291 |
| 271 | + _load_in_8bit = False if "cpu" in device else True |
273 | 272 | if self.model is None: |
274 | 273 | self.tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device) |
275 | 274 | self.model = AutoModelForCausalLM.from_pretrained( |
276 | | - "NumbersStation/nsql-6B", device_map=device, load_in_8bit=True |
| 275 | + "NumbersStation/nsql-6B", device_map=device, load_in_8bit=_load_in_8bit |
277 | 276 | ) |
278 | 277 |
|
279 | 278 | # TODO Update needed for multiple tables |
@@ -322,7 +321,8 @@ def generate_sql( |
322 | 321 | threshold=0.9, |
323 | 322 | ) |
324 | 323 | if len(context_queries) > 1 |
325 | | - else (context_queries, _) |
| 324 | + else context_queries, |
| 325 | + None, |
326 | 326 | ) |
327 | 327 | logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") |
328 | 328 | # If QnA pairs > 5, we keep top 5 for focused context |
|
0 commit comments