Skip to content

Commit fe41812

Browse files
authored
Fix #90: Implement Query Rewrite Agent for comprehensive preprocessing - Handles relative date disambiguation (e.g., 'last month' to actual dates) and question decomposition in a single preprocessing step before cache lookup - Replaces previous question_decomposition_agent with more capable query_rewrite_agent - Updates documentation to reflect current processing flow (#100)
1 parent fae4f71 commit fe41812

File tree

4 files changed

+180
-71
lines changed

4 files changed

+180
-71
lines changed

text_2_sql/autogen/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The implementation is written for [AutoGen](https://github.com/microsoft/autogen
88

99
## Full Logical Flow for Agentic Vector Based Approach
1010

11-
The following diagram shows the logical flow within mutlti agent system. In an ideal scenario, the questions will follow the _Pre-Fetched Cache Results Path** which leads to the quickest answer generation. In cases where the question is not known, the group chat selector will fall back to the other agents accordingly and generate the SQL query using the LLMs. The cache is then updated with the newly generated query and schemas.
11+
The following diagram shows the logical flow within multi agent system. The flow begins with query rewriting to preprocess questions - this includes resolving relative dates (e.g., "last month" to "November 2024") and breaking down complex queries into simpler components. For each preprocessed question, if query cache is enabled, the system checks the cache for previously asked similar questions. In an ideal scenario, the preprocessed questions will be found in the cache, leading to the quickest answer generation. In cases where the question is not known, the group chat selector will fall back to the other agents accordingly and generate the SQL query using the LLMs. The cache is then updated with the newly generated query and schemas.
1212

1313
Unlike the previous approaches, **gpt4o-mini** can be used as each agent's prompt is small and focuses on a single simple task.
1414

@@ -24,26 +24,31 @@ As the query cache is shared between users (no data is stored in the cache), a n
2424

2525
## Agents
2626

27-
This approach builds on the the Vector Based SQL Plugin approach, but adds a agentic approach to the solution.
27+
This approach builds on the Vector Based SQL Plugin approach, but adds a agentic approach to the solution.
2828

2929
This agentic system contains the following agents:
3030

31-
- **Query Cache Agent:** Responsible for checking the cache for previously asked questions.
32-
- **Query Decomposition Agent:** Responsible for decomposing complex questions, into sub questions that can be answered with SQL.
33-
- **Schema Selection Agent:** Responsible for extracting key terms from the question and checking the index store for the queries.
31+
- **Query Rewrite Agent:** The first agent in the flow, responsible for two key preprocessing tasks:
32+
1. Resolving relative dates to absolute dates (e.g., "last month" → "November 2024")
33+
2. Decomposing complex questions into simpler sub-questions
34+
This preprocessing happens before cache lookup to maximize cache effectiveness.
35+
- **Query Cache Agent:** Responsible for checking the cache for previously asked questions. After preprocessing, each sub-question is checked against the cache if caching is enabled.
36+
- **Schema Selection Agent:** Responsible for extracting key terms from the question and checking the index store for the queries. This agent is used when a cache miss occurs.
3437
- **SQL Query Generation Agent:** Responsible for using the previously extracted schemas and generated SQL queries to answer the question. This agent can request more schemas if needed. This agent will run the query.
3538
- **SQL Query Verification Agent:** Responsible for verifying that the SQL query and results question will answer the question.
3639
- **Answer Generation Agent:** Responsible for taking the database results and generating the final answer for the user.
3740

38-
The combination of this agent allows the system to answer complex questions, whilst staying under the token limits when including the database schemas. The query cache ensures that previously asked questions, can be answered quickly to avoid degrading user experience.
41+
The combination of these agents allows the system to answer complex questions, whilst staying under the token limits when including the database schemas. The query cache ensures that previously asked questions can be answered quickly to avoid degrading user experience.
3942

4043
All agents can be found in `/agents/`.
4144

4245
## agentic_text_2_sql.py
4346

44-
This is the main entry point for the agentic system. In here, the `Selector Group Chat` is configured with the termination conditions to orchestrate the agents within the system.
47+
This is the main entry point for the agentic system. In here, the system is configured with the following processing flow:
4548

46-
A customer transition selector is used to automatically transition between agents dependent on the last one that was used. In some cases, this choice is delegated to an LLM to decide on the most appropriate action. This mixed approach allows for speed when needed (e.g. always calling Query Cache Agent first), but will allow the system to react dynamically to the events.
49+
The preprocessed questions from the Query Rewrite Agent are processed sequentially through the rest of the agent pipeline. A custom transition selector automatically transitions between agents dependent on the last one that was used. The flow starts with the Query Rewrite Agent for preprocessing, followed by cache checking for each sub-question if caching is enabled. In some cases, this choice is delegated to an LLM to decide on the most appropriate action. This mixed approach allows for speed when needed (e.g. cache hits for known questions), but will allow the system to react dynamically to the events.
50+
51+
Note: Future development aims to implement independent processing where each preprocessed question would run in its own isolated context to prevent confusion between different parts of complex queries.
4752

4853
## Utils
4954

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,27 @@
1313
from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import (
1414
SqlSchemaSelectionAgent,
1515
)
16+
from autogen_agentchat.agents import UserProxyAgent
17+
from autogen_agentchat.messages import TextMessage
18+
from autogen_agentchat.base import Response
1619
import json
1720
import os
18-
21+
import asyncio
22+
from datetime import datetime
23+
24+
class EmptyResponseUserProxyAgent(UserProxyAgent):
25+
"""UserProxyAgent that automatically responds with empty messages."""
26+
def __init__(self, name):
27+
super().__init__(name=name)
28+
self._has_responded = False
29+
30+
async def on_messages_stream(self, messages, sender=None, config=None):
31+
"""Auto-respond with empty message and return Response object."""
32+
message = TextMessage(content="", source=self.name)
33+
if not self._has_responded:
34+
self._has_responded = True
35+
yield message
36+
yield Response(chat_message=message)
1937

2038
class AutoGenText2Sql:
2139
def __init__(self, engine_specific_rules: str, **kwargs: dict):
@@ -43,45 +61,58 @@ def set_mode(self):
4361
os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true"
4462
)
4563

