Skip to content

Commit e1a1e03

Browse files
committed
Fix agents
1 parent 4e00504 commit e1a1e03

File tree

5 files changed

+84
-50
lines changed

5 files changed

+84
-50
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def load_agent_file(cls, name: str) -> dict:
2424
return load(name.lower())
2525

2626
@classmethod
27-
def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
27+
def get_tool(cls, sql_helper, tool_name: str):
2828
"""Gets the tool based on the tool name.
2929
Args:
3030
----
@@ -46,7 +46,7 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
4646
)
4747
elif tool_name == "sql_get_column_values_tool":
4848
return FunctionToolAlias(
49-
ai_search_helper.get_column_values,
49+
sql_helper.get_column_values,
5050
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
5151
)
5252
else:
@@ -88,12 +88,11 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
8888
agent_file = cls.load_agent_file(name)
8989

9090
sql_helper = ConnectorFactory.get_database_connector()
91-
ai_search_helper = ConnectorFactory.get_ai_search_connector()
9291

9392
tools = []
9493
if "tools" in agent_file and len(agent_file["tools"]) > 0:
9594
for tool in agent_file["tools"]:
96-
tools.append(cls.get_tool(sql_helper, ai_search_helper, tool))
95+
tools.append(cls.get_tool(sql_helper, tool))
9796

9897
agent = AssistantAgent(
9998
name=name,

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,28 @@ async def consume_inner_messages_from_agentic_flow(
108108

109109
try:
110110
if isinstance(inner_message, ToolCallResultMessage):
111-
# Check for SQL query results
112-
parsed_message = self.parse_inner_message(inner_message.content)
113-
114-
logging.info(f"Inner Loaded: {parsed_message}")
115-
116-
if isinstance(parsed_message, dict):
117-
if (
118-
"type" in parsed_message
119-
and parsed_message["type"]
120-
== "query_execution_with_limit"
121-
):
122-
database_results[identifier].append(
123-
{
124-
"sql_query": parsed_message[
125-
"sql_query"
126-
].replace("\n", " "),
127-
"sql_rows": parsed_message["sql_rows"],
128-
}
129-
)
111+
for call_result in inner_message.content:
112+
# Check for SQL query results
113+
parsed_message = self.parse_inner_message(
114+
call_result.content
115+
)
116+
logging.info(f"Inner Loaded: {parsed_message}")
117+
118+
if isinstance(parsed_message, dict):
119+
if (
120+
"type" in parsed_message
121+
and parsed_message["type"]
122+
== "query_execution_with_limit"
123+
):
124+
logging.info("Contains query results")
125+
database_results[identifier].append(
126+
{
127+
"sql_query": parsed_message[
128+
"sql_query"
129+
].replace("\n", " "),
130+
"sql_rows": parsed_message["sql_rows"],
131+
}
132+
)
130133

131134
elif isinstance(inner_message, TextMessage):
132135
parsed_message = self.parse_inner_message(inner_message.content)
@@ -138,6 +141,7 @@ async def consume_inner_messages_from_agentic_flow(
138141
if ("contains_pre_run_results" in parsed_message) and (
139142
parsed_message["contains_pre_run_results"] is True
140143
):
144+
logging.info("Contains pre-run results")
141145
for pre_run_sql_query, pre_run_result in parsed_message[
142146
"cached_questions_and_schemas"
143147
].items():

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import logging
1010
import base64
1111
from datetime import datetime, timezone
12-
import json
1312
from typing import Annotated
1413
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1514

@@ -111,7 +110,6 @@ async def get_column_values(
111110
str,
112111
"The text to run a semantic search against. Relevant entities will be returned.",
113112
],
114-
as_json: bool = True,
115113
):
116114
"""Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
117115
@@ -139,24 +137,7 @@ async def get_column_values(
139137
minimum_score=5,
140138
)
141139

