Skip to content

Commit 2e930f7

Browse files
added validation for tuple schema relationships (#1289)
1 parent b994dbe commit 2e930f7

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

backend/score.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from src.main import *
55
from src.QA_integration import *
66
from src.shared.common_fn import *
7+
from src.shared.llm_graph_builder_exception import LLMGraphBuilderException
78
import uvicorn
89
import asyncio
910
import base64

backend/src/llm.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import boto3
1515
import google.auth
1616
from src.shared.constants import ADDITIONAL_INSTRUCTIONS
17+
from src.shared.llm_graph_builder_exception import LLMGraphBuilderException
1718
import re
18-
import json
19+
from typing import List
1920

2021
def get_llm(model: str):
2122
"""Retrieve the specified language model based on the model name."""
@@ -209,21 +210,45 @@ async def get_graph_document_list(
209210
return graph_document_list
210211

211212
async def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship, chunks_to_combine, additional_instructions=None):
213+
try:
214+
llm, model_name = get_llm(model)
215+
logging.info(f"Using model: {model_name}")
212216

213-
llm, model_name = get_llm(model)
214-
combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list, chunks_to_combine)
217+
combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list, chunks_to_combine)
218+
logging.info(f"Combined {len(combined_chunk_document_list)} chunks")
215219

216-
allowedNodes = allowedNodes.split(',') if allowedNodes else []
220+
allowed_nodes = [node.strip() for node in allowedNodes.split(',') if node.strip()]
221+
logging.info(f"Allowed nodes: {allowed_nodes}")
222+
223+
allowed_relationships = []
224+
if allowedRelationship:
225+
items = [item.strip() for item in allowedRelationship.split(',') if item.strip()]
226+
if len(items) % 3 != 0:
227+
raise LLMGraphBuilderException("allowedRelationship must be a multiple of 3 (source, relationship, target)")
228+
for i in range(0, len(items), 3):
229+
source, relation, target = items[i:i + 3]
230+
if source not in allowed_nodes or target not in allowed_nodes:
231+
raise LLMGraphBuilderException(
232+
f"Invalid relationship ({source}, {relation}, {target}): "
233+
f"source or target not in allowedNodes"
234+
)
235+
allowed_relationships.append((source, relation, target))
236+
logging.info(f"Allowed relationships: {allowed_relationships}")
237+
else:
238+
logging.info("No allowed relationships provided")
217239

218-
if not allowedRelationship:
219-
allowedRelationship = []
220-
else:
221-
items = allowedRelationship.split(',')
222-
allowedRelationship = [tuple(items[i:i+3]) for i in range(0, len(items), 3)]
223-
graph_document_list = await get_graph_document_list(
224-
llm, combined_chunk_document_list, allowedNodes, allowedRelationship, additional_instructions
225-
)
226-
return graph_document_list
240+
graph_document_list = await get_graph_document_list(
241+
llm,
242+
combined_chunk_document_list,
243+
allowed_nodes,
244+
allowed_relationships,
245+
additional_instructions
246+
)
247+
logging.info(f"Generated {len(graph_document_list)} graph documents")
248+
return graph_document_list
249+
except Exception as e:
250+
logging.error(f"Error in get_graph_from_llm: {e}", exc_info=True)
251+
raise LLMGraphBuilderException(f"Error in getting graph from llm: {e}")
227252

228253
def sanitize_additional_instruction(instruction: str) -> str:
229254
"""

0 commit comments

Comments
 (0)