Skip to content

Adds data contract and reduces agent calls. #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"source": [
"import dotenv\n",
"import logging\n",
"from autogen_text_2_sql import AutoGenText2Sql"
"from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody"
]
},
{
Expand Down Expand Up @@ -100,7 +100,7 @@
"metadata": {},
"outputs": [],
"source": [
"async for message in agentic_text_2_sql.process_question(question=\"What total number of orders in June 2008?\"):\n",
"async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n",
" logging.info(\"Received %s Message from Text2SQL System\", message)"
]
},
Expand Down
5 changes: 4 additions & 1 deletion text_2_sql/autogen/src/autogen_text_2_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
from text_2_sql_core.payloads.agent_response import AgentRequestBody

__all__ = ["AutoGenText2Sql"]
__all__ = ["AutoGenText2Sql", "AgentRequestBody"]
68 changes: 44 additions & 24 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
import os
from datetime import datetime

from text_2_sql_core.payloads import (
from text_2_sql_core.payloads.agent_response import (
AgentResponse,
AgentRequestBody,
AnswerWithSources,
UserInformationRequest,
Source,
DismabiguationRequests,
)
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
from text_2_sql_core.payloads.processing_update import (
ProcessingUpdateBody,
ProcessingUpdate,
ChatHistoryItem,
)
from autogen_agentchat.base import Response, TaskResult
from typing import AsyncGenerator
Expand Down Expand Up @@ -123,6 +129,16 @@ def agentic_flow(self):
)
return flow

def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests:
"""Extract the disambiguation request from the answer."""

disambiguation_request = messages[-1].content

# TODO: Properly extract the disambiguation request
return DismabiguationRequests(
disambiguation_request=disambiguation_request,
)

def extract_sources(self, messages: list) -> AnswerWithSources:
"""Extract the sources from the answer."""

Expand All @@ -147,10 +163,10 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
for sql_query_result in sql_query_result_list:
logging.info("SQL Query Result: %s", sql_query_result)
sources.append(
{
"sql_query": sql_query_result["sql_query"],
"sql_rows": sql_query_result["sql_rows"],
}
Source(
sql_query=sql_query_result["sql_query"],
sql_rows=sql_query_result["sql_rows"],
)
)

except json.JSONDecodeError:
Expand All @@ -164,10 +180,9 @@ def extract_sources(self, messages: list) -> AnswerWithSources:

async def process_question(
self,
question: str,
request: AgentRequestBody,
chat_history: list[ChatHistoryItem] = None,
injected_parameters: dict = None,
) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]:
) -> AsyncGenerator[AgentResponse | ProcessingUpdate, None]:
"""Process the complete question through the unified system.

Args:
Expand All @@ -180,58 +195,63 @@ async def process_question(
-------
dict: The response from the system.
"""
logging.info("Processing question: %s", question)
logging.info("Processing question: %s", request.question)
logging.info("Chat history: %s", chat_history)

agent_input = {
"question": question,
"question": request.question,
"chat_history": {},
"injected_parameters": injected_parameters,
"injected_parameters": request.injected_parameters,
}

if chat_history is not None:
# Update input
for idx, chat in enumerate(chat_history):
# For now only consider the user query
agent_input[f"chat_{idx}"] = chat.user_query
agent_input[f"chat_{idx}"] = chat.request.question

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)

payload = None

