|
| 1 | +import json |
| 2 | +import os |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import requests |
| 6 | + |
| 7 | +from UnionChatBot.utils.BasicManager import BasicManager |
| 8 | +from UnionChatBot.utils.EmbeddingAPI import MyEmbeddingFunction |
| 9 | +from UnionChatBot.utils.ChromaAdapter import ChromaAdapter |
| 10 | +from UnionChatBot.utils.RedisAdapters import SemanticRedisCache |
| 11 | +from UnionChatBot.utils.ChatHistoryManager import ChatHistoryManager |
| 12 | +from UnionChatBot.utils.QueryRewriteManager import QueryRewriteManager |
| 13 | + |
| 14 | + |
| 15 | +class CoreQueryProcessor(BasicManager): |
| 16 | + """Центральный класс позволяющий реализовать логику работы чат-бота. |
| 17 | +
|
| 18 | + Args: |
| 19 | + embedding_function: объект класса отвечающий за векторизацию текста. |
| 20 | + chroma_adapter: объект класса отвечающий за взаимодействие с векторной БД. |
| 21 | + redis_cache: объект класса отвечающий за взаимодействие с горячей БД Redis. |
| 22 | + chat_manager: объект класса отвечающий за контроль истории пользователя при общении с чат-ботом. |
| 23 | + """ |
| 24 | + |
| 25 | + core_prompt_file = os.getenv("DEFAULT_PROMPT_FILE", "default_prompt.txt") |
| 26 | + core_prompt_dir = os.getenv("DEFAULT_DIR_PROMPT", "./prompts") |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + temperature: float = 0.3, |
| 31 | + stream: bool = False, |
| 32 | + maxTokens: int = 2000, |
| 33 | + model_name: str = "deepseek-r1-distill-qwen-32b", |
| 34 | + embedding_function: MyEmbeddingFunction = None, |
| 35 | + chroma_adapter: ChromaAdapter = None, |
| 36 | + redis_cache: SemanticRedisCache = None, |
| 37 | + chat_manager: ChatHistoryManager = None, |
| 38 | + query_rewriter: Optional[QueryRewriteManager] = None, |
| 39 | + **kwargs, |
| 40 | + ): |
| 41 | + super().__init__( |
| 42 | + model_name=model_name, |
| 43 | + temperature=temperature, |
| 44 | + stream=stream, |
| 45 | + maxTokens=maxTokens, |
| 46 | + **kwargs, |
| 47 | + ) |
| 48 | + |
| 49 | + self.embedding_function = embedding_function |
| 50 | + self.redis_cache = redis_cache |
| 51 | + self.chroma_adapter = chroma_adapter |
| 52 | + self.chat_manager = chat_manager |
| 53 | + self.query_rewriter = query_rewriter |
| 54 | + |
| 55 | + def modify_system_prompt(self, prompt: str, data: dict, user_id: str) -> str: |
| 56 | + """Модифицируем системный промт исходя из ответов из базы данных. |
| 57 | +
|
| 58 | + Args: |
| 59 | + prompt: системый промт по умолчанию. |
| 60 | + data: словарь с релевантной информацией из БД. |
| 61 | + user_id: уникальный индетефикатор пользователя. |
| 62 | +
|
| 63 | + Return: |
| 64 | + Модифицированный системный промт исходя из дополнительной информации из БД и истории диалога. |
| 65 | + """ |
| 66 | + history_data = self.chat_manager.get_formatted_history(user_id=user_id) |
| 67 | + prompt += "<RAG>" |
| 68 | + context = ( |
| 69 | + " ".join( |
| 70 | + [ |
| 71 | + "№" |
| 72 | + + str(idx) |
| 73 | + + " <Информация>: " |
| 74 | + + info[0] |
| 75 | + + " " |
| 76 | + + "<Источник>: " |
| 77 | + + info[1].get(list(info[1].keys())[0]) |
| 78 | + + " </Источник> <Файл> " |
| 79 | + + list(info[1].keys())[0] |
| 80 | + + "</Файл>" |
| 81 | + + "</Информация> \n" |
| 82 | + for idx, info in enumerate( |
| 83 | + zip(data.get("documents"), data.get("metadatas")) |
| 84 | + ) |
| 85 | + ] |
| 86 | + ) |
| 87 | + + "</RAG>" |
| 88 | + ) |
| 89 | + prompt += " " + context + history_data |
| 90 | + return prompt |
| 91 | + |
| 92 | + def ask(self, query: str, collection_name: str, user_id: str) -> str: |
| 93 | + """Инициализация диалога с чат-ботом. |
| 94 | +
|
| 95 | + Args: |
| 96 | + query: вопрос пользователя & сообщение. |
| 97 | + collection_name: название коллекции к которой необходимо обратиться в ChromaDB. |
| 98 | + user_id: уникальный идентификатор пользователя. |
| 99 | +
|
| 100 | + Return: |
| 101 | + Текстовый ответ модели для пользователя. |
| 102 | + """ |
| 103 | + system_prompt = self.read_prompt( |
| 104 | + prompt_file=self.core_prompt_file, prompt_dir=self.core_prompt_dir |
| 105 | + ) |
| 106 | + |
| 107 | + if self.query_rewriter: |
| 108 | + query, status = self.query_rewriter.rewrite(query=query, user_id=user_id) |
| 109 | + if status != 200: |
| 110 | + return query |
| 111 | + |
| 112 | + query_embedding = self.embedding_function(query) |
| 113 | + |
| 114 | + cached = self.redis_cache.get(query, query_embedding) |
| 115 | + if cached: |
| 116 | + self.chat_manager.add_message_to_history( |
| 117 | + user_id=user_id, message=cached["response"] |
| 118 | + ) |
| 119 | + return cached["response"] |
| 120 | + |
| 121 | + data = self.chroma_adapter.get_info( |
| 122 | + query=query, collection_name=collection_name |
| 123 | + ) |
| 124 | + new_prompt = self.modify_system_prompt( |
| 125 | + prompt=system_prompt, data=data, user_id=user_id |
| 126 | + ) |
| 127 | + response = requests.post( |
| 128 | + url=self.url, |
| 129 | + headers=self.setup_header(), |
| 130 | + data=self.setup_data(text=query, prompt=new_prompt), |
| 131 | + ) |
| 132 | + |
| 133 | + if response.status_code == 200: |
| 134 | + dict_response = json.loads(response.content) |
| 135 | + answer = ( |
| 136 | + dict_response.get("result") |
| 137 | + .get("alternatives")[0] |
| 138 | + .get("message") |
| 139 | + .get("text") |
| 140 | + ) |
| 141 | + self.redis_cache.set(query, query_embedding, answer) |
| 142 | + self.chat_manager.add_message_to_history(user_id=user_id, message=answer) |
| 143 | + else: |
| 144 | + answer = ( |
| 145 | + f"Код ответа {response.status_code}. Попробуйте задать вопрос позднее." |
| 146 | + ) |
| 147 | + return answer |
0 commit comments