From 016003f5e7d854c3030a2b1141ec3b4137fc2b3d Mon Sep 17 00:00:00 2001 From: Kristian Nylund Date: Wed, 29 Jan 2025 13:44:52 +0200 Subject: [PATCH 1/8] added state store --- .../autogen_text_2_sql/autogen_text_2_sql.py | 31 +++++++------------ .../src/autogen_text_2_sql/state_store.py | 22 +++++++++++++ .../autogen/src/autogen_text_2_sql/test.py | 6 ++++ .../payloads/interaction_payloads.py | 2 -- 4 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 text_2_sql/autogen/src/autogen_text_2_sql/state_store.py create mode 100644 text_2_sql/autogen/src/autogen_text_2_sql/test.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index a5b6482..08c8ef0 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -12,6 +12,7 @@ from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( ParallelQuerySolvingAgent, ) +from state_store import StateStore from autogen_agentchat.messages import TextMessage import json import os @@ -31,9 +32,13 @@ class AutoGenText2Sql: - def __init__(self, **kwargs): + def __init__(self, state_store : StateStore, **kwargs): self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() + if not state_store: + raise ValueError("State store must be provided") + self.state_store = state_store + if "use_case" not in kwargs: logging.warning( "No use case provided. It is advised to provide a use case to help the LLM reason." @@ -250,15 +255,15 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload: async def process_user_message( self, + thread_id: str, message_payload: UserMessagePayload, - chat_history: list[InteractionPayload] = None, ) -> AsyncGenerator[InteractionPayload, None]: """Process the complete message through the unified system. Args: ---- + thread_id (str): The ID of the thread the message belongs to. task (str): The user message to process. - chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message. injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None. Returns: @@ -266,27 +271,15 @@ async def process_user_message( dict: The response from the system. """ logging.info("Processing message: %s", message_payload.body.user_message) - logging.info("Chat history: %s", chat_history) agent_input = { "message": message_payload.body.user_message, "injected_parameters": message_payload.body.injected_parameters, } - latest_state = None - if chat_history is not None: - # Update input - for chat in reversed(chat_history): - if chat.root.payload_type in [ - PayloadType.ANSWER_WITH_SOURCES, - PayloadType.DISAMBIGUATION_REQUESTS, - ]: - latest_state = chat.body.assistant_state - break - - # TODO: Trim the chat history to the last message from the user - if latest_state is not None: - await self.agentic_flow.load_state(latest_state) + state = self.state_store.get_state(thread_id) + if state is not None: + await self.agentic_flow.load_state(state) async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)): logging.debug("Message: %s", message) @@ -340,7 +333,7 @@ async def process_user_message( ): # Get the state assistant_state = await self.agentic_flow.save_state() - payload.body.assistant_state = assistant_state + self.state_store.save_state(thread_id, assistant_state) logging.debug("Final Payload: %s", payload) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py new file mode 100644 index 0000000..c4fa30c --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + +class StateStore(ABC): + @abstractmethod + def get_state(self, thread_id): + pass + + @abstractmethod + def save_state(self, thread_id, state): + pass + + +class InMemoryStateStore(StateStore): + def __init__(self): + # Replace with a caching library or something to have some sort of expiry for entries so this doesn't grow forever + self.cache = {} + + def get_state(self, thread_id): + return self.cache.get(thread_id) + + def save_state(self, thread_id, state): + self.cache[thread_id] = state diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/test.py b/text_2_sql/autogen/src/autogen_text_2_sql/test.py new file mode 100644 index 0000000..23018ca --- /dev/null +++ b/text_2_sql/autogen/src/autogen_text_2_sql/test.py @@ -0,0 +1,6 @@ +from state_store import InMemoryStateStore + +x=InMemoryStateStore() +print(x.get_state("1")) +x.save_state("1", {'x':2}) +print(x.get_state("1")) \ No newline at end of file diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py index c68d31a..80f4a20 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py @@ -57,7 +57,6 @@ class DismabiguationRequest(InteractionPayloadBase): decomposed_user_messages: list[list[str]] = Field( default_factory=list, alias="decomposedUserMessages" ) - assistant_state: dict | None = Field(default=None, alias="assistantState") payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field( PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType" @@ -86,7 +85,6 @@ class Source(InteractionPayloadBase): default_factory=list, alias="decomposedUserMessages" ) sources: list[Source] = Field(default_factory=list) - assistant_state: dict | None = Field(default=None, alias="assistantState") payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field( PayloadType.ANSWER_WITH_SOURCES, alias="payloadType" From 31d82b5ac3613a683eceb6418ab2395c4a810cd8 Mon Sep 17 00:00:00 2001 From: Kristian Nylund Date: Wed, 29 Jan 2025 14:27:53 +0200 Subject: [PATCH 2/8] added expiry for in-memory state store --- text_2_sql/autogen/pyproject.toml | 1 + .../src/autogen_text_2_sql/__init__.py | 3 +++ .../src/autogen_text_2_sql/state_store.py | 4 ++-- uv.lock | 21 ++++++++++++++----- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index 174aa9a..950bdca 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "text_2_sql_core", "sqlparse>=0.4.4", "nltk>=3.8.1", + "cachetools>=5.5.1", ] [dependency-groups] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py index 03e6104..cc23dea 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql +from autogen_text_2_sql.state_store import InMemoryStateStore + from text_2_sql_core.payloads.interaction_payloads import ( UserMessagePayload, DismabiguationRequestsPayload, @@ -16,4 +18,5 @@ "AnswerWithSourcesPayload", "ProcessingUpdatePayload", "InteractionPayload", + "InMemoryStateStore", ] diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py index c4fa30c..7fa7c74 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from cachetools import TTLCache class StateStore(ABC): @abstractmethod @@ -12,8 +13,7 @@ def save_state(self, thread_id, state): class InMemoryStateStore(StateStore): def __init__(self): - # Replace with a caching library or something to have some sort of expiry for entries so this doesn't grow forever - self.cache = {} + self.cache = TTLCache(maxsize=1000, ttl=4*60*60) # 4 hours def get_state(self, thread_id): return self.cache.get(thread_id) diff --git a/uv.lock b/uv.lock index afd07b1..a1d59e4 100644 --- a/uv.lock +++ b/uv.lock @@ -317,6 +317,7 @@ dependencies = [ { name = "autogen-agentchat" }, { name = "autogen-core" }, { name = "autogen-ext", extra = ["azure", "openai"] }, + { name = "cachetools" }, { name = "grpcio" }, { name = "nltk" }, { name = "pyyaml" }, @@ -355,6 +356,7 @@ requires-dist = [ { name = "autogen-agentchat", specifier = "==0.4.2" }, { name = "autogen-core", specifier = "==0.4.2" }, { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.2" }, + { name = "cachetools", specifier = ">=5.5.1" }, { name = "grpcio", specifier = ">=1.68.1" }, { name = "nltk", specifier = ">=3.8.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, @@ -637,6 +639,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/bb/bf7aab772a159614954d84aa832c129624ba6c32faa559dfb200a534e50b/bs4-0.0.2-py2.py3-none-any.whl", hash = "sha256:abf8742c0805ef7f662dce4b51cca104cffe52b835238afc169142ab9b3fbccc", size = 1189 }, ] +[[package]] +name = "cachetools" +version = "5.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/74/57df1ab0ce6bc5f6fa868e08de20df8ac58f9c44330c7671ad922d2bbeae/cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95", size = 28044 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/4e/de4ff18bcf55857ba18d3a4bd48c8a9fde6bb0980c9d20b263f05387fd88/cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb", size = 9530 }, +] + [[package]] name = "catalogue" version = "2.0.10" @@ -762,7 +773,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -1059,7 +1070,7 @@ source = { url = "https://github.com/explosion/spacy-models/releases/download/en dependencies = [ { name = "spacy" }, ] -sdist = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1.tar.gz", hash = "sha256:3273a1335fcb688be09949c5cdb73e85eb584ec3dfc50d4338c17daf6ccd4628" } +sdist = { hash = "sha256:3273a1335fcb688be09949c5cdb73e85eb584ec3dfc50d4338c17daf6ccd4628" } [package.metadata] requires-dist = [{ name = "spacy", specifier = ">=3.7.2,<3.8.0" }] @@ -1378,7 +1389,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -2497,7 +2508,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -3786,7 +3797,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ From 09a748c3a35aa52a04d9a7a628c5ca6c39be7cce Mon Sep 17 00:00:00 2001 From: Kristian Nylund Date: Wed, 29 Jan 2025 14:30:16 +0200 Subject: [PATCH 3/8] remove test.py --- text_2_sql/autogen/src/autogen_text_2_sql/test.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 text_2_sql/autogen/src/autogen_text_2_sql/test.py diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/test.py b/text_2_sql/autogen/src/autogen_text_2_sql/test.py deleted file mode 100644 index 23018ca..0000000 --- a/text_2_sql/autogen/src/autogen_text_2_sql/test.py +++ /dev/null @@ -1,6 +0,0 @@ -from state_store import InMemoryStateStore - -x=InMemoryStateStore() -print(x.get_state("1")) -x.save_state("1", {'x':2}) -print(x.get_state("1")) \ No newline at end of file From d2483dd0aa20143ec48103e308fda3f3dd2a787c Mon Sep 17 00:00:00 2001 From: Kristian Nylund Date: Wed, 29 Jan 2025 15:04:20 +0200 Subject: [PATCH 4/8] type annotations --- text_2_sql/autogen/src/autogen_text_2_sql/state_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py index 7fa7c74..d313249 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py @@ -15,8 +15,8 @@ class InMemoryStateStore(StateStore): def __init__(self): self.cache = TTLCache(maxsize=1000, ttl=4*60*60) # 4 hours - def get_state(self, thread_id): + def get_state(self, thread_id: str) -> dict: return self.cache.get(thread_id) - def save_state(self, thread_id, state): + def save_state(self, thread_id: str, state: dict) -> None: self.cache[thread_id] = state From 862d9a1db66ff2db9e757ddf8e458d489ef6ee22 Mon Sep 17 00:00:00 2001 From: Kristian Nylund Date: Wed, 29 Jan 2025 15:31:03 +0200 Subject: [PATCH 5/8] - --- text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 08c8ef0..dd73fbe 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -12,7 +12,7 @@ from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import ( ParallelQuerySolvingAgent, ) -from state_store import StateStore +from autogen_text_2_sql.state_store import StateStore from autogen_agentchat.messages import TextMessage import json import os From 4f45f7075e206a2b55f2a79a2e2215efa4b8792d Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Wed, 29 Jan 2025 14:39:23 +0000 Subject: [PATCH 6/8] Update notebook sample --- ...tion 5 - Agentic Vector Based Text2SQL.ipynb | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb index 3f0a4da..dcc5bd7 100644 --- a/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb +++ b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb @@ -52,7 +52,8 @@ "source": [ "import dotenv\n", "import logging\n", - "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload" + "from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n", + "from autogen_text_2_sql.state_store import InMemoryStateStore" ] }, { @@ -86,16 +87,10 @@ "metadata": {}, "outputs": [], "source": [ - "agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")" + "# The state store allows AutoGen to store the states in memory across invocation. Whilst not neccessary, you can replace it with your own implementation that is backed by a database or file system. \n", + "agentic_text_2_sql = AutoGenText2Sql(state_store=InMemoryStateStore(), use_case=\"Analysing sales data\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -109,7 +104,7 @@ "metadata": {}, "outputs": [], "source": [ - "async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"what are the total sales\")):\n", + "async for message in agentic_text_2_sql.process_user_message(thread_id=\"1\", message_payload=UserMessagePayload(user_message=\"what are the total sales\")):\n", " logging.info(\"Received %s Message from Text2SQL System\", message)" ] }, @@ -137,7 +132,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.11.2" } }, "nbformat": 4, From 7a637447f197838c92875e5373e5d2a0764d3f11 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Wed, 29 Jan 2025 14:42:55 +0000 Subject: [PATCH 7/8] Update README --- text_2_sql/autogen/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/text_2_sql/autogen/README.md b/text_2_sql/autogen/README.md index a3eb1da..d6e4dd3 100644 --- a/text_2_sql/autogen/README.md +++ b/text_2_sql/autogen/README.md @@ -121,6 +121,10 @@ Contains specialized agent implementations: - **sql_schema_selection_agent.py:** Handles schema selection and management - **answer_and_sources_agent.py:** Formats and standardizes final outputs +## State Store + +To enable the [AutoGen State](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.state.html) to be tracked across invocations, a state store implementation must be provided. A basic `InMemoryStateStore` is provided, but this can be replaced with an implementation for a database or file system for when the Agentic System is running behind an API. This enables the AutoGen state to be saved behind the scenes and recalled later when the message is part of the same thread. A `thread_id` must be provided to the entrypoint. + ## Configuration The system behavior can be controlled through environment variables: From a94afc0dad520a55630aa975f636cf92f640bf49 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Wed, 29 Jan 2025 14:57:15 +0000 Subject: [PATCH 8/8] Run formatter --- .../autogen/src/autogen_text_2_sql/autogen_text_2_sql.py | 2 +- text_2_sql/autogen/src/autogen_text_2_sql/state_store.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index dd73fbe..98d3821 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -32,7 +32,7 @@ class AutoGenText2Sql: - def __init__(self, state_store : StateStore, **kwargs): + def __init__(self, state_store: StateStore, **kwargs): self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper() if not state_store: diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py index d313249..849bc13 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/state_store.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from cachetools import TTLCache + class StateStore(ABC): @abstractmethod def get_state(self, thread_id): @@ -13,7 +14,7 @@ def save_state(self, thread_id, state): class InMemoryStateStore(StateStore): def __init__(self): - self.cache = TTLCache(maxsize=1000, ttl=4*60*60) # 4 hours + self.cache = TTLCache(maxsize=1000, ttl=4 * 60 * 60) # 4 hours def get_state(self, thread_id: str) -> dict: return self.cache.get(thread_id)