Skip to content

Commit fb52574

Browse files
committed
Update prompts and agents to support programmtic sources
1 parent 0e11c29 commit fb52574

File tree

6 files changed

+131
-38
lines changed

6 files changed

+131
-38
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from autogen_agentchat.conditions import (
44
TextMentionTermination,
55
MaxMessageTermination,
6-
SourceMatchTermination,
76
)
87
from autogen_agentchat.teams import SelectorGroupChat
98
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
@@ -13,6 +12,9 @@
1312
from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import (
1413
SqlSchemaSelectionAgent,
1514
)
15+
from autogen_text_2_sql.custom_agents.answer_and_sources_agent import (
16+
AnswerAndSourcesAgent,
17+
)
1618
from autogen_agentchat.agents import UserProxyAgent
1719
from autogen_agentchat.messages import TextMessage
1820
from autogen_agentchat.base import Response
@@ -99,6 +101,8 @@ def get_all_agents(self):
99101
**self.kwargs,
100102
)
101103

104+
self.answer_and_sources_agent = AnswerAndSourcesAgent()
105+
102106
# Auto-responding UserProxyAgent
103107
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")
104108

@@ -109,6 +113,7 @@ def get_all_agents(self):
109113
self.sql_schema_selection_agent,
110114
self.sql_query_correction_agent,
111115
self.sql_disambiguation_agent,
116+
self.answer_and_sources_agent,
112117
]
113118

114119
if self.use_query_cache:
@@ -122,11 +127,7 @@ def termination_condition(self):
122127
"""Define the termination condition for the chat."""
123128
termination = (
124129
TextMentionTermination("TERMINATE")
125-
| (
126-
TextMentionTermination("answer")
127-
& TextMentionTermination("sources")
128-
& SourceMatchTermination("sql_query_correction_agent")
129-
)
130+
| (TextMentionTermination("answer") & TextMentionTermination("sources"))
130131
| MaxMessageTermination(20)
131132
)
132133
return termination
@@ -166,14 +167,20 @@ def unified_selector(self, messages):
166167
decision = "sql_query_generation_agent"
167168

168169
elif messages[-1].source == "sql_query_correction_agent":
169-
decision = "sql_query_generation_agent"
170+
if "answer" in messages[-1].content is not None:
171+
decision = "answer_and_sources_agent"
172+
else:
173+
decision = "sql_query_generation_agent"
170174

171175
elif messages[-1].source == "sql_query_generation_agent":
172-
decision = "sql_query_correction_agent"
173-
elif messages[-1].source == "sql_query_correction_agent":
174-
decision = "sql_query_correction_agent"
175-
elif messages[-1].source == "answer_agent":
176-
return "user_proxy" # Let user_proxy send TERMINATE
176+
if "query_execution_with_limit" in messages[-1].content:
177+
decision = "sql_query_correction_agent"
178+
else:
179+
# Rerun
180+
decision = "sql_query_generation_agent"
181+
182+
elif messages[-1].source == "answer_and_sources_agent":
183+
decision = "user_proxy" # Let user_proxy send TERMINATE
177184

178185
logging.info("Decision: %s", decision)
179186
return decision

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from autogen_core.components.tools import FunctionTool
3+
from autogen_core.components.tools import FunctionToolAlias
44
from autogen_agentchat.agents import AssistantAgent
55
from text_2_sql_core.connectors.factory import ConnectorFactory
66
from text_2_sql_core.prompts.load import load
@@ -32,25 +32,25 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
3232
tool_name (str): The name of the tool to retrieve.
3333
3434
Returns:
35-
FunctionTool: The tool."""
35+
FunctionToolAlias: The tool."""
3636