if isinstance(message, TextMessage):
processing_update = None
if message.source == "query_rewrite_agent":
# If the message is from the query_rewrite_agent, we need to update the chat history
payload = ProcessingUpdate(
processing_update = ProcessingUpdateBody(
message="Rewriting the query...",
)
elif message.source == "parallel_query_solving_agent":
# If the message is from the parallel_query_solving_agent, we need to update the chat history
payload = ProcessingUpdate(
processing_update = ProcessingUpdateBody(
message="Solving the query...",
)
elif message.source == "answer_agent":
# If the message is from the answer_agent, we need to update the chat history
payload = ProcessingUpdate(
processing_update = ProcessingUpdateBody(
message="Generating the answer...",
)

if processing_update is not None:
payload = ProcessingUpdate(
processing_update=processing_update,
)

elif isinstance(message, TaskResult):
# Now we need to return the final answer or the disambiguation request
logging.info("TaskResult: %s", message)

response = None
if message.messages[-1].source == "answer_agent":
# If the message is from the answer_agent, we need to return the final answer
payload = self.extract_sources(message.messages)
response = self.extract_sources(message.messages)
elif message.messages[-1].source == "parallel_query_solving_agent":
payload = UserInformationRequest(
**json.loads(message.messages[-1].content),
)
# Load into disambiguation request
response = self.extract_disambiguation_request(message.messages)
else:
logging.error("Unexpected TaskResult: %s", message)
raise ValueError("Unexpected TaskResult")

payload = AgentResponse(request=request, response=response)

if payload is not None:
logging.debug("Payload: %s", payload)
yield payload
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,13 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
elif tool_name == "sql_get_entity_schemas_tool":
return FunctionToolAlias(
sql_helper.get_entity_schemas,
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
)
elif tool_name == "sql_get_column_values_tool":
return FunctionToolAlias(
ai_search_helper.get_column_values,
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
)
elif tool_name == "current_datetime_tool":
return FunctionToolAlias(
sql_helper.get_current_datetime,
description="Gets the current date and time.",
)
else:
raise ValueError(f"Tool {tool_name} not found")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ async def on_messages_stream(
last_response = messages[-1].content
parameter_input = messages[0].content
try:
user_parameters = json.loads(parameter_input)["parameters"]
injected_parameters = json.loads(parameter_input)["injected_parameters"]
except json.JSONDecodeError:
logging.error("Error decoding the user parameters.")
user_parameters = {}
injected_parameters = {}

# Load the json of the last message to populate the final output object
query_rewrites = json.loads(last_response)
Expand Down Expand Up @@ -75,7 +75,7 @@ async def consume_inner_messages_from_agentic_flow(
if isinstance(inner_message, TaskResult) is False:
try:
inner_message = json.loads(inner_message.content)
logging.info(f"Loaded: {inner_message}")
logging.info(f"Inner Loaded: {inner_message}")

# Search for specific message types and add them to the final output object
if (
Expand All @@ -91,6 +91,21 @@ async def consume_inner_messages_from_agentic_flow(
}
)

if ("contains_pre_run_results" in inner_message) and (
inner_message["contains_pre_run_results"] is True
):
for pre_run_sql_query, pre_run_result in inner_message[
"cached_questions_and_schemas"
].items():
database_results[identifier].append(
{
"sql_query": pre_run_sql_query.replace(
"\n", " "
),
"sql_rows": pre_run_result["sql_rows"],
}
)

except (JSONDecodeError, TypeError) as e:
logging.error("Could not load message: %s", inner_message)
logging.warning(f"Error processing message: {e}")
Expand All @@ -113,13 +128,15 @@ async def consume_inner_messages_from_agentic_flow(
self.engine_specific_rules, **self.kwargs
)

identifier = ", ".join(query_rewrite)

# Launch tasks for each sub-query
inner_solving_generators.append(
consume_inner_messages_from_agentic_flow(
inner_autogen_text_2_sql.process_question(
question=query_rewrite, parameters=user_parameters
question=query_rewrite, injected_parameters=injected_parameters
),
query_rewrite,
identifier,
database_results,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,55 +39,46 @@ async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Get the decomposed questions from the query_rewrite_agent
parameter_input = messages[0].content
last_response = messages[-1].content
try:
user_questions = json.loads(last_response)
injected_parameters = json.loads(parameter_input)["injected_parameters"]
request_details = json.loads(messages[0].content)
injected_parameters = request_details["injected_parameters"]
user_questions = request_details["question"]
logging.info(f"Processing questions: {user_questions}")
logging.info(f"Input Parameters: {injected_parameters}")
except json.JSONDecodeError:
# If not JSON array, process as single question
raise ValueError("Could not load message")

# Initialize results dictionary
cached_results = {
"cached_questions_and_schemas": [],
"contains_pre_run_results": False,
}

# Process each question sequentially
for question in user_questions:
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {question}")
cached_query = await self.sql_connector.fetch_queries_from_cache(
question, injected_parameters=injected_parameters
)
# Initialize results dictionary
cached_results = {
"cached_questions_and_schemas": [],
"contains_pre_run_results": False,
}

# If any question has pre-run results, set the flag
if cached_query.get("contains_pre_run_results", False):
cached_results["contains_pre_run_results"] = True
# Process each question sequentially
for question in user_questions:
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {question}")
cached_query = await self.sql_connector.fetch_queries_from_cache(
question, injected_parameters=injected_parameters
)

# Add the cached results for this question
if cached_query.get("cached_questions_and_schemas"):
cached_results["cached_questions_and_schemas"].extend(
cached_query["cached_questions_and_schemas"]
)
# If any question has pre-run results, set the flag
if cached_query.get("contains_pre_run_results", False):
cached_results["contains_pre_run_results"] = True

logging.info(f"Final cached results: {cached_results}")
yield Response(
chat_message=TextMessage(
content=json.dumps(cached_results), source=self.name
)
)
except json.JSONDecodeError:
# If not JSON array, process as single question
logging.info(f"Processing single question: {last_response}")
cached_queries = await self.sql_connector.fetch_queries_from_cache(
last_response
)
yield Response(
chat_message=TextMessage(
content=json.dumps(cached_queries), source=self.name
# Add the cached results for this question
if cached_query.get("cached_questions_and_schemas"):
cached_results["cached_questions_and_schemas"].extend(
cached_query["cached_questions_and_schemas"]
)

logging.info(f"Final cached results: {cached_results}")
yield Response(
chat_message=TextMessage(
content=json.dumps(cached_results), source=self.name
)
)

async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
Loading
Loading