142-
# build into a common format
143-
column_values = {}
144-
145-
for value in values:
146-
trimmed_fqn = ".".join(value["FQN"].split(".")[:-1])
147-
if trimmed_fqn not in column_values:
148-
column_values[trimmed_fqn] = []
149-
150-
column_values[trimmed_fqn].append(value["Value"])
151-
152-
logging.info("Column Values: %s", column_values)
153-
154-
filter_to_column = {text: column_values}
155-
156-
if as_json:
157-
return json.dumps(filter_to_column, default=str)
158-
else:
159-
return filter_to_column
140+
return values
160141

161142
async def get_entity_schemas(
162143
self,
@@ -185,6 +166,8 @@ async def get_entity_schemas(
185166

186167
logging.info("Search Text: %s", text)
187168

169+
stringified_engine_specific_fields = list(map(str, engine_specific_fields))
170+
188171
retrieval_fields = [
189172
"FQN",
190173
"Entity",
@@ -194,7 +177,7 @@ async def get_entity_schemas(
194177
"Columns",
195178
"EntityRelationships",
196179
"CompleteEntityRelationshipsGraph",
197-
] + list(map(str, engine_specific_fields))
180+
] + stringified_engine_specific_fields
198181

199182
schemas = await self.run_ai_search_query(
200183
text,
@@ -207,6 +190,8 @@ async def get_entity_schemas(
207190
top=3,
208191
)
209192

193+
fqn_to_trim = ".".join(stringified_engine_specific_fields)
194+
210195
if len(excluded_entities) == 0:
211196
return schemas
212197

@@ -220,12 +205,16 @@ async def get_entity_schemas(
220205
and len(schema["CompleteEntityRelationshipsGraph"]) == 0
221206
):
222207
del schema["CompleteEntityRelationshipsGraph"]
208+
else:
209+
schema["CompleteEntityRelationshipsGraph"] = list(
210+
map(
211+
lambda x: x.replace(fqn_to_trim, ""),
212+
schema["CompleteEntityRelationshipsGraph"],
213+
)
214+
)
223215

224-
if (
225-
schema["SammpleValues"] is not None
226-
and len(schema["SammpleValues"]) == 0
227-
):
228-
del schema["SammpleValues"]
216+
if schema["SampleValues"] is not None and len(schema["SampleValues"]) == 0:
217+
del schema["SampleValues"]
229218

230219
if (
231220
schema["EntityRelationships"] is not None

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,48 @@ async def query_execution(
7474
list[dict]: The results of the SQL query.
7575
"""
7676

77+
async def get_column_values(
78+
self,
79+
text: Annotated[
80+
str,
81+
"The text to run a semantic search against. Relevant entities will be returned.",
82+
],
83+
as_json: bool = True,
84+
):
85+
"""Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
86+
87+
Args:
88+
----
89+
text (str): The text to run the search against.
90+
91+
Returns:
92+
-------
93+
str: The values of the column in JSON format.
94+
"""
95+
96+
values = await self.ai_search_connector.get_column_values(text)
97+
98+
# build into a common format
99+
column_values = {}
100+
101+
starting = len(self.engine_specific_fields)
102+
103+
for value in values:
104+
trimmed_fqn = ".".join(value["FQN"].split(".")[starting:-1])
105+
if trimmed_fqn not in column_values:
106+
column_values[trimmed_fqn] = []
107+
108+
column_values[trimmed_fqn].append(value["Value"])
109+
110+
logging.info("Column Values: %s", column_values)
111+
112+
filter_to_column = {text: column_values}
113+
114+
if as_json:
115+
return json.dumps(filter_to_column, default=str)
116+
else:
117+
return filter_to_column
118+
77119
@abstractmethod
78120
async def get_entity_schemas(
79121
self,

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def process_message(self, user_questions: list[str]) -> dict:
5555

5656
for filter_condition in entity_result.filter_conditions:
5757
column_search_tasks.append(
58-
self.ai_search_connector.get_column_values(
58+
self.sql_connector.get_column_values(
5959
filter_condition, as_json=False
6060
)
6161
)

0 commit comments

Comments
 (0)