diff --git a/text_2_sql/.env.example b/text_2_sql/.env.example index 21c358f..9a705d3 100644 --- a/text_2_sql/.env.example +++ b/text_2_sql/.env.example @@ -27,8 +27,12 @@ Text2Sql__Tsql__ConnectionString= # PostgreSQL Specific Connection Details -Text2Sql__Postgresql__ConnectionString= +Text2Sql__Postgresql__ConnectionString= Text2Sql__Postgresql__Database= +Text2Sql__Postgresql__User= +Text2Sql__Postgresql__Password= +Text2Sql__Postgresql__ServerHostname= +Text2Sql__Postgresql__Port= # Snowflake Specific Connection Details Text2Sql__Snowflake__User= diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 9dcb127..a5b6482 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -41,6 +41,8 @@ def __init__(self, **kwargs): self.kwargs = {**DEFAULT_INJECTED_PARAMETERS, **kwargs} + self._agentic_flow = None + def get_all_agents(self): """Get all agents for the complete flow.""" @@ -97,6 +99,10 @@ def unified_selector(self, messages): @property def agentic_flow(self): """Create the unified flow for the complete process.""" + + if self._agentic_flow is not None: + return self._agentic_flow + flow = SelectorGroupChat( self.get_all_agents(), allow_repeated_speaker=False, @@ -104,7 +110,9 @@ def agentic_flow(self): termination_condition=self.termination_condition, selector_func=self.unified_selector, ) - return flow + + self._agentic_flow = flow + return self._agentic_flow def parse_message_content(self, content): """Parse different message content formats into a dictionary.""" @@ -250,7 +258,7 @@ async def process_user_message( Args: ---- task (str): The user message to process. - chat_history (list[str], optional): The chat history. Defaults to None. + chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: @@ -262,17 +270,23 @@ async def process_user_message( agent_input = { "message": message_payload.body.user_message, - "chat_history": {}, "injected_parameters": message_payload.body.injected_parameters, } + latest_state = None if chat_history is not None: # Update input - for idx, chat in enumerate(chat_history): - if chat.root.payload_type == PayloadType.USER_MESSAGE: - # For now only consider the user query - chat_history_key = f"chat_{idx}" - agent_input[chat_history_key] = chat.root.body.user_message + for chat in reversed(chat_history): + if chat.root.payload_type in [ + PayloadType.ANSWER_WITH_SOURCES, + PayloadType.DISAMBIGUATION_REQUESTS, + ]: + latest_state = chat.body.assistant_state + break + + # TODO: Trim the chat history to the last message from the user + if latest_state is not None: + await self.agentic_flow.load_state(latest_state) async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -312,6 +326,22 @@ async def process_user_message( logging.error("Unexpected TaskResult: %s", message) raise ValueError("Unexpected TaskResult") - if payload is not None: + if ( + payload is not None + and payload.payload_type is PayloadType.PROCESSING_UPDATE + ): logging.debug("Payload: %s", payload) yield payload + + # Return the final payload + if ( + payload is not None + and payload.payload_type is not PayloadType.PROCESSING_UPDATE + ): + # Get the state + assistant_state = await self.agentic_flow.save_state() + payload.body.assistant_state = assistant_state + + logging.debug("Final Payload: %s", payload) + + yield payload diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py index 4192e8d..396d38b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py @@ -6,7 +6,7 @@ import os import logging import json - +from urllib.parse import urlparse from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields @@ -66,10 +66,35 @@ async def query_execution( """ logging.info(f"Running query: {sql_query}") results = [] - connection_string = os.environ["Text2Sql__Postgresql__ConnectionString"] + + if "Text2Sql__Postgresql__ConnectionString" in os.environ: + logging.info("Postgresql Connection string found in environment variables.") + + p = urlparse(os.environ["Text2Sql__Postgresql__ConnectionString"]) + + postgres_connections = { + "dbname": p.path[1:], + "user": p.username, + "password": p.password, + "port": p.port, + "host": p.hostname, + } + else: + logging.warning( + "Postgresql Connection string not found in environment variables. Using individual variables." + ) + postgres_connections = { + "dbname": os.environ["Text2Sql__Postgresql__Database"], + "user": os.environ["Text2Sql__Postgresql__User"], + "password": os.environ["Text2Sql__Postgresql__Password"], + "port": os.environ["Text2Sql__Postgresql__Port"], + "host": os.environ["Text2Sql__Postgresql__ServerHostname"], + } # Establish an asynchronous connection to the PostgreSQL database - async with await psycopg.AsyncConnection.connect(connection_string) as conn: + async with await psycopg.AsyncConnection.connect( + **postgres_connections + ) as conn: # Create an asynchronous cursor async with conn.cursor() as cursor: await cursor.execute(sql_query) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index ce7fa65..c68d31a 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -17,7 +17,7 @@ class PayloadSource(StrEnum): USER = "user" - AGENT = "agent" + ASSISTANT = "assistant" class PayloadType(StrEnum): @@ -42,11 +42,13 @@ class PayloadBase(InteractionPayloadBase): payload_type: PayloadType = Field(..., alias="payloadType") payload_source: PayloadSource = Field(..., alias="payloadSource") + body: InteractionPayloadBase | None = Field(default=None) + class DismabiguationRequestsPayload(InteractionPayloadBase): class Body(InteractionPayloadBase): class DismabiguationRequest(InteractionPayloadBase): - agent_question: str | None = Field(..., alias="agentQuestion") + ASSISTANT_question: str | None = Field(..., alias="ASSISTANTQuestion") user_choices: list[str] | None = Field(default=None, alias="userChoices") disambiguation_requests: list[DismabiguationRequest] | None = Field( @@ -55,12 +57,13 @@ class DismabiguationRequest(InteractionPayloadBase): decomposed_user_messages: list[list[str]] = Field( default_factory=list, alias="decomposedUserMessages" ) + assistant_state: dict | None = Field(default=None, alias="assistantState") payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field( PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType" ) - payload_source: Literal[PayloadSource.AGENT] = Field( - default=PayloadSource.AGENT, alias="payloadSource" + payload_source: Literal[PayloadSource.ASSISTANT] = Field( + default=PayloadSource.ASSISTANT, alias="payloadSource" ) body: Body | None = Field(default=None) @@ -83,12 +86,13 @@ class Source(InteractionPayloadBase): default_factory=list, alias="decomposedUserMessages" ) sources: list[Source] = Field(default_factory=list) + assistant_state: dict | None = Field(default=None, alias="assistantState") payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field( PayloadType.ANSWER_WITH_SOURCES, alias="payloadType" ) - payload_source: Literal[PayloadSource.AGENT] = Field( - PayloadSource.AGENT, alias="payloadSource" + payload_source: Literal[PayloadSource.ASSISTANT] = Field( + PayloadSource.ASSISTANT, alias="payloadSource" ) body: Body | None = Field(default=None) @@ -108,8 +112,8 @@ class Body(InteractionPayloadBase): payload_type: Literal[PayloadType.PROCESSING_UPDATE] = Field( PayloadType.PROCESSING_UPDATE, alias="payloadType" ) - payload_source: Literal[PayloadSource.AGENT] = Field( - PayloadSource.AGENT, alias="payloadSource" + payload_source: Literal[PayloadSource.ASSISTANT] = Field( + PayloadSource.ASSISTANT, alias="payloadSource" ) body: Body | None = Field(default=None)