diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 3f0a4da..dcc5bd7 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -52,7 +52,8 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload" + "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n", + "from autogen_text_2_sql.state_store import InMemoryStateStore" ] }, { @@ -86,16 +87,10 @@ "metadata": {}, "outputs": [], "source": [ - "agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")" + "# The state store allows AutoGen to store the states in memory across invocation. Whilst not neccessary, you can replace it with your own implementation that is backed by a database or file system. \n", + "agentic_text_2_sql = AutoGenText2Sql(state_store=InMemoryStateStore(), use_case=\"Analysing sales data\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -109,7 +104,7 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"what are the total sales\")):\n", + "async for message in agentic_text_2_sql.process_user_message(thread_id=\"1\", message_payload=UserMessagePayload(user_message=\"what are the total sales\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, @@ -137,7 +132,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.11.2" } }, "nbformat": 4, diff --git a/text_2_sql/autogen/README.md b/text_2_sql/autogen/README.md index a3eb1da..d6e4dd3 100644 --- a/text_2_sql/autogen/README.md +++ b/text_2_sql/autogen/README.md @@ -121,6 +121,10 @@ Contains specialized agent implementations: - **sql_schema_selection_agent.py:** Handles schema selection and management - **answer_and_sources_agent.py:** Formats and standardizes final outputs +## State Store + +To enable the [AutoGen State](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.state.html) to be tracked across invocations, a state store implementation must be provided. A basic `InMemoryStateStore` is provided, but this can be replaced with an implementation for a database or file system for when the Agentic System is running behind an API. This enables the AutoGen state to be saved behind the scenes and recalled later when the message is part of the same thread. A `thread_id` must be provided to the entrypoint. + ## Configuration The system behavior can be controlled through environment variables: diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index 174aa9a..950bdca 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "text_2_sql_core", "sqlparse>=0.4.4", "nltk>=3.8.1", + "cachetools>=5.5.1", ] [dependency-groups] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index 03e6104..cc23dea 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql +from autogen_text_2_sql.state_store import InMemoryStateStore + from text_2_sql_core.payloads.interaction_payloads import ( UserMessagePayload, DismabiguationRequestsPayload, @@ -16,4 +18,5 @@ "AnswerWithSourcesPayload", "ProcessingUpdatePayload", "InteractionPayload", + "InMemoryStateStore", ] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index a5b6482..98d3821 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -12,6 +12,7 @@ from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( ParallelQuerySolvingAgent, ) +from autogen_text_2_sql.state_store import StateStore from autogen_agentchat.messages import TextMessage import json import os @@ -31,9 +32,13 @@ class AutoGenText2Sql: - def __init__(self, **kwargs): + def __init__(self, state_store: StateStore, **kwargs): self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() + if not state_store: + raise ValueError("State store must be provided") + self.state_store = state_store + if "use_case" not in kwargs: logging.warning( "No use case provided. It is advised to provide a use case to help the LLM reason." @@ -250,15 +255,15 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: async def process_user_message( self, + thread_id: str, message_payload: UserMessagePayload, - chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: """Process the complete message through the unified system. Args: ---- + thread_id (str): The ID of the thread the message belongs to. task (str): The user message to process. - chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: @@ -266,27 +271,15 @@ async def process_user_message( dict: The response from the system. """ logging.info("Processing message: %s", message_payload.body.user_message) - logging.info("Chat history: %s", chat_history) agent_input = { "message": message_payload.body.user_message, "injected_parameters": message_payload.body.injected_parameters, } - latest_state = None - if chat_history is not None: - # Update input - for chat in reversed(chat_history): - if chat.root.payload_type in [ - PayloadType.ANSWER_WITH_SOURCES, - PayloadType.DISAMBIGUATION_REQUESTS, - ]: - latest_state = chat.body.assistant_state - break - - # TODO: Trim the chat history to the last message from the user - if latest_state is not None: - await self.agentic_flow.load_state(latest_state) + state = self.state_store.get_state(thread_id) + if state is not None: + await self.agentic_flow.load_state(state) async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -340,7 +333,7 @@ async def process_user_message( ): # Get the state assistant_state = await self.agentic_flow.save_state() - payload.body.assistant_state = assistant_state + self.state_store.save_state(thread_id, assistant_state) logging.debug("Final Payload: %s", payload) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py new file mode 100644 index 0000000..849bc13 --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from cachetools import TTLCache + + +class StateStore(ABC): + @abstractmethod + def get_state(self, thread_id): + pass + + @abstractmethod + def save_state(self, thread_id, state): + pass + + +class InMemoryStateStore(StateStore): + def __init__(self): + self.cache = TTLCache(maxsize=1000, ttl=4 * 60 * 60) # 4 hours + + def get_state(self, thread_id: str) -> dict: + return self.cache.get(thread_id) + + def save_state(self, thread_id: str, state: dict) -> None: + self.cache[thread_id] = state diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index c68d31a..80f4a20 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -57,7 +57,6 @@ class DismabiguationRequest(InteractionPayloadBase): decomposed_user_messages: list[list[str]] = Field( default_factory=list, alias="decomposedUserMessages" ) - assistant_state: dict | None = Field(default=None, alias="assistantState") payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field( PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType" @@ -86,7 +85,6 @@ class Source(InteractionPayloadBase): default_factory=list, alias="decomposedUserMessages" ) sources: list[Source] = Field(default_factory=list) - assistant_state: dict | None = Field(default=None, alias="assistantState") payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field( PayloadType.ANSWER_WITH_SOURCES, alias="payloadType" diff --git a/uv.lock b/uv.lock index afd07b1..a1d59e4 100644 --- a/uv.lock +++ b/uv.lock @@ -317,6 +317,7 @@ dependencies = [ { name = "autogen-agentchat" }, { name = "autogen-core" }, { name = "autogen-ext", extra = ["azure", "openai"] }, + { name = "cachetools" }, { name = "grpcio" }, { name = "nltk" }, { name = "pyyaml" }, @@ -355,6 +356,7 @@ requires-dist = [ { name = "autogen-agentchat", specifier = "==0.4.2" }, { name = "autogen-core", specifier = "==0.4.2" }, { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.2" }, + { name = "cachetools", specifier = ">=5.5.1" }, { name = "grpcio", specifier = ">=1.68.1" }, { name = "nltk", specifier = ">=3.8.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, @@ -637,6 +639,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/bb/bf7aab772a159614954d84aa832c129624ba6c32faa559dfb200a534e50b/bs4-0.0.2-py2.py3-none-any.whl", hash = "sha256:abf8742c0805ef7f662dce4b51cca104cffe52b835238afc169142ab9b3fbccc", size = 1189 }, ] +[[package]] +name = "cachetools" +version = "5.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/74/57df1ab0ce6bc5f6fa868e08de20df8ac58f9c44330c7671ad922d2bbeae/cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95", size = 28044 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/4e/de4ff18bcf55857ba18d3a4bd48c8a9fde6bb0980c9d20b263f05387fd88/cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb", size = 9530 }, +] + [[package]] name = "catalogue" version = "2.0.10" @@ -762,7 +773,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -1059,7 +1070,7 @@ source = { url = "https://github.com/explosion/spacy-models/releases/download/en dependencies = [ { name = "spacy" }, ] -sdist = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1.tar.gz", hash = "sha256:3273a1335fcb688be09949c5cdb73e85eb584ec3dfc50d4338c17daf6ccd4628" } +sdist = { hash = "sha256:3273a1335fcb688be09949c5cdb73e85eb584ec3dfc50d4338c17daf6ccd4628" } [package.metadata] requires-dist = [{ name = "spacy", specifier = ">=3.7.2,<3.8.0" }] @@ -1378,7 +1389,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -2497,7 +2508,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -3786,7 +3797,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [