Skip to content

Add state store for Autogen #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"source": [
"import dotenv\n",
"import logging\n",
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload"
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n",
"from autogen_text_2_sql.state_store import InMemoryStateStore"
]
},
{
Expand Down Expand Up @@ -86,16 +87,10 @@
"metadata": {},
"outputs": [],
"source": [
"agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")"
"# The state store allows AutoGen to store the states in memory across invocation. Whilst not neccessary, you can replace it with your own implementation that is backed by a database or file system. \n",
"agentic_text_2_sql = AutoGenText2Sql(state_store=InMemoryStateStore(), use_case=\"Analysing sales data\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -109,7 +104,7 @@
"metadata": {},
"outputs": [],
"source": [
"async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"what are the total sales\")):\n",
"async for message in agentic_text_2_sql.process_user_message(thread_id=\"1\", message_payload=UserMessagePayload(user_message=\"what are the total sales\")):\n",
" logging.info(\"Received %s Message from Text2SQL System\", message)"
]
},
Expand Down Expand Up @@ -137,7 +132,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
4 changes: 4 additions & 0 deletions text_2_sql/autogen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ Contains specialized agent implementations:
- **sql_schema_selection_agent.py:** Handles schema selection and management
- **answer_and_sources_agent.py:** Formats and standardizes final outputs

## State Store

To enable the [AutoGen State](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.state.html) to be tracked across invocations, a state store implementation must be provided. A basic `InMemoryStateStore` is provided, but this can be replaced with an implementation for a database or file system for when the Agentic System is running behind an API. This enables the AutoGen state to be saved behind the scenes and recalled later when the message is part of the same thread. A `thread_id` must be provided to the entrypoint.

## Configuration

The system behavior can be controlled through environment variables:
Expand Down
1 change: 1 addition & 0 deletions text_2_sql/autogen/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"text_2_sql_core",
"sqlparse>=0.4.4",
"nltk>=3.8.1",
"cachetools>=5.5.1",
]

[dependency-groups]
Expand Down
3 changes: 3 additions & 0 deletions text_2_sql/autogen/src/autogen_text_2_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
from autogen_text_2_sql.state_store import InMemoryStateStore

from text_2_sql_core.payloads.interaction_payloads import (
UserMessagePayload,
DismabiguationRequestsPayload,
Expand All @@ -16,4 +18,5 @@
"AnswerWithSourcesPayload",
"ProcessingUpdatePayload",
"InteractionPayload",
"InMemoryStateStore",
]
31 changes: 12 additions & 19 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
ParallelQuerySolvingAgent,
)
from autogen_text_2_sql.state_store import StateStore
from autogen_agentchat.messages import TextMessage
import json
import os
Expand All @@ -31,9 +32,13 @@


class AutoGenText2Sql:
def __init__(self, **kwargs):
def __init__(self, state_store: StateStore, **kwargs):
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()

if not state_store:
raise ValueError("State store must be provided")
self.state_store = state_store

if "use_case" not in kwargs:
logging.warning(
"No use case provided. It is advised to provide a use case to help the LLM reason."
Expand Down Expand Up @@ -250,43 +255,31 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:

async def process_user_message(
self,
thread_id: str,
message_payload: UserMessagePayload,
chat_history: list[InteractionPayload] = None,
) -> AsyncGenerator[InteractionPayload, None]:
"""Process the complete message through the unified system.

Args:
----
thread_id (str): The ID of the thread the message belongs to.
task (str): The user message to process.
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.

Returns:
-------
dict: The response from the system.
"""
logging.info("Processing message: %s", message_payload.body.user_message)
logging.info("Chat history: %s", chat_history)

agent_input = {
"message": message_payload.body.user_message,
"injected_parameters": message_payload.body.injected_parameters,
}

latest_state = None
if chat_history is not None:
# Update input
for chat in reversed(chat_history):
if chat.root.payload_type in [
PayloadType.ANSWER_WITH_SOURCES,
PayloadType.DISAMBIGUATION_REQUESTS,
]:
latest_state = chat.body.assistant_state
break

# TODO: Trim the chat history to the last message from the user
if latest_state is not None:
await self.agentic_flow.load_state(latest_state)
state = self.state_store.get_state(thread_id)
if state is not None:
await self.agentic_flow.load_state(state)

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)
Expand Down Expand Up @@ -340,7 +333,7 @@ async def process_user_message(
):
# Get the state
assistant_state = await self.agentic_flow.save_state()
payload.body.assistant_state = assistant_state
self.state_store.save_state(thread_id, assistant_state)

logging.debug("Final Payload: %s", payload)

Expand Down
23 changes: 23 additions & 0 deletions text_2_sql/autogen/src/autogen_text_2_sql/state_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from cachetools import TTLCache


class StateStore(ABC):
@abstractmethod
def get_state(self, thread_id):
pass

@abstractmethod
def save_state(self, thread_id, state):
pass


class InMemoryStateStore(StateStore):
def __init__(self):
self.cache = TTLCache(maxsize=1000, ttl=4 * 60 * 60) # 4 hours

def get_state(self, thread_id: str) -> dict:
return self.cache.get(thread_id)

def save_state(self, thread_id: str, state: dict) -> None:
self.cache[thread_id] = state
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class DismabiguationRequest(InteractionPayloadBase):
decomposed_user_messages: list[list[str]] = Field(
default_factory=list, alias="decomposedUserMessages"
)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
Expand Down Expand Up @@ -86,7 +85,6 @@ class Source(InteractionPayloadBase):
default_factory=list, alias="decomposedUserMessages"
)
sources: list[Source] = Field(default_factory=list)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"
Expand Down
21 changes: 16 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading