Skip to content

Commit c416cc7

Browse files
Update interaction payloads (#148)
1 parent 40e21ed commit c416cc7

File tree

4 files changed

+84
-21
lines changed

4 files changed

+84
-21
lines changed

text_2_sql/.env.example

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ Text2Sql__Tsql__ConnectionString=<Tsql databaseConnectionString if using Tsql Da
2727
Text2Sql__Tsql__Database=<Tsql database if using Tsql Data Source>
2828

2929
# PostgreSQL Specific Connection Details
30-
Text2Sql__Postgresql__ConnectionString=<Postgresql databaseConnectionString if using Postgresql Data Source>
30+
Text2Sql__Postgresql__ConnectionString=<Postgresql databaseConnectionString if using Postgresql Data Source and a connection string>
3131
Text2Sql__Postgresql__Database=<Postgresql database if using Postgresql Data Source>
32+
Text2Sql__Postgresql__User=<Postgresql user if using Postgresql Data Source and not the connections string>
33+
Text2Sql__Postgresql__Password=<Postgresql password if using Postgresql Data Source and not the connections string>
34+
Text2Sql__Postgresql__ServerHostname=<Postgresql serverHostname if using Postgresql Data Source and not the connections string>
35+
Text2Sql__Postgresql__Port=<Postgresql port if using Postgresql Data Source and not the connections string>
3236

3337
# Snowflake Specific Connection Details
3438
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(self, **kwargs):
4141

4242
self.kwargs = {**DEFAULT_INJECTED_PARAMETERS, **kwargs}
4343

44+
self._agentic_flow = None
45+
4446
def get_all_agents(self):
4547
"""Get all agents for the complete flow."""
4648

@@ -97,14 +99,20 @@ def unified_selector(self, messages):
9799
@property
98100
def agentic_flow(self):
99101
"""Create the unified flow for the complete process."""
102+
103+
if self._agentic_flow is not None:
104+
return self._agentic_flow
105+
100106
flow = SelectorGroupChat(
101107
self.get_all_agents(),
102108
allow_repeated_speaker=False,
103109
model_client=LLMModelCreator.get_model("4o-mini"),
104110
termination_condition=self.termination_condition,
105111
selector_func=self.unified_selector,
106112
)
107-
return flow
113+
114+
self._agentic_flow = flow
115+
return self._agentic_flow
108116

109117
def parse_message_content(self, content):
110118
"""Parse different message content formats into a dictionary."""
@@ -250,7 +258,7 @@ async def process_user_message(
250258
Args:
251259
----
252260
task (str): The user message to process.
253-
chat_history (list[str], optional): The chat history. Defaults to None.
261+
chat_history (list[str], optional): The chat history. Defaults to None. The last message is the most recent message.
254262
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
255263
256264
Returns:
@@ -262,17 +270,23 @@ async def process_user_message(
262270

263271
agent_input = {
264272
"message": message_payload.body.user_message,
265-
"chat_history": {},
266273
"injected_parameters": message_payload.body.injected_parameters,
267274
}
268275

276+
latest_state = None
269277
if chat_history is not None:
270278
# Update input
271-
for idx, chat in enumerate(chat_history):
272-
if chat.root.payload_type == PayloadType.USER_MESSAGE:
273-
# For now only consider the user query
274-
chat_history_key = f"chat_{idx}"
275-
agent_input[chat_history_key] = chat.root.body.user_message
279+
for chat in reversed(chat_history):
280+
if chat.root.payload_type in [
281+
PayloadType.ANSWER_WITH_SOURCES,
282+
PayloadType.DISAMBIGUATION_REQUESTS,
283+
]:
284+
latest_state = chat.body.assistant_state
285+
break
286+
287+
# TODO: Trim the chat history to the last message from the user
288+
if latest_state is not None:
289+
await self.agentic_flow.load_state(latest_state)
276290

277291
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
278292
logging.debug("Message: %s", message)
@@ -312,6 +326,22 @@ async def process_user_message(
312326
logging.error("Unexpected TaskResult: %s", message)
313327
raise ValueError("Unexpected TaskResult")
314328

315-
if payload is not None:
329+
if (
330+
payload is not None
331+
and payload.payload_type is PayloadType.PROCESSING_UPDATE
332+
):
316333
logging.debug("Payload: %s", payload)
317334
yield payload
335+
336+
# Return the final payload
337+
if (
338+
payload is not None
339+
and payload.payload_type is not PayloadType.PROCESSING_UPDATE
340+
):
341+
# Get the state
342+
assistant_state = await self.agentic_flow.save_state()
343+
payload.body.assistant_state = assistant_state
344+
345+
logging.debug("Final Payload: %s", payload)
346+
347+
yield payload

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import logging
88
import json
9-
9+
from urllib.parse import urlparse
1010
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1111

1212

@@ -66,10 +66,35 @@ async def query_execution(
6666
"""
6767
logging.info(f"Running query: {sql_query}")
6868
results = []
69-
connection_string = os.environ["Text2Sql__Postgresql__ConnectionString"]
69+
70+
if "Text2Sql__Postgresql__ConnectionString" in os.environ:
71+
logging.info("Postgresql Connection string found in environment variables.")
72+
73+
p = urlparse(os.environ["Text2Sql__Postgresql__ConnectionString"])
74+
75+
postgres_connections = {
76+
"dbname": p.path[1:],
77+
"user": p.username,
78+
"password": p.password,
79+
"port": p.port,
80+
"host": p.hostname,
81+
}
82+
else:
83+
logging.warning(
84+
"Postgresql Connection string not found in environment variables. Using individual variables."
85+
)
86+
postgres_connections = {
87+
"dbname": os.environ["Text2Sql__Postgresql__Database"],
88+
"user": os.environ["Text2Sql__Postgresql__User"],
89+
"password": os.environ["Text2Sql__Postgresql__Password"],
90+
"port": os.environ["Text2Sql__Postgresql__Port"],
91+
"host": os.environ["Text2Sql__Postgresql__ServerHostname"],
92+
}
7093

7194
# Establish an asynchronous connection to the PostgreSQL database
72-
async with await psycopg.AsyncConnection.connect(connection_string) as conn:
95+
async with await psycopg.AsyncConnection.connect(
96+
**postgres_connections
97+
) as conn:
7398
# Create an asynchronous cursor
7499
async with conn.cursor() as cursor:
75100
await cursor.execute(sql_query)

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
class PayloadSource(StrEnum):
1919
USER = "user"
20-
AGENT = "agent"
20+
ASSISTANT = "assistant"
2121

2222

2323
class PayloadType(StrEnum):
@@ -42,11 +42,13 @@ class PayloadBase(InteractionPayloadBase):
4242
payload_type: PayloadType = Field(..., alias="payloadType")
4343
payload_source: PayloadSource = Field(..., alias="payloadSource")
4444

45+
body: InteractionPayloadBase | None = Field(default=None)
46+
4547

4648
class DismabiguationRequestsPayload(InteractionPayloadBase):
4749
class Body(InteractionPayloadBase):
4850
class DismabiguationRequest(InteractionPayloadBase):
49-
agent_question: str | None = Field(..., alias="agentQuestion")
51+
ASSISTANT_question: str | None = Field(..., alias="ASSISTANTQuestion")
5052
user_choices: list[str] | None = Field(default=None, alias="userChoices")
5153

5254
disambiguation_requests: list[DismabiguationRequest] | None = Field(
@@ -55,12 +57,13 @@ class DismabiguationRequest(InteractionPayloadBase):
5557
decomposed_user_messages: list[list[str]] = Field(
5658
default_factory=list, alias="decomposedUserMessages"
5759
)
60+
assistant_state: dict | None = Field(default=None, alias="assistantState")
5861

5962
payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
6063
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
6164
)
62-
payload_source: Literal[PayloadSource.AGENT] = Field(
63-
default=PayloadSource.AGENT, alias="payloadSource"
65+
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
66+
default=PayloadSource.ASSISTANT, alias="payloadSource"
6467
)
6568
body: Body | None = Field(default=None)
6669

@@ -83,12 +86,13 @@ class Source(InteractionPayloadBase):
8386
default_factory=list, alias="decomposedUserMessages"
8487
)
8588
sources: list[Source] = Field(default_factory=list)
89+
assistant_state: dict | None = Field(default=None, alias="assistantState")
8690

8791
payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
8892
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"
8993
)
90-
payload_source: Literal[PayloadSource.AGENT] = Field(
91-
PayloadSource.AGENT, alias="payloadSource"
94+
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
95+
PayloadSource.ASSISTANT, alias="payloadSource"
9296
)
9397
body: Body | None = Field(default=None)
9498

@@ -108,8 +112,8 @@ class Body(InteractionPayloadBase):
108112
payload_type: Literal[PayloadType.PROCESSING_UPDATE] = Field(
109113
PayloadType.PROCESSING_UPDATE, alias="payloadType"
110114
)
111-
payload_source: Literal[PayloadSource.AGENT] = Field(
112-
PayloadSource.AGENT, alias="payloadSource"
115+
payload_source: Literal[PayloadSource.ASSISTANT] = Field(
116+
PayloadSource.ASSISTANT, alias="payloadSource"
113117
)
114118
body: Body | None = Field(default=None)
115119

0 commit comments

Comments
 (0)