Skip to content

Commit 307bd88

Browse files
committed
Update interactions
1 parent 05ee5b3 commit 307bd88

File tree

10 files changed

+133
-71
lines changed

10 files changed

+133
-71
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def __init__(self, state_store: StateStore, **kwargs):
4848

4949
self._agentic_flow = None
5050

51+
self._generate_follow_up_questions = (
52+
os.environ.get("Text2Sql__GenerateFollowUpQuestions", "True").lower()
53+
== "true"
54+
)
55+
5156
def get_all_agents(self):
5257
"""Get all agents for the complete flow."""
5358

@@ -57,7 +62,12 @@ def get_all_agents(self):
5762

5863
parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
5964

60-
answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs)
65+
if self._generate_follow_up_questions:
66+
answer_agent = LLMAgentCreator.create(
67+
"answer_with_follow_up_questions_agent", **self.kwargs
68+
)
69+
else:
70+
answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs)
6171

6272
agents = [
6373
user_message_rewrite_agent,
@@ -72,6 +82,7 @@ def termination_condition(self):
7282
"""Define the termination condition for the chat."""
7383
termination = (
7484
SourceMatchTermination("answer_agent")
85+
| SourceMatchTermination("answer_with_follow_up_questions_agent")
7586
# | TextMentionTermination(
7687
# "[]",
7788
# sources=["user_message_rewrite_agent"],
@@ -97,6 +108,11 @@ def unified_selector(self, messages):
97108
elif current_agent == "user_message_rewrite_agent":
98109
decision = "parallel_query_solving_agent"
99110
# Handle transition after parallel query solving
111+
elif (
112+
current_agent == "parallel_query_solving_agent"
113+
and self._generate_follow_up_questions
114+
):
115+
decision = "answer_with_follow_up_questions_agent"
100116
elif current_agent == "parallel_query_solving_agent":
101117
decision = "answer_agent"
102118

@@ -148,10 +164,19 @@ def parse_message_content(self, content):
148164
# If all parsing attempts fail, return the content as-is
149165
return content
150166

167+
def last_message_by_agent(self, messages: list, agent_name: str) -> TextMessage:
168+
"""Get the last message by a specific agent."""
169+
for message in reversed(messages):
170+
if message.source == agent_name:
171+
return message.content
172+
return None
173+
151174
def extract_steps(self, messages: list) -> list[list[str]]:
152175
"""Extract the steps messages from the answer."""
153176
# Only load sub-message results if we have a database result
154-
sub_message_results = self.parse_message_content(messages[1].content)
177+
sub_message_results = json.loads(
178+
self.last_message_by_agent(messages, "user_message_rewrite_agent")
179+
)
155180
logging.info("Steps Results: %s", sub_message_results)
156181

157182
steps = sub_message_results.get("steps", [])
@@ -187,12 +212,18 @@ def extract_disambiguation_request(
187212

188213
def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
189214
"""Extract the sources from the answer."""
190-
answer = messages[-1].content
191-
sql_query_results = self.parse_message_content(messages[-2].content)
215+
answer_payload = json.loads(messages[-1].content)
216+
217+
logging.info("Answer Payload: %s", answer_payload)
218+
sql_query_results = self.last_message_by_agent(
219+
messages, "parallel_query_solving_agent"
220+
)
192221

193222
try:
194223
if isinstance(sql_query_results, str):
195224
sql_query_results = json.loads(sql_query_results)
225+
elif sql_query_results is None:
226+
sql_query_results = {}
196227
except json.JSONDecodeError:
197228
logging.warning("Unable to read SQL query results: %s", sql_query_results)
198229
sql_query_results = {}
@@ -201,7 +232,7 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
201232
steps = self.extract_steps(messages)
202233

203234
logging.info("SQL Query Results: %s", sql_query_results)
204-
payload = AnswerWithSourcesPayload(answer=answer, steps=steps)
235+
payload = AnswerWithSourcesPayload(**answer_payload, steps=steps)
205236

206237
if not isinstance(sql_query_results, dict):
207238
logging.error(f"Expected dict, got {type(sql_query_results)}")
@@ -246,10 +277,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
246277

247278
except Exception as e:
248279
logging.error("Error processing results: %s", str(e))
280+
249281
# Return payload with error context instead of empty
250-
return AnswerWithSourcesPayload(
251-
answer=f"{answer}\nError processing results: {str(e)}"
252-
)
282+
return AnswerWithSourcesPayload(**answer_payload)
253283

254284
async def process_user_message(
255285
self,
@@ -293,7 +323,10 @@ async def process_user_message(
293323
payload = ProcessingUpdatePayload(
294324
message="Solving the query...",
295325
)
296-
elif message.source == "answer_agent":
326+
elif (
327+
message.source == "answer_agent"
328+
or message.source == "answer_with_follow_up_questions_agent"
329+
):
297330
payload = ProcessingUpdatePayload(
298331
message="Generating the answer...",
299332
)
@@ -302,7 +335,11 @@ async def process_user_message(
302335
# Now we need to return the final answer or the disambiguation request
303336
logging.info("TaskResult: %s", message)
304337

305-
if message.messages[-1].source == "answer_agent":
338+
if (
339+
message.messages[-1].source == "answer_agent"
340+
or message.messages[-1].source
341+
== "answer_with_follow_up_questions_agent"
342+
):
306343
# If the message is from the answer_agent, we need to return the final answer
307344
payload = self.extract_answer_payload(message.messages)
308345
elif message.messages[-1].source == "parallel_query_solving_agent":

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from jinja2 import Template
99
import logging
1010
from text_2_sql_core.structured_outputs import (
11-
AnswerAgentWithFollowUpQuestionsAgentOutput,
11+
AnswerAgentOutput,
12+
AnswerWithFollowUpQuestionsAgentOutput,
1213
UserMessageRewriteAgentOutput,
1314
)
1415
from autogen_core.model_context import BufferedChatCompletionContext
@@ -114,8 +115,10 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
114115
structured_output = None
115116
if agent_file.get("structured_output", False):
116117
# Import the structured output agent
117-
if name == "answer_agent_with_follow_up_questions":
118-
structured_output = AnswerAgentWithFollowUpQuestionsAgentOutput
118+
if name == "answer_agent":
119+
structured_output = AnswerAgentOutput
120+
elif name == "answer_with_follow_up_questions_agent":
121+
structured_output = AnswerWithFollowUpQuestionsAgentOutput
119122
elif name == "user_message_rewrite_agent":
120123
structured_output = UserMessageRewriteAgentOutput
121124

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,19 @@ def clean_query(self, sql_query: str) -> str:
227227
str: The cleaned SQL query.
228228
"""
229229
single_line_query = sql_query.strip().replace("\n", " ")
230+
231+
def sanitize_identifier_wrapper(identifier):
232+
"""Wrap the identifier in double quotes if it contains special characters."""
233+
if re.match(
234+
r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier
235+
): # Valid SQL identifier
236+
return identifier
237+
238+
return self.sanitize_identifier(identifier)
239+
230240
cleaned_query = re.sub(
231-
r'(?<!["\[\w])\b([a-zA-Z_][a-zA-Z0-9_-]*)\b(?!["\]])',
232-
lambda m: self.sanitize_identifier(m.group(1)),
241+
r'(?<![\["`])\b([a-zA-Z_][a-zA-Z0-9_-]*)\b(?![\]"`])',
242+
lambda m: sanitize_identifier_wrapper(m.group(1)),
233243
single_line_query,
234244
)
235245

@@ -244,6 +254,7 @@ async def query_validation(
244254
) -> Union[bool | list[dict]]:
245255
"""Validate the SQL query."""
246256
try:
257+
logging.info("Input SQL Query: %s", sql_query)
247258
cleaned_query = self.clean_query(sql_query)
248259
logging.info("Validating SQL Query: %s", cleaned_query)
249260
parsed_queries = sqlglot.parse(

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class Source(InteractionPayloadBase):
8484
follow_up_questions: list[str] | None = Field(
8585
default=None, alias="followUpQuestions"
8686
)
87-
assistant_state: dict | None = Field(default=None, alias="assistantState")
8887

8988
payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
9089
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,24 @@ system_message: |
66
</role_and_objective>
77
88
<system_information>
9-
You are part of an overall system that provides Text2SQL and subsequent data analysis functionality only. You will be passed a result from multiple SQL queries, you must formulate a response to the user's question using this information.
10-
You can assume that the SQL queries are correct and that the results are accurate.
11-
You and the wider system can only generate SQL queries and process the results of these queries. You cannot access any external resources.
12-
The main ability of the system is to perform natural language understanding and generate SQL queries from the user's question. These queries are then automatically run against the database and the results are passed to you.
9+
- You are part of an overall system that provides Text2SQL and subsequent data analysis functionality only. You will be passed a result from multiple SQL queries, you must formulate a response to the user's question using this information.
10+
- You can assume that the SQL queries are correct and that the results are accurate.
11+
- You and the wider system can only generate SQL queries and process the results of these queries. You cannot access any external resources.
12+
- The main ability of the system is to perform natural language understanding and generate SQL queries from the user's question. These queries are then automatically run against the database and the results are passed to you.
1313
</system_information>
1414
1515
<instructions>
16-
17-
Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries.
18-
19-
Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results.
20-
21-
You have no access to the internet or any other external resources. You can only use the information provided in the SQL queries and their results, to generate the response.
22-
23-
You can use Markdown and Markdown tables to format the response. You MUST use the information obtained from the SQL queries to generate the response.
24-
25-
If the user is asking about your capabilities, use the <system_information> to explain what you do.
26-
27-
Make sure your response directly addresses every part of the user's question.
28-
16+
- Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries.
17+
- Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results.
18+
- You have no access to the internet or any other external resources. You can only use the information provided in the SQL queries and their results, to generate the response.
19+
- You can use Markdown and Markdown tables to format the response. You MUST use the information obtained from the SQL queries to generate the response.
20+
- If the user is asking about your capabilities, use the <system_information> to explain what you do.
21+
- Make sure your response directly addresses every part of the user's question.
2922
</instructions>
23+
24+
<output>
25+
{
26+
"answer": "The response to the user's question.",
27+
}
28+
</output>
3029
context_size: 8

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent_with_follow_up_questions.yaml

Lines changed: 0 additions & 34 deletions
This file was deleted.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
model: "4o-mini"
2+
description: "An agent that generates a response to a user's question."
3+
system_message: |
4+
<role_and_objective>
5+
You are Senior Data Analyst, specializing in providing data driven answers to a user's question. Use the general business use case of '{{ use_case }}' to aid understanding of the user's question. You should provide a clear and concise response based on the information obtained from the SQL queries and their results. Adopt a data-driven approach to generate the response.
6+
</role_and_objective>
7+
8+
<system_information>
9+
- You are part of an overall system that provides Text2SQL and subsequent data analysis functionality only. You will be passed a result from multiple SQL queries, you must formulate a response to the user's question using this information.
10+
- You can assume that the SQL queries are correct and that the results are accurate.
11+
- You and the wider system can only generate SQL queries and process the results of these queries. You cannot access any external resources.
12+
- The main ability of the system is to perform natural language understanding and generate SQL queries from the user's question. These queries are then automatically run against the database and the results are passed to you.
13+
</system_information>
14+
15+
<instructions>
16+
- Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries.
17+
- Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results.
18+
- You have no access to the internet or any other external resources. You can only use the information provided in the SQL queries and their results, to generate the response.
19+
- You can use Markdown and Markdown tables to format the response. You MUST use the information obtained from the SQL queries to generate the response.
20+
- If the user is asking about your capabilities, use the <system_information> to explain what you do.
21+
- Make sure your response directly addresses every part of the user's question.
22+
- Finally, generate 3 data driven follow-up questions based on the information obtained from the SQL queries and their results. Think carefully about what questions may arise from the data and how they can be used to further analyze the data.
23+
</instructions>
24+
25+
<output>
26+
{
27+
"answer": "The response to the user's question.",
28+
"follow_up_questions": [
29+
"Follow-up question 1",
30+
"Follow-up question 2",
31+
"Follow-up question 3"
32+
]
33+
}
34+
</output>
35+
context_size: 8
36+
structured_output: true

text_2_sql/text_2_sql_core/src/text_2_sql_core/structured_outputs/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from text_2_sql_core.structured_outputs.user_message_rewrite_agent import (
77
UserMessageRewriteAgentOutput,
88
)
9-
from text_2_sql_core.structured_outputs.answer_agent_with_follow_up_questions import (
10-
AnswerAgentWithFollowUpQuestionsAgentOutput,
9+
from text_2_sql_core.structured_outputs.answer_with_follow_up_questions_agent import (
10+
AnswerWithFollowUpQuestionsAgentOutput,
1111
)
12+
from text_2_sql_core.structured_outputs.answer_agent import AnswerAgentOutput
1213

1314
__all__ = [
14-
"AnswerAgentWithFollowUpQuestionsAgentOutput",
15+
"AnswerAgentOutput",
16+
"AnswerWithFollowUpQuestionsAgentOutput",
1517
"SQLSchemaSelectionAgentOutput",
1618
"UserMessageRewriteAgentOutput",
1719
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from pydantic import BaseModel
4+
5+
6+
class AnswerAgentOutput(BaseModel):
7+
"""The output of the answer agent with follow up questions."""
8+
9+
answer: str
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic import BaseModel
44

55

6-
class AnswerAgentWithFollowUpQuestionsAgentOutput(BaseModel):
6+
class AnswerWithFollowUpQuestionsAgentOutput(BaseModel):
77
"""The output of the answer agent with follow up questions."""
88

99
answer: str

0 commit comments

Comments
 (0)