Skip to content

Commit aa36e62

Browse files
krisnylKristian NylundBenConstable9
authored
Add state store for Autogen (#150)
* added state store * added expiry for in-memory state store * remove test.py * type annotations * - * Update notebook sample * Update README * Run formatter --------- Co-authored-by: Kristian Nylund <Kristian.Nylund@microsoft.com> Co-authored-by: Ben Constable <benconstable@microsoft.com>
1 parent c416cc7 commit aa36e62

File tree

8 files changed

+65
-37
lines changed

8 files changed

+65
-37
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
"source": [
5353
"import dotenv\n",
5454
"import logging\n",
55-
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload"
55+
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n",
56+
"from autogen_text_2_sql.state_store import InMemoryStateStore"
5657
]
5758
},
5859
{
@@ -86,16 +87,10 @@
8687
"metadata": {},
8788
"outputs": [],
8889
"source": [
89-
"agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")"
90+
"# The state store allows AutoGen to store the states in memory across invocation. Whilst not neccessary, you can replace it with your own implementation that is backed by a database or file system. \n",
91+
"agentic_text_2_sql = AutoGenText2Sql(state_store=InMemoryStateStore(), use_case=\"Analysing sales data\")"
9092
]
9193
},
92-
{
93-
"cell_type": "code",
94-
"execution_count": null,
95-
"metadata": {},
96-
"outputs": [],
97-
"source": []
98-
},
9994
{
10095
"cell_type": "markdown",
10196
"metadata": {},
@@ -109,7 +104,7 @@
109104
"metadata": {},
110105
"outputs": [],
111106
"source": [
112-
"async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"what are the total sales\")):\n",
107+
"async for message in agentic_text_2_sql.process_user_message(thread_id=\"1\", message_payload=UserMessagePayload(user_message=\"what are the total sales\")):\n",
113108
" logging.info(\"Received %s Message from Text2SQL System\", message)"
114109
]
115110
},
@@ -137,7 +132,7 @@
137132
"name": "python",
138133
"nbconvert_exporter": "python",
139134
"pygments_lexer": "ipython3",
140-
"version": "3.12.8"
135+
"version": "3.11.2"
141136
}
142137
},
143138
"nbformat": 4,

text_2_sql/autogen/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ Contains specialized agent implementations:
121121
- **sql_schema_selection_agent.py:** Handles schema selection and management
122122
- **answer_and_sources_agent.py:** Formats and standardizes final outputs
123123

124+
## State Store
125+
126+
To enable the [AutoGen State](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.state.html) to be tracked across invocations, a state store implementation must be provided. A basic `InMemoryStateStore` is provided, but this can be replaced with an implementation for a database or file system for when the Agentic System is running behind an API. This enables the AutoGen state to be saved behind the scenes and recalled later when the message is part of the same thread. A `thread_id` must be provided to the entrypoint.
127+
124128
## Configuration
125129

126130
The system behavior can be controlled through environment variables:

text_2_sql/autogen/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
"text_2_sql_core",
1818
"sqlparse>=0.4.4",
1919
"nltk>=3.8.1",
20+
"cachetools>=5.5.1",
2021
]
2122

2223
[dependency-groups]

text_2_sql/autogen/src/autogen_text_2_sql/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4+
from autogen_text_2_sql.state_store import InMemoryStateStore
5+
46
from text_2_sql_core.payloads.interaction_payloads import (
57
UserMessagePayload,
68
DismabiguationRequestsPayload,
@@ -16,4 +18,5 @@
1618
"AnswerWithSourcesPayload",
1719
"ProcessingUpdatePayload",
1820
"InteractionPayload",
21+
"InMemoryStateStore",
1922
]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
1313
ParallelQuerySolvingAgent,
1414
)
15+
from autogen_text_2_sql.state_store import StateStore
1516
from autogen_agentchat.messages import TextMessage
1617
import json
1718
import os
@@ -31,9 +32,13 @@
3132

3233

3334
class AutoGenText2Sql:
34-
def __init__(self, **kwargs):
35+
def __init__(self, state_store: StateStore, **kwargs):
3536
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
3637

38+
if not state_store:
39+
raise ValueError("State store must be provided")
40+
self.state_store = state_store
41+
3742
if "use_case" not in kwargs:
3843
logging.warning(
3944
"No use case provided. It is advised to provide a use case to help the LLM reason."
@@ -250,43 +255,31 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
250255

251256
async def process_user_message(
252257
self,
258+
thread_id: str,
253259
message_payload: UserMessagePayload,
254-
chat_history: list[InteractionPayload] = None,
255260
) -> AsyncGenerator[InteractionPayload, None]:
256261
"""Process the complete message through the unified system.
257262
258263
Args:
259264
----
265+
thread_id (str): The ID of the thread the message belongs to.
260266
task (str): The user message to process.
261-
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
262267
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
263268
264269
Returns:
265270
-------
266271
dict: The response from the system.
267272
"""
268273
logging.info("Processing message: %s", message_payload.body.user_message)
269-
logging.info("Chat history: %s", chat_history)
270274

271275
agent_input = {
272276
"message": message_payload.body.user_message,
273277
"injected_parameters": message_payload.body.injected_parameters,
274278
}
275279

276-
latest_state = None
277-
if chat_history is not None:
278-
# Update input
279-
for chat in reversed(chat_history):
280-
if chat.root.payload_type in [
281-
PayloadType.ANSWER_WITH_SOURCES,
282-
PayloadType.DISAMBIGUATION_REQUESTS,
283-
]:
284-
latest_state = chat.body.assistant_state
285-
break
286-
287-
# TODO: Trim the chat history to the last message from the user
288-
if latest_state is not None:
289-
await self.agentic_flow.load_state(latest_state)
280+
state = self.state_store.get_state(thread_id)
281+
if state is not None:
282+
await self.agentic_flow.load_state(state)
290283

291284
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
292285
logging.debug("Message: %s", message)
@@ -340,7 +333,7 @@ async def process_user_message(
340333
):
341334
# Get the state
342335
assistant_state = await self.agentic_flow.save_state()
343-
payload.body.assistant_state = assistant_state
336+
self.state_store.save_state(thread_id, assistant_state)
344337

345338
logging.debug("Final Payload: %s", payload)
346339

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from abc import ABC, abstractmethod
2+
from cachetools import TTLCache
3+
4+
5+
class StateStore(ABC):
6+
@abstractmethod
7+
def get_state(self, thread_id):
8+
pass
9+
10+
@abstractmethod
11+
def save_state(self, thread_id, state):
12+
pass
13+
14+
15+
class InMemoryStateStore(StateStore):
16+
def __init__(self):
17+
self.cache = TTLCache(maxsize=1000, ttl=4 * 60 * 60) # 4 hours
18+
19+
def get_state(self, thread_id: str) -> dict:
20+
return self.cache.get(thread_id)
21+
22+
def save_state(self, thread_id: str, state: dict) -> None:
23+
self.cache[thread_id] = state

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class DismabiguationRequest(InteractionPayloadBase):
5757
decomposed_user_messages: list[list[str]] = Field(
5858
default_factory=list, alias="decomposedUserMessages"
5959
)
60-
assistant_state: dict | None = Field(default=None, alias="assistantState")
6160

6261
payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
6362
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
@@ -86,7 +85,6 @@ class Source(InteractionPayloadBase):
8685
default_factory=list, alias="decomposedUserMessages"
8786
)
8887
sources: list[Source] = Field(default_factory=list)
89-
assistant_state: dict | None = Field(default=None, alias="assistantState")
9088

9189
payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
9290
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"

uv.lock

Lines changed: 16 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)