diff --git a/text_2_sql/.env.example b/text_2_sql/.env.example index 35bbf3e..e87265c 100644 --- a/text_2_sql/.env.example +++ b/text_2_sql/.env.example @@ -5,7 +5,7 @@ Text2Sql__DatabaseEngine= # TSQL or Postgres or Snowflake or Dat Text2Sql__UseQueryCache= # True or False Text2Sql__PreRunQueryCache= # True or False Text2Sql__UseColumnValueStore= # True or False -Text2Sql__GenerateFollowUpQuestions= # True or False +Text2Sql__GenerateFollowUpSuggestions= # True or False # Open AI Connection Details OpenAI__CompletionDeployment= diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 02c1fdc..4ea0bf5 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -48,8 +48,8 @@ def __init__(self, state_store: StateStore, **kwargs): self._agentic_flow = None - self._generate_follow_up_questions = ( - os.environ.get("Text2Sql__GenerateFollowUpQuestions", "True").lower() + self._generate_follow_up_suggestions = ( + os.environ.get("Text2Sql__GenerateFollowUpSuggestions", "True").lower() == "true" ) @@ -62,9 +62,9 @@ def get_all_agents(self): parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs) - if self._generate_follow_up_questions: + if self._generate_follow_up_suggestions: answer_agent = LLMAgentCreator.create( - "answer_with_follow_up_questions_agent", **self.kwargs + "answer_with_follow_up_suggestions_agent", **self.kwargs ) else: answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs) @@ -82,7 +82,7 @@ def termination_condition(self): """Define the termination condition for the chat.""" termination = ( SourceMatchTermination("answer_agent") - | SourceMatchTermination("answer_with_follow_up_questions_agent") + | SourceMatchTermination("answer_with_follow_up_suggestions_agent") # | TextMentionTermination( # "[]", # sources=["user_message_rewrite_agent"], @@ -110,9 +110,9 @@ def unified_selector(self, messages): # Handle transition after parallel query solving elif ( current_agent == "parallel_query_solving_agent" - and self._generate_follow_up_questions + and self._generate_follow_up_suggestions ): - decision = "answer_with_follow_up_questions_agent" + decision = "answer_with_follow_up_suggestions_agent" elif current_agent == "parallel_query_solving_agent": decision = "answer_agent" @@ -325,7 +325,7 @@ async def process_user_message( ) elif ( message.source == "answer_agent" - or message.source == "answer_with_follow_up_questions_agent" + or message.source == "answer_with_follow_up_suggestions_agent" ): payload = ProcessingUpdatePayload( message="Generating the answer...", @@ -338,7 +338,7 @@ async def process_user_message( if ( message.messages[-1].source == "answer_agent" or message.messages[-1].source - == "answer_with_follow_up_questions_agent" + == "answer_with_follow_up_suggestions_agent" ): # If the message is from the answer_agent, we need to return the final answer payload = self.extract_answer_payload(message.messages) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index 2bb7563..4ce60ba 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -9,7 +9,7 @@ import logging from text_2_sql_core.structured_outputs import ( AnswerAgentOutput, - AnswerWithFollowUpQuestionsAgentOutput, + AnswerWithFollowUpSuggestionsAgentOutput, UserMessageRewriteAgentOutput, ) from autogen_core.model_context import BufferedChatCompletionContext @@ -117,8 +117,8 @@ def create(cls, name: str, **kwargs) -> AssistantAgent: # Import the structured output agent if name == "answer_agent": structured_output = AnswerAgentOutput - elif name == "answer_with_follow_up_questions_agent": - structured_output = AnswerWithFollowUpQuestionsAgentOutput + elif name == "answer_with_follow_up_suggestions_agent": + structured_output = AnswerWithFollowUpSuggestionsAgentOutput elif name == "user_message_rewrite_agent": structured_output = UserMessageRewriteAgentOutput diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index 7c5b930..62f49c5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -81,8 +81,8 @@ class Source(InteractionPayloadBase): answer: str steps: list[list[str]] = Field(default_factory=list, alias="Steps") sources: list[Source] = Field(default_factory=list) - follow_up_questions: list[str] | None = Field( - default=None, alias="followUpQuestions" + follow_up_suggestions: list[str] | None = Field( + default=None, alias="followUpSuggestions" ) payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field( diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_with_follow_up_questions_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_with_follow_up_suggestions_agent.yaml similarity index 98% rename from text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_with_follow_up_questions_agent.yaml rename to text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_with_follow_up_suggestions_agent.yaml index 1dacb23..97ef481 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_with_follow_up_questions_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_with_follow_up_suggestions_agent.yaml @@ -25,7 +25,7 @@ system_message: | { "answer": "The response to the user's question.", - "follow_up_questions": [ + "follow_up_suggestions": [ "Follow-up question 1", "Follow-up question 2", "Follow-up question 3" diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py index 56ca200..daaf954 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py @@ -6,14 +6,14 @@ from text_2_sql_core.structured_outputs.user_message_rewrite_agent import ( UserMessageRewriteAgentOutput, ) -from text_2_sql_core.structured_outputs.answer_with_follow_up_questions_agent import ( - AnswerWithFollowUpQuestionsAgentOutput, +from text_2_sql_core.structured_outputs.answer_with_follow_up_suggestions_agent import ( + AnswerWithFollowUpSuggestionsAgentOutput, ) from text_2_sql_core.structured_outputs.answer_agent import AnswerAgentOutput __all__ = [ "AnswerAgentOutput", - "AnswerWithFollowUpQuestionsAgentOutput", + "AnswerWithFollowUpSuggestionsAgentOutput", "SQLSchemaSelectionAgentOutput", "UserMessageRewriteAgentOutput", ] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/answer_with_follow_up_questions_agent.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/answer_with_follow_up_suggestions_agent.py similarity index 66% rename from text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/answer_with_follow_up_questions_agent.py rename to text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/answer_with_follow_up_suggestions_agent.py index 4c747ee..d6af476 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/answer_with_follow_up_questions_agent.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/answer_with_follow_up_suggestions_agent.py @@ -3,8 +3,8 @@ from pydantic import BaseModel -class AnswerWithFollowUpQuestionsAgentOutput(BaseModel): +class AnswerWithFollowUpSuggestionsAgentOutput(BaseModel): """The output of the answer agent with follow up questions.""" answer: str - follow_up_questions: list[str] + follow_up_suggestions: list[str]