Skip to content

Commit ca04ac9

Browse files
Improvements to method of query rewriting (#156)
1 parent 2644e2d commit ca04ac9

File tree

8 files changed

+171
-183
lines changed

8 files changed

+171
-183
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 88 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ async def on_messages_stream(
9696
injected_parameters = {}
9797

9898
# Load the json of the last message to populate the final output object
99-
message_rewrites = json.loads(last_response)
99+
sequential_rounds = json.loads(last_response)
100100

101-
logging.info(f"Query Rewrites: {message_rewrites}")
101+
logging.info(f"Query Rewrites: {sequential_rounds}")
102102

103103
async def consume_inner_messages_from_agentic_flow(
104104
agentic_flow, identifier, filtered_parallel_messages
@@ -197,7 +197,7 @@ async def consume_inner_messages_from_agentic_flow(
197197

198198
# Convert all_non_database_query to lowercase string and compare
199199
all_non_database_query = str(
200-
message_rewrites.get("all_non_database_query", "false")
200+
sequential_rounds.get("all_non_database_query", "false")
201201
).lower()
202202

203203
if all_non_database_query == "true":
@@ -210,84 +210,97 @@ async def consume_inner_messages_from_agentic_flow(
210210
return
211211

212212
# Start processing sub-queries
213-
for message_rewrite in message_rewrites["decomposed_user_messages"]:
214-
logging.info(f"Processing sub-query: {message_rewrite}")
215-
# Create an instance of the InnerAutoGenText2Sql class
216-
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
217-
218-
identifier = ", ".join(message_rewrite)
219-
220-
# Add database connection info to injected parameters
221-
query_params = injected_parameters.copy() if injected_parameters else {}
222-
if "Text2Sql__Tsql__ConnectionString" in os.environ:
223-
query_params["database_connection_string"] = os.environ[
224-
"Text2Sql__Tsql__ConnectionString"
225-
]
226-
if "Text2Sql__Tsql__Database" in os.environ:
227-
query_params["database_name"] = os.environ["Text2Sql__Tsql__Database"]
228-
229-
# Launch tasks for each sub-query
230-
inner_solving_generators.append(
231-
consume_inner_messages_from_agentic_flow(
232-
inner_autogen_text_2_sql.process_user_message(
233-
user_message=message_rewrite,
234-
injected_parameters=query_params,
235-
),
236-
identifier,
237-
filtered_parallel_messages,
213+
for sequential_round in sequential_rounds["decomposed_user_messages"]:
214+
logging.info(f"Processing round: {sequential_round}")
215+
216+
for parallel_message in sequential_round:
217+
logging.info(f"Parallel Message: {parallel_message}")
218+
219+
# Create an instance of the InnerAutoGenText2Sql class
220+
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
221+
222+
# Add database connection info to injected parameters
223+
query_params = injected_parameters.copy() if injected_parameters else {}
224+
if "Text2Sql__Tsql__ConnectionString" in os.environ:
225+
query_params["database_connection_string"] = os.environ[
226+
"Text2Sql__Tsql__ConnectionString"
227+
]
228+
if "Text2Sql__Tsql__Database" in os.environ:
229+
query_params["database_name"] = os.environ[
230+
"Text2Sql__Tsql__Database"
231+
]
232+
233+
# Launch tasks for each sub-query
234+
inner_solving_generators.append(
235+
consume_inner_messages_from_agentic_flow(
236+
inner_autogen_text_2_sql.process_user_message(
237+
user_message=parallel_message,
238+
injected_parameters=query_params,
239+
database_results=filtered_parallel_messages.database_results,
240+
),
241+
parallel_message,
242+
filtered_parallel_messages,
243+
)
238244
)
245+
246+
logging.info(
247+
"Created %i Inner Solving Generators", len(inner_solving_generators)
248+
)
249+
logging.info("Starting Inner Solving Generators")
250+
combined_message_streams = stream.merge(*inner_solving_generators)
251+
252+
async with combined_message_streams.stream() as streamer:
253+
async for inner_message in streamer:
254+
if isinstance(inner_message, TextMessage):
255+
logging.debug(f"Inner Solving Message: {inner_message}")
256+
yield inner_message
257+
258+
# Log final results for debugging or auditing
259+
logging.info(
260+
"Database Results: %s", filtered_parallel_messages.database_results
261+
)
262+
logging.info(
263+
"Disambiguation Requests: %s",
264+
filtered_parallel_messages.disambiguation_requests,
239265
)
240266

241-
logging.info(
242-
"Created %i Inner Solving Generators", len(inner_solving_generators)
243-
)
244-
logging.info("Starting Inner Solving Generators")
245-
combined_message_streams = stream.merge(*inner_solving_generators)
246-
247-
async with combined_message_streams.stream() as streamer:
248-
async for inner_message in streamer:
249-
if isinstance(inner_message, TextMessage):
250-
logging.debug(f"Inner Solving Message: {inner_message}")
251-
yield inner_message
252-
253-
# Log final results for debugging or auditing
254-
logging.info(
255-
"Database Results: %s", filtered_parallel_messages.database_results
256-
)
257-
logging.info(
258-
"Disambiguation Requests: %s",
259-
filtered_parallel_messages.disambiguation_requests,
260-
)
267+
# Check for disambiguation requests before processing the next round
261268

262-
if (
263-
max(map(len, filtered_parallel_messages.disambiguation_requests.values()))
264-
> 0
265-
):
266-
# Final response
267-
yield Response(
268-
chat_message=TextMessage(
269-
content=json.dumps(
270-
{
271-
"contains_disambiguation_requests": True,
272-
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
273-
}
274-
),
275-
source=self.name,
276-
),
277-
)
278-
else:
279-
# Final response
280-
yield Response(
281-
chat_message=TextMessage(
282-
content=json.dumps(
283-
{
284-
"contains_database_results": True,
285-
"database_results": filtered_parallel_messages.database_results,
286-
}
269+
if (
270+
max(
271+
map(
272+
len, filtered_parallel_messages.disambiguation_requests.values()
273+
)
274+
)
275+
> 0
276+
):
277+
# Final response
278+
yield Response(
279+
chat_message=TextMessage(
280+
content=json.dumps(
281+
{
282+
"contains_disambiguation_requests": True,
283+
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
284+
}
285+
),
286+
source=self.name,
287287
),
288-
source=self.name,
288+
)
289+
290+
break
291+
292+
# Final response
293+
yield Response(
294+
chat_message=TextMessage(
295+
content=json.dumps(
296+
{
297+
"contains_database_results": True,
298+
"database_results": filtered_parallel_messages.database_results,
299+
}
289300
),
290-
)
301+
source=self.name,
302+
),
303+
)
291304

292305
async def on_reset(self, cancellation_token: CancellationToken) -> None:
293306
pass

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ async def on_messages_stream(
4444
try:
4545
request_details = json.loads(messages[0].content)
4646
injected_parameters = request_details["injected_parameters"]
47-
user_messages = request_details["user_message"]
48-
logging.info(f"Processing messages: {user_messages}")
47+
user_message = request_details["user_message"]
48+
logging.info(f"Processing messages: {user_message}")
4949
logging.info(f"Input Parameters: {injected_parameters}")
5050
except json.JSONDecodeError:
5151
# If not JSON array, process as single message
5252
raise ValueError("Could not load message")
5353

5454
cached_results = await self.agent.process_message(
55-
user_messages, injected_parameters
55+
user_message, injected_parameters
5656
)
5757
yield Response(
5858
chat_message=TextMessage(

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,14 @@ async def on_messages_stream(
4343
# Try to parse as JSON first
4444
try:
4545
request_details = json.loads(messages[0].content)
46-
messages = request_details["question"]
46+
message = request_details["user_message"]
4747
except (json.JSONDecodeError, KeyError):
4848
# If not JSON or missing question key, use content directly
49-
messages = messages[0].content
49+
message = messages[0].content
5050

51-
if isinstance(messages, str):
52-
messages = [messages]
53-
elif not isinstance(messages, list):
54-
messages = [str(messages)]
51+
logging.info("Processing message: %s", message)
5552

56-
logging.info(f"Processing questions: {messages}")
57-
58-
final_results = await self.agent.process_message(messages)
53+
final_results = await self.agent.process_message(message)
5954

6055
yield Response(
6156
chat_message=TextMessage(

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def process_user_message(
177177
self,
178178
user_message: str,
179179
injected_parameters: dict = None,
180+
database_results: dict = None,
180181
):
181182
"""Process the complete question through the unified system.
182183
@@ -200,6 +201,9 @@ def process_user_message(
200201
"injected_parameters": injected_parameters,
201202
}
202203

204+
if database_results:
205+
agent_input["database_results"] = database_results
206+
203207
return self.agentic_flow.run_stream(task=json.dumps(agent_input))
204208
finally:
205209
# Restore original environment

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,35 @@ class SqlQueryCacheAgentCustomAgent:
88
def __init__(self):
99
self.sql_connector = ConnectorFactory.get_database_connector()
1010

11-
async def process_message(
12-
self, messages: list[str], injected_parameters: dict
13-
) -> dict:
11+
async def process_message(self, message: str, injected_parameters: dict) -> dict:
1412
# Initialize results dictionary
1513
cached_results = {
1614
"cached_sql_queries_with_schemas_from_cache": [],
1715
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
1816
}
1917

20-
# Process each question sequentially
21-
for message in messages:
22-
# Fetch the queries from the cache based on the question
23-
logging.info(f"Fetching queries from cache for question: {message}")
24-
cached_query = (
25-
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
26-
message, injected_parameters=injected_parameters
27-
)
18+
# Fetch the queries from the cache based on the question
19+
logging.info(f"Fetching queries from cache for question: {message}")
20+
cached_query = (
21+
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
22+
message, injected_parameters=injected_parameters
2823
)
24+
)
2925

30-
# If any question has pre-run results, set the flag
31-
if cached_query.get(
32-
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
33-
False,
34-
):
35-
cached_results[
36-
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
37-
] = True
26+
# If any question has pre-run results, set the flag
27+
if cached_query.get(
28+
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
29+
False,
30+
):
31+
cached_results[
32+
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
33+
] = True
3834

39-
# Add the cached results for this question
40-
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
41-
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
42-
cached_query["cached_sql_queries_with_schemas_from_cache"]
43-
)
35+
# Add the cached results for this question
36+
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
37+
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
38+
cached_query["cached_sql_queries_with_schemas_from_cache"]
39+
)
4440

4541
logging.info(f"Final cached results: {cached_results}")
4642
return cached_results

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,35 @@ def __init__(self, **kwargs):
2222

2323
self.system_prompt = Template(system_prompt).render(kwargs)
2424

25-
async def process_message(self, messages: list[str]) -> dict:
26-
logging.info(f"user inputs: {messages}")
27-
28-
entity_tasks = []
29-
30-
for message in messages:
31-
messages = [
32-
{"role": "system", "content": self.system_prompt},
33-
{"role": "user", "content": message},
34-
]
35-
entity_tasks.append(
36-
self.open_ai_connector.run_completion_request(
37-
messages, response_format=SQLSchemaSelectionAgentOutput
38-
)
39-
)
25+
async def process_message(self, message: str) -> dict:
26+
logging.info(f"Processing message: {message}")
4027

41-
entity_results = await asyncio.gather(*entity_tasks)
28+
messages = [
29+
{"role": "system", "content": self.system_prompt},
30+
{"role": "user", "content": message},
31+
]
32+
entity_result = await self.open_ai_connector.run_completion_request(
33+
messages, response_format=SQLSchemaSelectionAgentOutput
34+
)
4235

4336
entity_search_tasks = []
4437
column_search_tasks = []
4538

46-
for entity_result in entity_results:
47-
logging.info(f"Entity result: {entity_result}")
39+
logging.info(f"Entity result: {entity_result}")
4840

49-
for entity_group in entity_result.entities:
50-
logging.info("Searching for schemas for entity group: %s", entity_group)
51-
entity_search_tasks.append(
52-
self.sql_connector.get_entity_schemas(
53-
" ".join(entity_group), as_json=False
54-
)
41+
for entity_group in entity_result.entities:
42+
logging.info("Searching for schemas for entity group: %s", entity_group)
43+
entity_search_tasks.append(
44+
self.sql_connector.get_entity_schemas(
45+
" ".join(entity_group), as_json=False
5546
)
47+
)
5648

57-
for filter_condition in entity_result.filter_conditions:
58-
logging.info(
59-
"Searching for column values for filter: %s", filter_condition
60-
)
61-
column_search_tasks.append(
62-
self.sql_connector.get_column_values(
63-
filter_condition, as_json=False
64-
)
65-
)
49+
for filter_condition in entity_result.filter_conditions:
50+
logging.info("Searching for column values for filter: %s", filter_condition)
51+
column_search_tasks.append(
52+
self.sql_connector.get_column_values(filter_condition, as_json=False)
53+
)
6654

6755
schemas_results = await asyncio.gather(*entity_search_tasks)
6856
column_value_results = await asyncio.gather(*column_search_tasks)

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,5 @@ system_message:
274274
TERMINATE
275275
</output_format>
276276
"
277+
tools:
278+
- sql_get_entity_schemas_tool

0 commit comments

Comments
 (0)