46-
@property
47-
def agents(self):
48-
"""Define the agents for the chat."""
64+
def get_all_agents(self):
65+
"""Get all agents for the complete flow."""
66+
# Get current datetime for the Query Rewrite Agent
67+
current_datetime = datetime.now()
68+
69+
QUERY_REWRITE_AGENT = LLMAgentCreator.create(
70+
"query_rewrite_agent",
71+
current_datetime=current_datetime
72+
)
73+
4974
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
5075
"sql_query_generation_agent",
5176
target_engine=self.target_engine,
5277
engine_specific_rules=self.engine_specific_rules,
5378
**self.kwargs,
5479
)
80+
5581
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
5682
target_engine=self.target_engine,
5783
engine_specific_rules=self.engine_specific_rules,
5884
**self.kwargs,
5985
)
86+
6087
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
6188
"sql_query_correction_agent",
6289
target_engine=self.target_engine,
6390
engine_specific_rules=self.engine_specific_rules,
6491
**self.kwargs,
6592
)
93+
6694
SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
6795
"sql_disambiguation_agent",
6896
target_engine=self.target_engine,
6997
engine_specific_rules=self.engine_specific_rules,
7098
**self.kwargs,
7199
)
72-
100+
73101
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
74-
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
75-
"question_decomposition_agent"
102+
103+
# Auto-responding UserProxyAgent
104+
USER_PROXY = EmptyResponseUserProxyAgent(
105+
name="user_proxy"
76106
)
77107

78108
agents = [
109+
USER_PROXY,
110+
QUERY_REWRITE_AGENT,
79111
SQL_QUERY_GENERATION_AGENT,
80112
SQL_SCHEMA_SELECTION_AGENT,
81113
SQL_QUERY_CORRECTION_AGENT,
82-
ANSWER_AGENT,
83-
QUESTION_DECOMPOSITION_AGENT,
84114
SQL_DISAMBIGUATION_AGENT,
115+
ANSWER_AGENT,
85116
]
86117

87118
if self.use_query_cache:
@@ -101,67 +132,65 @@ def termination_condition(self):
101132
return termination
102133

