Skip to content

Fix issues with model flow between agents, add structured output modes and add follow up questions #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 4, 2025
1 change: 1 addition & 0 deletions text_2_sql/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Text2Sql__DatabaseEngine=<DatabaseEngine> # TSQL or Postgres or Snowflake or Dat
Text2Sql__UseQueryCache=<Determines if the Query Cache will be used to speed up query generation. Defaults to True.> # True or False
Text2Sql__PreRunQueryCache=<Determines if the results from the Query Cache will be pre-run to speed up answer generation. Defaults to True.> # True or False
Text2Sql__UseColumnValueStore=<Determines if the Column Value Store will be used for schema selection Defaults to True.> # True or False
Text2Sql__GenerateFollowUpQuestions=<Determines if follow up questions will be generated. Defaults to True.> # True or False

# Open AI Connection Details
OpenAI__CompletionDeployment=<openAICompletionDeploymentId. Used for data dictionary creator>
Expand Down
105 changes: 70 additions & 35 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,31 @@ def __init__(self, state_store: StateStore, **kwargs):

self._agentic_flow = None

self._generate_follow_up_questions = (
os.environ.get("Text2Sql__GenerateFollowUpQuestions", "True").lower()
== "true"
)

def get_all_agents(self):
"""Get all agents for the complete flow."""

self.user_message_rewrite_agent = LLMAgentCreator.create(
user_message_rewrite_agent = LLMAgentCreator.create(
"user_message_rewrite_agent", **self.kwargs
)

self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)

self.answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs)
if self._generate_follow_up_questions:
answer_agent = LLMAgentCreator.create(
"answer_with_follow_up_questions_agent", **self.kwargs
)
else:
answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs)

agents = [
self.user_message_rewrite_agent,
self.parallel_query_solving_agent,
self.answer_agent,
user_message_rewrite_agent,
parallel_query_solving_agent,
answer_agent,
]

return agents
Expand All @@ -71,9 +81,16 @@ def get_all_agents(self):
def termination_condition(self):
"""Define the termination condition for the chat."""
termination = (
TextMentionTermination("TERMINATE")
| SourceMatchTermination("answer_agent")
| TextMentionTermination("contains_disambiguation_requests")
SourceMatchTermination("answer_agent")
| SourceMatchTermination("answer_with_follow_up_questions_agent")
# | TextMentionTermination(
# "[]",
# sources=["user_message_rewrite_agent"],
# )
| TextMentionTermination(
"contains_disambiguation_requests",
sources=["parallel_query_solving_agent"],
)
| MaxMessageTermination(5)
)
return termination
Expand All @@ -91,6 +108,11 @@ def unified_selector(self, messages):
elif current_agent == "user_message_rewrite_agent":
decision = "parallel_query_solving_agent"
# Handle transition after parallel query solving
elif (
current_agent == "parallel_query_solving_agent"
and self._generate_follow_up_questions
):
decision = "answer_with_follow_up_questions_agent"
elif current_agent == "parallel_query_solving_agent":
decision = "answer_agent"

Expand Down Expand Up @@ -142,32 +164,35 @@ def parse_message_content(self, content):
# If all parsing attempts fail, return the content as-is
return content

def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]:
"""Extract the decomposed messages from the answer."""
# Only load sub-message results if we have a database result
sub_message_results = self.parse_message_content(messages[1].content)
logging.info("Decomposed Results: %s", sub_message_results)
def last_message_by_agent(self, messages: list, agent_name: str) -> TextMessage:
"""Get the last message by a specific agent."""
for message in reversed(messages):
if message.source == agent_name:
return message.content
return None

