Skip to content

Commit 28f7e3c

Browse files
Fix parallel query generation exception by moving to structured output mode (#129)
* Update prompts and agents * Map engine specific deals * Fix agents
1 parent b4662a6 commit 28f7e3c

File tree

18 files changed

+287
-164
lines changed

18 files changed

+287
-164
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def load_agent_file(cls, name: str) -> dict:
2424
return load(name.lower())
2525

2626
@classmethod
27-
def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
27+
def get_tool(cls, sql_helper, tool_name: str):
2828
"""Gets the tool based on the tool name.
2929
Args:
3030
----
@@ -46,7 +46,7 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
4646
)
4747
elif tool_name == "sql_get_column_values_tool":
4848
return FunctionToolAlias(
49-
ai_search_helper.get_column_values,
49+
sql_helper.get_column_values,
5050
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
5151
)
5252
else:
@@ -88,12 +88,11 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
8888
agent_file = cls.load_agent_file(name)
8989

9090
sql_helper = ConnectorFactory.get_database_connector()
91-
ai_search_helper = ConnectorFactory.get_ai_search_connector()
9291

9392
tools = []
9493
if "tools" in agent_file and len(agent_file["tools"]) > 0:
9594
for tool in agent_file["tools"]:
96-
tools.append(cls.get_tool(sql_helper, ai_search_helper, tool))
95+
tools.append(cls.get_tool(sql_helper, tool))
9796

9897
agent = AssistantAgent(
9998
name=name,

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
from typing import AsyncGenerator, List, Sequence
44

55
from autogen_agentchat.agents import BaseChatAgent
6-
from autogen_agentchat.base import Response, TaskResult
7-
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
6+
from autogen_agentchat.base import Response
7+
from autogen_agentchat.messages import (
8+
AgentMessage,
9+
ChatMessage,
10+
TextMessage,
11+
ToolCallResultMessage,
12+
)
813
from autogen_core import CancellationToken
914
import json
1015
import logging
1116
from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql
1217
from aiostream import stream
1318
from json import JSONDecodeError
19+
import re
1420

1521

1622
class ParallelQuerySolvingAgent(BaseChatAgent):
@@ -53,9 +59,6 @@ def parse_inner_message(self, message):
5359
except JSONDecodeError:
5460
pass
5561

56-
# Try to extract JSON from markdown code blocks
57-
import re
58-
5962
json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL)
6063
if json_match:
6164
try:
@@ -103,30 +106,42 @@ async def consume_inner_messages_from_agentic_flow(
103106

104107
logging.info(f"Checking Inner Message: {inner_message}")
105108

106-
if isinstance(inner_message, TaskResult) is False:
107-
try:
109+
try:
110+
if isinstance(inner_message, ToolCallResultMessage):
111+
for call_result in inner_message.content:
112+
# Check for SQL query results
113+
parsed_message = self.parse_inner_message(
114+
call_result.content
115+
)
116+
logging.info(f"Inner Loaded: {parsed_message}")
117+
118+
if isinstance(parsed_message, dict):
119+
if (
120+
"type" in parsed_message
121+
and parsed_message["type"]
122+
== "query_execution_with_limit"
123+
):
124+
logging.info("Contains query results")
125+
database_results[identifier].append(
126+
{
127+
"sql_query": parsed_message[
128+
"sql_query"
129+
].replace("\n", " "),
130+
"sql_rows": parsed_message["sql_rows"],
131+
}
132+
)
133+
134+
elif isinstance(inner_message, TextMessage):
108135
parsed_message = self.parse_inner_message(inner_message.content)
136+
109137
logging.info(f"Inner Loaded: {parsed_message}")
110138

111139
# Search for specific message types and add them to the final output object
112140
if isinstance(parsed_message, dict):
113-
if (
114-
"type" in parsed_message
115-
and parsed_message["type"]
116-
== "query_execution_with_limit"
117-
):
118-
database_results[identifier].append(
119-
{
120-
"sql_query": parsed_message[
121-
"sql_query"
122-
].replace("\n", " "),
123-
"sql_rows": parsed_message["sql_rows"],
124-
}
125-
)
126-
127141
if ("contains_pre_run_results" in parsed_message) and (
128142
parsed_message["contains_pre_run_results"] is True
129143
):
144+
logging.info("Contains pre-run results")
130145
for pre_run_sql_query, pre_run_result in parsed_message[
131146
"cached_questions_and_schemas"
132147
].items():
@@ -139,8 +154,8 @@ async def consume_inner_messages_from_agentic_flow(
139154
}
140155
)
141156

142-
except Exception as e:
143-
logging.warning(f"Error processing message: {e}")
157+
except Exception as e:
158+
logging.warning(f"Error processing message: {e}")
144159

145160
yield inner_message
146161

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core import CancellationToken
9-
from text_2_sql_core.connectors.factory import ConnectorFactory
9+
from text_2_sql_core.custom_agents.sql_query_cache_agent import (
10+
SqlQueryCacheAgentCustomAgent,
11+
)
1012
import json
1113
import logging
1214

@@ -18,7 +20,7 @@ def __init__(self):
1820
"An agent that fetches the queries from the cache based on the user question.",
1921
)
2022

21-
self.sql_connector = ConnectorFactory.get_database_connector()
23+
self.agent = SqlQueryCacheAgentCustomAgent()
2224

2325
@property
2426
def produced_message_types(self) -> List[type[ChatMessage]]:
@@ -49,31 +51,9 @@ async def on_messages_stream(
4951
# If not JSON array, process as single question
5052
raise ValueError("Could not load message")
5153

52-
# Initialize results dictionary
53-
cached_results = {
54-
"cached_questions_and_schemas": [],
55-
"contains_pre_run_results": False,
56-
}
57-
58-
# Process each question sequentially
59-
for question in user_questions:
60-
# Fetch the queries from the cache based on the question
61-
logging.info(f"Fetching queries from cache for question: {question}")
62-
cached_query = await self.sql_connector.fetch_queries_from_cache(
63-
question, injected_parameters=injected_parameters
64-
)
65-
66-
# If any question has pre-run results, set the flag
67-
if cached_query.get("contains_pre_run_results", False):
68-
cached_results["contains_pre_run_results"] = True
69-
70-
# Add the cached results for this question
71-
if cached_query.get("cached_questions_and_schemas"):
72-
cached_results["cached_questions_and_schemas"].extend(
73-
cached_query["cached_questions_and_schemas"]
74-
)
75-
76-
logging.info(f"Final cached results: {cached_results}")
54+
cached_results = await self.agent.process_message(
55+
user_questions, injected_parameters
56+
)
7757
yield Response(
7858
chat_message=TextMessage(
7959
content=json.dumps(cached_results), source=self.name

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 13 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core import CancellationToken
9-
from text_2_sql_core.connectors.factory import ConnectorFactory
109
import json
1110
import logging
12-
from text_2_sql_core.prompts.load import load
13-
from jinja2 import Template
14-
import asyncio
11+
from text_2_sql_core.custom_agents.sql_schema_selection_agent import (
12+
SqlSchemaSelectionAgentCustomAgent,
13+
)
1514

1615

1716
class SqlSchemaSelectionAgent(BaseChatAgent):
@@ -21,15 +20,7 @@ def __init__(self, **kwargs):
2120
"An agent that fetches the schemas from the cache based on the user question.",
2221
)
2322

24-
self.ai_search_connector = ConnectorFactory.get_ai_search_connector()
25-
26-
self.open_ai_connector = ConnectorFactory.get_open_ai_connector()
27-
28-
self.sql_connector = ConnectorFactory.get_database_connector()
29-
30-
system_prompt = load("sql_schema_selection_agent")["system_message"]
31-
32-
self.system_prompt = Template(system_prompt).render(kwargs)
23+
self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs)
3324

3425
@property
3526
def produced_message_types(self) -> List[type[ChatMessage]]:
@@ -49,64 +40,15 @@ async def on_messages(
4940
async def on_messages_stream(
5041
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
5142
) -> AsyncGenerator[AgentMessage | Response, None]:
52-
last_response = messages[-1].content
53-
54-
# load the json of the last message and get the user question's
55-
56-
user_questions = json.loads(last_response)
57-
58-
logging.info(f"User questions: {user_questions}")
59-
60-
entity_tasks = []
61-
62-
for user_question in user_questions:
63-
messages = [
64-
{"role": "system", "content": self.system_prompt},
65-
{"role": "user", "content": user_question},
66-
]
67-
entity_tasks.append(self.open_ai_connector.run_completion_request(messages))
68-
69-
entity_results = await asyncio.gather(*entity_tasks)
70-
71-
entity_search_tasks = []
72-
column_search_tasks = []
73-
74-
for entity_result in entity_results:
75-
loaded_entity_result = json.loads(entity_result)
76-
77-
logging.info(f"Loaded entity result: {loaded_entity_result}")
78-
79-
for entity_group in loaded_entity_result["entities"]:
80-
entity_search_tasks.append(
81-
self.sql_connector.get_entity_schemas(
82-
" ".join(entity_group), as_json=False
83-
)
84-
)
85-
86-
for filter_condition in loaded_entity_result["filter_conditions"]:
87-
column_search_tasks.append(
88-
self.ai_search_connector.get_column_values(
89-
filter_condition, as_json=False
90-
)
91-
)
92-
93-
schemas_results = await asyncio.gather(*entity_search_tasks)
94-
column_value_results = await asyncio.gather(*column_search_tasks)
95-
96-
# deduplicate schemas
97-
final_schemas = []
98-
99-
for schema_result in schemas_results:
100-
for schema in schema_result:
101-
if schema not in final_schemas:
102-
final_schemas.append(schema)
103-
104-
final_results = {
105-
"COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results,
106-
"SCHEMA_OPTIONS": final_schemas,
107-
}
108-
109-
logging.info(f"Final results: {final_results}")
43+
try:
44+
request_details = json.loads(messages[0].content)
45+
user_questions = request_details["question"]
46+
logging.info(f"Processing questions: {user_questions}")
47+
except json.JSONDecodeError:
48+
# If not JSON array, process as single question
49+
raise ValueError("Could not load message")
50+
51+
final_results = await self.agent.process_message(user_questions)
11052

11153
yield Response(
11254
chat_message=TextMessage(

0 commit comments

Comments
 (0)