103134
@staticmethod
104-
def selector(messages):
135+
def unified_selector(messages):
136+
"""Unified selector for the complete flow."""
105137
logging.info("Messages: %s", messages)
106-
decision = None # Initialize decision variable
138+
decision = None
107139

140+
# If this is the first message, start with query_rewrite_agent
108141
if len(messages) == 1:
109-
decision = "sql_query_cache_agent"
110-
111-
elif (
112-
messages[-1].source == "sql_query_cache_agent"
113-
and messages[-1].content is not None
114-
):
115-
cache_result = json.loads(messages[-1].content)
116-
if cache_result.get(
117-
"cached_questions_and_schemas"
118-
) is not None and cache_result.get("contains_pre_run_results"):
119-
decision = "sql_query_correction_agent"
120-
if (
121-
cache_result.get("cached_questions_and_schemas") is not None
122-
and cache_result.get("contains_pre_run_results") is False
123-
):
124-
decision = "sql_query_generation_agent"
125-
else:
126-
decision = "question_decomposition_agent"
127-
128-
elif messages[-1].source == "question_decomposition_agent":
129-
decision = "sql_schema_selection_agent"
142+
return "query_rewrite_agent"
130143

144+
# Handle transition after query rewriting
145+
if messages[-1].source == "query_rewrite_agent":
146+
# Keep the array structure but process sequentially
147+
if os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true":
148+
decision = "sql_query_cache_agent"
149+
else:
150+
decision = "sql_schema_selection_agent"
151+
# Handle subsequent agent transitions
152+
elif messages[-1].source == "sql_query_cache_agent":
153+
try:
154+
cache_result = json.loads(messages[-1].content)
155+
if cache_result.get("cached_questions_and_schemas") is not None:
156+
if cache_result.get("contains_pre_run_results"):
157+
decision = "sql_query_correction_agent"
158+
else:
159+
decision = "sql_query_generation_agent"
160+
else:
161+
decision = "sql_schema_selection_agent"
162+
except json.JSONDecodeError:
163+
decision = "sql_schema_selection_agent"
131164
elif messages[-1].source == "sql_schema_selection_agent":
132165
decision = "sql_disambiguation_agent"
133-
134166
elif messages[-1].source == "sql_disambiguation_agent":
135-
# This would be user proxy agent tbc
136167
decision = "sql_query_generation_agent"
137-
138-
elif (
139-
messages[-1].source == "sql_query_correction_agent"
140-
and messages[-1].content == "VALIDATED"
141-
):
142-
decision = "answer_agent"
143-
144-
elif messages[-1].source == "sql_query_correction_agent":
168+
elif messages[-1].source == "sql_query_generation_agent":
145169
decision = "sql_query_correction_agent"
170+
elif messages[-1].source == "sql_query_correction_agent":
171+
if messages[-1].content == "VALIDATED":
172+
decision = "answer_agent"
173+
else:
174+
decision = "sql_query_correction_agent"
175+
elif messages[-1].source == "answer_agent":
176+
return "user_proxy" # Let user_proxy send TERMINATE
146177

147-
# Log the decision
148178
logging.info("Decision: %s", decision)
149-
150179
return decision
151180