3737
if tool_name == "sql_query_execution_tool":
38-
return FunctionTool(
38+
return FunctionToolAlias(
3939
sql_helper.query_execution_with_limit,
4040
description="Runs an SQL query against the SQL Database to extract information",
4141
)
4242
elif tool_name == "sql_get_entity_schemas_tool":
43-
return FunctionTool(
43+
return FunctionToolAlias(
4444
sql_helper.get_entity_schemas,
4545
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
4646
)
4747
elif tool_name == "sql_get_column_values_tool":
48-
return FunctionTool(
48+
return FunctionToolAlias(
4949
ai_search_helper.get_column_values,
5050
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
5151
)
5252
elif tool_name == "current_datetime_tool":
53-
return FunctionTool(
53+
return FunctionToolAlias(
5454
sql_helper.get_current_datetime,
5555
description="Gets the current date and time.",
5656
)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from typing import AsyncGenerator, List, Sequence
4+
5+
from autogen_agentchat.agents import BaseChatAgent
6+
from autogen_agentchat.base import Response
7+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
8+
from autogen_core import CancellationToken
9+
import json
10+
from json import JSONDecodeError
11+
import logging
12+
import pandas as pd
13+
14+
15+
class AnswerAndSourcesAgent(BaseChatAgent):
16+
def __init__(self):
17+
super().__init__(
18+
"answer_and_sources_agent",
19+
"An agent that formats the final answer and sources.",
20+
)
21+
22+
@property
23+
def produced_message_types(self) -> List[type[ChatMessage]]:
24+
return [TextMessage]
25+
26+
async def on_messages(
27+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
28+
) -> Response:
29+
# Calls the on_messages_stream.
30+
response: Response | None = None
31+
async for message in self.on_messages_stream(messages, cancellation_token):
32+
if isinstance(message, Response):
33+
response = message
34+
assert response is not None
35+
return response
36+
37+
async def on_messages_stream(
38+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
39+
) -> AsyncGenerator[AgentMessage | Response, None]:
40+
last_response = messages[-1].content
41+
42+
# Load the json of the last message to populate the final output object
43+
final_output_object = json.loads(last_response)
44+
final_output_object["sources"] = []
45+
46+
for message in messages:
47+
# Load the message content if it is a json object and was a query execution
48+
try:
49+
message = json.loads(message.content)
50+
logging.info(f"Loaded: {message}")
51+
52+
# Search for specific message types and add them to the final output object
53+
if (
54+
"type" in message
55+
and message["type"] == "query_execution_with_limit"
56+
):
57+
dataframe = pd.DataFrame(message["sql_rows"])
58+
final_output_object["sources"].append(
59+
{
60+
"sql_query": message["sql_query"].replace("\n", " "),
61+
"sql_rows": message["sql_rows"],
62+
"markdown_table": dataframe.to_markdown(index=False),
63+
}
64+
)
65+
66+
except JSONDecodeError:
67+
logging.info(f"Could not load message: {message}")
68+
continue
69+
70+
except Exception as e:
71+
logging.error(f"Error processing message: {e}")
72+
raise e
73+
74+
yield Response(
75+
chat_message=TextMessage(
76+
content=json.dumps(final_output_object), source=self.name
77+
)
78+
)
79+
80+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
81+
pass

text_2_sql/text_2_sql_core/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ dependencies = [
1616
"networkx>=3.4.2",
1717
"numpy<2.0.0",
1818
"openai>=1.55.3",
19+
"pandas[tabulate]>=2.2.3",
1920
"pydantic>=2.10.2",
2021
"python-dotenv>=1.0.1",
2122
"pyyaml>=6.0.2",
2223
"rich>=13.9.4",
2324
"sqlglot[rs]>=25.32.1",
25+
"tabulate>=0.9.0",
2426
"tenacity>=9.0.0",
2527
"typer>=0.14.0",
2628
]

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from abc import ABC, abstractmethod
1010
from datetime import datetime
1111
from jinja2 import Template
12+
import json
1213

1314

1415
class SqlConnector(ABC):
@@ -109,9 +110,23 @@ async def query_execution_with_limit(
109110
validation_result = await self.query_validation(sql_query)
110111

111112
if isinstance(validation_result, bool) and validation_result:
112-
return await self.query_execution(sql_query, cast_to=None, limit=25)
113+
result = await self.query_execution(sql_query, cast_to=None, limit=25)
114+
115+
return json.dumps(
116+
{
117+
"type": "query_execution_with_limit",
118+
"sql_query": sql_query,
119+
"sql_rows": result,
120+
}
121+
)
113122
else:
114-
return validation_result
123+
return json.dumps(
124+
{
125+
"type": "errored_query_execution_with_limit",
126+
"sql_query": sql_query,
127+
"errors": validation_result,
128+
}
129+
)
115130

116131
async def query_validation(
117132
self,

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,25 @@ system_message:
2020
2121
<output_format>
2222
- **If the SQL query is valid and the results are correct**:
23-
```json
23+
2424
{
2525
\"answer\": \"<GENERATED ANSWER>\",
26-
\"sources\": [
27-
{
28-
\"sql_result_snippet\": \"<SQL QUERY RESULT 1>\",
29-
\"sql_query_used\": \"<SOURCE 1 SQL QUERY>\",
30-
\"explanation\": \"<EXPLANATION OF SQL QUERY 1>\"
31-
},
32-
{
33-
\"sql_result_snippet\": \"<SQL QUERY RESULT 2>\",
34-
\"sql_query_used\": \"<SOURCE 2 SQL QUERY>\",
35-
\"explanation\": \"<EXPLANATION OF SQL QUERY 2>\"
36-
}
37-
]
3826
}
39-
```
27+
4028
- **If the SQL query needs corrections**:
41-
```json
29+
4230
[
4331
{
4432
\"requested_fix\": \"<EXPLANATION OF REQUESTED FIX OF THE SQL QUERY>\"
4533
}
4634
]
47-
```
35+
4836
- **If the SQL query cannot be corrected**:
49-
```json
37+
5038
{
5139
\"error\": \"Unable to correct the SQL query. Please request a new SQL query.\"
5240
}
53-
```
41+
5442
Followed by **TERMINATE**.
5543
</output_format>
5644
"

0 commit comments

Comments
 (0)