Skip to content

Commit 4565071

Browse files
authored
feat: optimize get_edges (#1138)
* feat: optimize get_edges * feat: optimize get_edges
1 parent 7778189 commit 4565071

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -796,8 +796,8 @@ def add_edge(
796796

797797
start_time = time.time()
798798
if not source_id or not target_id:
799-
logger.info(f"Edge '{source_id}' and '{target_id}' are both None")
800-
raise ValueError("[add_edge] source_id and target_id must be provided")
799+
logger.error(f"Edge '{source_id}' and '{target_id}' are both None")
800+
return
801801

802802
source_exists = self.get_node(source_id) is not None
803803
target_exists = self.get_node(target_id) is not None
@@ -806,11 +806,6 @@ def add_edge(
806806
logger.warning(
807807
"[add_edge] Source %s or target %s does not exist.", source_exists, target_exists
808808
)
809-
logger.info(
810-
"[add_edge_error] Source %s or target %s does not exist.",
811-
source_exists,
812-
target_exists,
813-
)
814809
return
815810

816811
properties = {}
@@ -4039,34 +4034,47 @@ def get_edges(
40394034
...
40404035
]
40414036
"""
4037+
start_time = time.time()
4038+
logger.info(f" get_edges id:{id},type:{type},direction:{direction},user_name:{user_name}")
40424039
user_name = user_name if user_name else self._get_config_value("user_name")
4043-
4044-
if direction == "OUTGOING":
4045-
pattern = "(a:Memory)-[r]->(b:Memory)"
4046-
where_clause = f"a.id = '{id}'"
4047-
elif direction == "INCOMING":
4048-
pattern = "(a:Memory)<-[r]-(b:Memory)"
4049-
where_clause = f"a.id = '{id}'"
4050-
elif direction == "ANY":
4051-
pattern = "(a:Memory)-[r]-(b:Memory)"
4052-
where_clause = f"a.id = '{id}' OR b.id = '{id}'"
4053-
else:
4040+
if direction not in ("OUTGOING", "INCOMING", "ANY"):
40544041
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
40554042

4056-
# Add type filter
4057-
if type != "ANY":
4058-
where_clause += f" AND type(r) = '{type}'"
4059-
4060-
# Add user filter
4061-
where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
4043+
# Escape single quotes for safe embedding in Cypher string
4044+
id_esc = (id or "").replace("'", "''")
4045+
user_esc = (user_name or "").replace("'", "''")
4046+
type_esc = (type or "").replace("'", "''")
4047+
type_filter = f" AND type(r) = '{type_esc}'" if type != "ANY" else ""
4048+
logger.info(f"type_filter:{type_filter}")
40624049

4050+
if direction == "OUTGOING":
4051+
cypher_body = f"""
4052+
MATCH (a:Memory)-[r:{type}]->(b:Memory)
4053+
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'
4054+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
4055+
"""
4056+
elif direction == "INCOMING":
4057+
cypher_body = f"""
4058+
MATCH (b:Memory)<-[r:{type}]-(a:Memory)
4059+
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'
4060+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
4061+
"""
4062+
else: # ANY: union of OUTGOING and INCOMING
4063+
cypher_body = f"""
4064+
MATCH (a:Memory)-[r]->(b:Memory)
4065+
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter}
4066+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
4067+
UNION ALL
4068+
MATCH (b:Memory)<-[r]-(a:Memory)
4069+
WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter}
4070+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
4071+
"""
40634072
query = f"""
40644073
SELECT * FROM cypher('{self.db_name}_graph', $$
4065-
MATCH {pattern}
4066-
WHERE {where_clause}
4067-
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
4074+
{cypher_body.strip()}
40684075
$$) AS (from_id agtype, to_id agtype, edge_type agtype)
40694076
"""
4077+
logger.info(f"get_edges query:{query}")
40704078
conn = None
40714079
try:
40724080
conn = self._get_connection()
@@ -4110,6 +4118,8 @@ def get_edges(
41104118
edge_type = str(edge_type_raw)
41114119

41124120
edges.append({"from": from_id, "to": to_id, "type": edge_type})
4121+
elapsed_time = time.time() - start_time
4122+
logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s")
41134123
return edges
41144124

41154125
except Exception as e:

0 commit comments

Comments
 (0)