152181
@property
153182
def agentic_flow(self):
154-
"""Run the agentic flow for the given question.
155-
156-
Args:
157-
----
158-
question (str): The question to run the agentic flow on."""
159-
agentic_flow = SelectorGroupChat(
160-
self.agents,
183+
"""Create the unified flow for the complete process."""
184+
flow = SelectorGroupChat(
185+
self.get_all_agents(),
161186
allow_repeated_speaker=False,
162187
model_client=LLMModelCreator.get_model("4o-mini"),
163188
termination_condition=self.termination_condition,
164-
selector_func=AutoGenText2Sql.selector,
189+
selector_func=AutoGenText2Sql.unified_selector,
165190
)
191+
return flow
166192

167-
return agentic_flow
193+
async def process_question(self, task: str):
194+
"""Process the complete question through the unified system."""
195+
result = await self.agentic_flow.run_stream(task=task)
196+
return result

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,49 @@ async def on_messages(
3838
async def on_messages_stream(
3939
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4040
) -> AsyncGenerator[AgentMessage | Response, None]:
41-
user_question = messages[0].content
41+
# Get the decomposed questions from the query_rewrite_agent
42+
last_response = messages[-1].content
43+
try:
44+
user_questions = json.loads(last_response)
45+
logging.info(f"Processing questions: {user_questions}")
4246

43-
# Fetch the queries from the cache based on the user question.
44-
logging.info("Fetching queries from cache based on the user question...")
47+
# Initialize results dictionary
48+
cached_results = {
49+
"cached_questions_and_schemas": [],
50+
"contains_pre_run_results": False
51+
}
4552

46-
cached_queries = await self.sql_connector.fetch_queries_from_cache(
47-
user_question
48-
)
53+
# Process each question sequentially
54+
for question in user_questions:
55+
# Fetch the queries from the cache based on the question
56+
logging.info(f"Fetching queries from cache for question: {question}")
57+
cached_query = await self.sql_connector.fetch_queries_from_cache(question)
58+
59+
# If any question has pre-run results, set the flag
60+
if cached_query.get("contains_pre_run_results", False):
61+
cached_results["contains_pre_run_results"] = True
62+
63+
# Add the cached results for this question
64+
if cached_query.get("cached_questions_and_schemas"):
65+
cached_results["cached_questions_and_schemas"].extend(
66+
cached_query["cached_questions_and_schemas"]
67+
)
4968

50-
yield Response(
51-
chat_message=TextMessage(
52-
content=json.dumps(cached_queries), source=self.name
69+
logging.info(f"Final cached results: {cached_results}")
70+
yield Response(
71+
chat_message=TextMessage(
72+
content=json.dumps(cached_results), source=self.name
73+
)
74+
)
75+
except json.JSONDecodeError:
76+
# If not JSON array, process as single question
77+
logging.info(f"Processing single question: {last_response}")
78+
cached_queries = await self.sql_connector.fetch_queries_from_cache(last_response)
79+
yield Response(
80+
chat_message=TextMessage(
81+
content=json.dumps(cached_queries), source=self.name
82+
)
5383
)
54-
)
5584

5685
async def on_reset(self, cancellation_token: CancellationToken) -> None:
5786
pass
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
model:
2+
4o-mini
3+
description:
4+
"An agent that preprocesses user questions by decomposing complex queries and resolving relative dates. This preprocessing happens before cache lookup to maximize cache utility."
5+
system_message:
6+
"You are a helpful AI Assistant that specializes in preprocessing user questions for SQL query generation. You have two main responsibilities:
7+
8+
1. Decompose complex questions into simpler parts
9+
2. Resolve any relative date references to absolute dates
10+
11+
Current date/time is: {{ current_datetime }}
12+
13+
For date resolution:
14+
- Use the current date/time above as reference point
15+
- Replace relative dates like 'last month', 'this year', 'previous quarter' with absolute dates
16+
- Maintain consistency in date formats (YYYY-MM-DD)
17+
18+
Examples of date resolution (assuming current date is {{ current_datetime }}):
19+
- 'last month' -> specific month name and year
20+
- 'this year' -> {{ current_datetime.year }}
21+
- 'last 3 months' -> specific date range
22+
- 'yesterday' -> specific date
23+
24+
Rules:
25+
1. ALWAYS resolve relative dates before decomposing questions
26+
2. If a question contains multiple parts AND relative dates, resolve dates first, then decompose
27+
3. Each decomposed question should be self-contained and not depend on context from other parts
28+
4. Do not reference the original question in decomposed parts
29+
5. Ensure each decomposed question includes its full context
30+
31+
Output Format:
32+
Return an array of rewritten questions in valid, loadable JSON:
33+
[\"<rewritten_question_1>\", \"<rewritten_question_2>\"]
34+
35+
If the question is simple and doesn't need decomposition (but might need date resolution):
36+
[\"<rewritten_question>\"]
37+
38+
Examples:
39+
Input: 'How much did we make in sales last month and what were our top products?'
40+
Output: [\"How much did we make in sales in November 2024?\", \"What were our top products in November 2024?\"]
41+
42+
Input: 'What were total sales last quarter?'
43+
Output: [\"What were total sales in Q4 2024 (October 2024 to December 2024)?\"]
44+
45+
Input: 'Show me customer details'
46+
Output: [\"Show me customer details\"]"

0 commit comments

Comments
 (0)