decomposed_user_messages = sub_message_results.get(
"decomposed_user_messages", []
def extract_steps(self, messages: list) -> list[list[str]]:
"""Extract the steps messages from the answer."""
# Only load sub-message results if we have a database result
sub_message_results = json.loads(
self.last_message_by_agent(messages, "user_message_rewrite_agent")
)
logging.info("Steps Results: %s", sub_message_results)

logging.debug(
"Returning decomposed_user_messages: %s", decomposed_user_messages
)
steps = sub_message_results.get("steps", [])

logging.debug("Returning steps: %s", steps)

return decomposed_user_messages
return steps

def extract_disambiguation_request(
self, messages: list
) -> DismabiguationRequestsPayload:
"""Extract the disambiguation request from the answer."""
all_disambiguation_requests = self.parse_message_content(messages[-1].content)

decomposed_user_messages = self.extract_decomposed_user_messages(messages)
request_payload = DismabiguationRequestsPayload(
decomposed_user_messages=decomposed_user_messages
)
steps = self.extract_steps(messages)
request_payload = DismabiguationRequestsPayload(steps=steps)

for per_question_disambiguation_request in all_disambiguation_requests[
"disambiguation_requests"
Expand All @@ -187,23 +212,27 @@ def extract_disambiguation_request(

def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
"""Extract the sources from the answer."""
answer = messages[-1].content
sql_query_results = self.parse_message_content(messages[-2].content)
answer_payload = json.loads(messages[-1].content)

logging.info("Answer Payload: %s", answer_payload)
sql_query_results = self.last_message_by_agent(
messages, "parallel_query_solving_agent"
)

try:
if isinstance(sql_query_results, str):
sql_query_results = json.loads(sql_query_results)
elif sql_query_results is None:
sql_query_results = {}
except json.JSONDecodeError:
logging.warning("Unable to read SQL query results: %s", sql_query_results)
sql_query_results = {}

try:
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
steps = self.extract_steps(messages)

logging.info("SQL Query Results: %s", sql_query_results)
payload = AnswerWithSourcesPayload(
answer=answer, decomposed_user_messages=decomposed_user_messages
)
payload = AnswerWithSourcesPayload(**answer_payload, steps=steps)

if not isinstance(sql_query_results, dict):
logging.error(f"Expected dict, got {type(sql_query_results)}")
Expand Down Expand Up @@ -248,10 +277,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:

except Exception as e:
logging.error("Error processing results: %s", str(e))

# Return payload with error context instead of empty
return AnswerWithSourcesPayload(
answer=f"{answer}\nError processing results: {str(e)}"
)
return AnswerWithSourcesPayload(**answer_payload)

async def process_user_message(
self,
Expand Down Expand Up @@ -295,7 +323,10 @@ async def process_user_message(
payload = ProcessingUpdatePayload(
message="Solving the query...",
)
elif message.source == "answer_agent":
elif (
message.source == "answer_agent"
or message.source == "answer_with_follow_up_questions_agent"
):
payload = ProcessingUpdatePayload(
message="Generating the answer...",
)
Expand All @@ -304,7 +335,11 @@ async def process_user_message(
# Now we need to return the final answer or the disambiguation request
logging.info("TaskResult: %s", message)

if message.messages[-1].source == "answer_agent":
if (
message.messages[-1].source == "answer_agent"
or message.messages[-1].source
== "answer_with_follow_up_questions_agent"
):
# If the message is from the answer_agent, we need to return the final answer
payload = self.extract_answer_payload(message.messages)
elif message.messages[-1].source == "parallel_query_solving_agent":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
from jinja2 import Template
import logging
from text_2_sql_core.structured_outputs import (
AnswerAgentOutput,
AnswerWithFollowUpQuestionsAgentOutput,
UserMessageRewriteAgentOutput,
)
from autogen_core.model_context import BufferedChatCompletionContext


class LLMAgentCreator:
Expand Down Expand Up @@ -106,10 +112,22 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
for tool in agent_file["tools"]:
tools.append(cls.get_tool(sql_helper, tool))

structured_output = None
if agent_file.get("structured_output", False):
# Import the structured output agent
if name == "answer_agent":
structured_output = AnswerAgentOutput
elif name == "answer_with_follow_up_questions_agent":
structured_output = AnswerWithFollowUpQuestionsAgentOutput
elif name == "user_message_rewrite_agent":
structured_output = UserMessageRewriteAgentOutput

agent = AssistantAgent(
name=name,
tools=tools,
model_client=LLMModelCreator.get_model(agent_file["model"]),
model_client=LLMModelCreator.get_model(
agent_file["model"], structured_output=structured_output
),
description=cls.get_property_and_render_parameters(
agent_file, "description", kwargs
),
Expand All @@ -118,4 +136,9 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
),
)

if "context_size" in agent_file:
agent.model_context = BufferedChatCompletionContext(
buffer_size=agent_file["context_size"]
)

return agent
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

class LLMModelCreator:
@classmethod
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
def get_model(
cls, model_name: str, structured_output=None
) -> AzureOpenAIChatCompletionClient:
"""Retrieves the model based on the model name.

Args:
Expand All @@ -22,9 +24,9 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
Returns:
AzureOpenAIChatCompletionClient: The model client."""
if model_name == "4o-mini":
return cls.gpt_4o_mini_model()
return cls.gpt_4o_mini_model(structured_output=structured_output)
elif model_name == "4o":
return cls.gpt_4o_model()
return cls.gpt_4o_model(structured_output=structured_output)
else:
raise ValueError(f"Model {model_name} not found")

Expand All @@ -46,7 +48,9 @@ def get_authentication_properties(cls) -> dict:
return token_provider, api_key

@classmethod
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
def gpt_4o_mini_model(
cls, structured_output=None
) -> AzureOpenAIChatCompletionClient:
token_provider, api_key = cls.get_authentication_properties()
return AzureOpenAIChatCompletionClient(
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
Expand All @@ -61,10 +65,11 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
"json_output": True,
},
temperature=0,
response_format=structured_output,
)

@classmethod
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
def gpt_4o_model(cls, structured_output=None) -> AzureOpenAIChatCompletionClient:
token_provider, api_key = cls.get_authentication_properties()
return AzureOpenAIChatCompletionClient(
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
Expand All @@ -79,4 +84,5 @@ def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
"json_output": True,
},
temperature=0,
response_format=structured_output,
)
Loading
Loading