Skip to content

Use built in chat history support #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion text_2_sql/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ Text2Sql__Tsql__ConnectionString=<Tsql databaseConnectionString if using Tsql Da
Text2Sql__Tsql__Database=<Tsql database if using Tsql Data Source>

# PostgreSQL Specific Connection Details
Text2Sql__Postgresql__ConnectionString=<Postgresql databaseConnectionString if using Postgresql Data Source>
Text2Sql__Postgresql__ConnectionString=<Postgresql databaseConnectionString if using Postgresql Data Source and a connection string>
Text2Sql__Postgresql__Database=<Postgresql database if using Postgresql Data Source>
Text2Sql__Postgresql__User=<Postgresql user if using Postgresql Data Source and not the connections string>
Text2Sql__Postgresql__Password=<Postgresql password if using Postgresql Data Source and not the connections string>
Text2Sql__Postgresql__ServerHostname=<Postgresql serverHostname if using Postgresql Data Source and not the connections string>
Text2Sql__Postgresql__Port=<Postgresql port if using Postgresql Data Source and not the connections string>

# Snowflake Specific Connection Details
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>
Expand Down
48 changes: 39 additions & 9 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self, **kwargs):

self.kwargs = {**DEFAULT_INJECTED_PARAMETERS, **kwargs}

self._agentic_flow = None

def get_all_agents(self):
"""Get all agents for the complete flow."""

Expand Down Expand Up @@ -97,14 +99,20 @@ def unified_selector(self, messages):
@property
def agentic_flow(self):
"""Create the unified flow for the complete process."""

if self._agentic_flow is not None:
return self._agentic_flow

flow = SelectorGroupChat(
self.get_all_agents(),
allow_repeated_speaker=False,
model_client=LLMModelCreator.get_model("4o-mini"),
termination_condition=self.termination_condition,
selector_func=self.unified_selector,
)
return flow

self._agentic_flow = flow
return self._agentic_flow

def parse_message_content(self, content):
"""Parse different message content formats into a dictionary."""
Expand Down Expand Up @@ -250,7 +258,7 @@ async def process_user_message(
Args:
----
task (str): The user message to process.
chat_history (list[str], optional): The chat history. Defaults to None.
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.

Returns:
Expand All @@ -262,17 +270,23 @@ async def process_user_message(

agent_input = {
"message": message_payload.body.user_message,
"chat_history": {},
"injected_parameters": message_payload.body.injected_parameters,
}

latest_state = None
if chat_history is not None:
# Update input
for idx, chat in enumerate(chat_history):
if chat.root.payload_type == PayloadType.USER_MESSAGE:
# For now only consider the user query
chat_history_key = f"chat_{idx}"
agent_input[chat_history_key] = chat.root.body.user_message
for chat in reversed(chat_history):
if chat.root.payload_type in [
PayloadType.ANSWER_WITH_SOURCES,
PayloadType.DISAMBIGUATION_REQUESTS,
]:
latest_state = chat.body.assistant_state
break

# TODO: Trim the chat history to the last message from the user
if latest_state is not None:
await self.agentic_flow.load_state(latest_state)

async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
logging.debug("Message: %s", message)
Expand Down Expand Up @@ -312,6 +326,22 @@ async def process_user_message(
logging.error("Unexpected TaskResult: %s", message)
raise ValueError("Unexpected TaskResult")

if payload is not None:
if (
payload is not None
and payload.payload_type is PayloadType.PROCESSING_UPDATE
):
logging.debug("Payload: %s", payload)
yield payload

# Return the final payload
if (
payload is not None
and payload.payload_type is not PayloadType.PROCESSING_UPDATE
):
# Get the state
assistant_state = await self.agentic_flow.save_state()
payload.body.assistant_state = assistant_state

logging.debug("Final Payload: %s", payload)

yield payload
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import logging
import json

from urllib.parse import urlparse
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields


Expand Down Expand Up @@ -66,10 +66,35 @@ async def query_execution(
"""
logging.info(f"Running query: {sql_query}")
results = []
connection_string = os.environ["Text2Sql__Postgresql__ConnectionString"]

if "Text2Sql__Postgresql__ConnectionString" in os.environ:
logging.info("Postgresql Connection string found in environment variables.")

p = urlparse(os.environ["Text2Sql__Postgresql__ConnectionString"])

postgres_connections = {
"dbname": p.path[1:],
"user": p.username,
"password": p.password,
"port": p.port,
"host": p.hostname,
}
else:
logging.warning(
"Postgresql Connection string not found in environment variables. Using individual variables."
)
postgres_connections = {
"dbname": os.environ["Text2Sql__Postgresql__Database"],
"user": os.environ["Text2Sql__Postgresql__User"],
"password": os.environ["Text2Sql__Postgresql__Password"],
"port": os.environ["Text2Sql__Postgresql__Port"],
"host": os.environ["Text2Sql__Postgresql__ServerHostname"],
}

# Establish an asynchronous connection to the PostgreSQL database
async with await psycopg.AsyncConnection.connect(connection_string) as conn:
async with await psycopg.AsyncConnection.connect(
**postgres_connections
) as conn:
# Create an asynchronous cursor
async with conn.cursor() as cursor:
await cursor.execute(sql_query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class PayloadSource(StrEnum):
USER = "user"
AGENT = "agent"
ASSISTANT = "assistant"


class PayloadType(StrEnum):
Expand All @@ -42,11 +42,13 @@ class PayloadBase(InteractionPayloadBase):
payload_type: PayloadType = Field(..., alias="payloadType")
payload_source: PayloadSource = Field(..., alias="payloadSource")

body: InteractionPayloadBase | None = Field(default=None)


class DismabiguationRequestsPayload(InteractionPayloadBase):
class Body(InteractionPayloadBase):
class DismabiguationRequest(InteractionPayloadBase):
agent_question: str | None = Field(..., alias="agentQuestion")
ASSISTANT_question: str | None = Field(..., alias="ASSISTANTQuestion")
user_choices: list[str] | None = Field(default=None, alias="userChoices")

disambiguation_requests: list[DismabiguationRequest] | None = Field(
Expand All @@ -55,12 +57,13 @@ class DismabiguationRequest(InteractionPayloadBase):
decomposed_user_messages: list[list[str]] = Field(
default_factory=list, alias="decomposedUserMessages"
)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
)
payload_source: Literal[PayloadSource.AGENT] = Field(
default=PayloadSource.AGENT, alias="payloadSource"
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
default=PayloadSource.ASSISTANT, alias="payloadSource"
)
body: Body | None = Field(default=None)

Expand All @@ -83,12 +86,13 @@ class Source(InteractionPayloadBase):
default_factory=list, alias="decomposedUserMessages"
)
sources: list[Source] = Field(default_factory=list)
assistant_state: dict | None = Field(default=None, alias="assistantState")

payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"
)
payload_source: Literal[PayloadSource.AGENT] = Field(
PayloadSource.AGENT, alias="payloadSource"
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
PayloadSource.ASSISTANT, alias="payloadSource"
)
body: Body | None = Field(default=None)

Expand All @@ -108,8 +112,8 @@ class Body(InteractionPayloadBase):
payload_type: Literal[PayloadType.PROCESSING_UPDATE] = Field(
PayloadType.PROCESSING_UPDATE, alias="payloadType"
)
payload_source: Literal[PayloadSource.AGENT] = Field(
PayloadSource.AGENT, alias="payloadSource"
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
PayloadSource.ASSISTANT, alias="payloadSource"
)
body: Body | None = Field(default=None)

Expand Down
Loading