|
14 | 14 | import boto3
|
15 | 15 | import google.auth
|
16 | 16 | from src.shared.constants import ADDITIONAL_INSTRUCTIONS
|
| 17 | +from src.shared.llm_graph_builder_exception import LLMGraphBuilderException |
17 | 18 | import re
|
18 |
| -import json |
| 19 | +from typing import List |
19 | 20 |
|
20 | 21 | def get_llm(model: str):
|
21 | 22 | """Retrieve the specified language model based on the model name."""
|
@@ -209,21 +210,45 @@ async def get_graph_document_list(
|
209 | 210 | return graph_document_list
|
210 | 211 |
|
211 | 212 | async def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship, chunks_to_combine, additional_instructions=None):
|
| 213 | + try: |
| 214 | + llm, model_name = get_llm(model) |
| 215 | + logging.info(f"Using model: {model_name}") |
212 | 216 |
|
213 |
| - llm, model_name = get_llm(model) |
214 |
| - combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list, chunks_to_combine) |
| 217 | + combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list, chunks_to_combine) |
| 218 | + logging.info(f"Combined {len(combined_chunk_document_list)} chunks") |
215 | 219 |
|
216 |
| - allowedNodes = allowedNodes.split(',') if allowedNodes else [] |
| 220 | + allowed_nodes = [node.strip() for node in allowedNodes.split(',') if node.strip()] |
| 221 | + logging.info(f"Allowed nodes: {allowed_nodes}") |
| 222 | + |
| 223 | + allowed_relationships = [] |
| 224 | + if allowedRelationship: |
| 225 | + items = [item.strip() for item in allowedRelationship.split(',') if item.strip()] |
| 226 | + if len(items) % 3 != 0: |
| 227 | + raise LLMGraphBuilderException("allowedRelationship must be a multiple of 3 (source, relationship, target)") |
| 228 | + for i in range(0, len(items), 3): |
| 229 | + source, relation, target = items[i:i + 3] |
| 230 | + if source not in allowed_nodes or target not in allowed_nodes: |
| 231 | + raise LLMGraphBuilderException( |
| 232 | + f"Invalid relationship ({source}, {relation}, {target}): " |
| 233 | + f"source or target not in allowedNodes" |
| 234 | + ) |
| 235 | + allowed_relationships.append((source, relation, target)) |
| 236 | + logging.info(f"Allowed relationships: {allowed_relationships}") |
| 237 | + else: |
| 238 | + logging.info("No allowed relationships provided") |
217 | 239 |
|
218 |
| - if not allowedRelationship: |
219 |
| - allowedRelationship = [] |
220 |
| - else: |
221 |
| - items = allowedRelationship.split(',') |
222 |
| - allowedRelationship = [tuple(items[i:i+3]) for i in range(0, len(items), 3)] |
223 |
| - graph_document_list = await get_graph_document_list( |
224 |
| - llm, combined_chunk_document_list, allowedNodes, allowedRelationship, additional_instructions |
225 |
| - ) |
226 |
| - return graph_document_list |
| 240 | + graph_document_list = await get_graph_document_list( |
| 241 | + llm, |
| 242 | + combined_chunk_document_list, |
| 243 | + allowed_nodes, |
| 244 | + allowed_relationships, |
| 245 | + additional_instructions |
| 246 | + ) |
| 247 | + logging.info(f"Generated {len(graph_document_list)} graph documents") |
| 248 | + return graph_document_list |
| 249 | + except Exception as e: |
| 250 | + logging.error(f"Error in get_graph_from_llm: {e}", exc_info=True) |
| 251 | + raise LLMGraphBuilderException(f"Error in getting graph from llm: {e}") |
227 | 252 |
|
228 | 253 | def sanitize_additional_instruction(instruction: str) -> str:
|
229 | 254 | """
|
|
0 commit comments