Skip to content

Commit 1e2d21f

Browse files
Improves disambiguation & answer agent (#99)
1 parent fe41812 commit 1e2d21f

File tree

13 files changed

+212
-158
lines changed

13 files changed

+212
-158
lines changed

text_2_sql/autogen/pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ description = "AutoGen Based Implementation"
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8-
"autogen-agentchat==0.4.0.dev9",
9-
"autogen-core==0.4.0.dev9",
10-
"autogen-ext[azure,openai]==0.4.0.dev9",
8+
"autogen-agentchat==0.4.0.dev11",
9+
"autogen-core==0.4.0.dev11",
10+
"autogen-ext[azure,openai]==0.4.0.dev11",
1111
"grpcio>=1.68.1",
1212
"pyyaml>=6.0.2",
13-
"text_2_sql_core",
13+
"text_2_sql_core[snowflake,databricks]",
1414
]
1515

1616
[dependency-groups]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def __init__(self, engine_specific_rules: str, **kwargs: dict):
5050
def set_mode(self):
5151
"""Set the mode of the plugin based on the environment variables."""
5252
self.use_query_cache = (
53-
os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
53+
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
5454
)
5555

5656
self.pre_run_query_cache = (
57-
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
57+
os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true"
5858
)
5959

6060
self.use_column_value_store = (
61-
os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true"
61+
os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true"
6262
)
6363

6464
def get_all_agents(self):
@@ -97,8 +97,10 @@ def get_all_agents(self):
9797
engine_specific_rules=self.engine_specific_rules,
9898
**self.kwargs,
9999
)
100-
101-
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
100+
101+
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
102+
"question_decomposition_agent"
103+
)
102104

103105
# Auto-responding UserProxyAgent
104106
USER_PROXY = EmptyResponseUserProxyAgent(
@@ -111,8 +113,8 @@ def get_all_agents(self):
111113
SQL_QUERY_GENERATION_AGENT,
112114
SQL_SCHEMA_SELECTION_AGENT,
113115
SQL_QUERY_CORRECTION_AGENT,
114-
SQL_DISAMBIGUATION_AGENT,
115-
ANSWER_AGENT,
116+
QUESTION_DECOMPOSITION_AGENT,
117+
SQL_DISAMBIGUATION_AGENT
116118
]
117119

118120
if self.use_query_cache:
@@ -126,12 +128,15 @@ def termination_condition(self):
126128
"""Define the termination condition for the chat."""
127129
termination = (
128130
TextMentionTermination("TERMINATE")
131+
| (
132+
TextMentionTermination("answer")
133+
& TextMentionTermination("sources")
134+
& SourceMatchTermination("sql_query_correction_agent")
135+
)
129136
| MaxMessageTermination(20)
130-
| SourceMatchTermination(["answer_agent"])
131137
)
132138
return termination
133139

134-
@staticmethod
135140
def unified_selector(messages):
136141
"""Unified selector for the complete flow."""
137142
logging.info("Messages: %s", messages)
@@ -165,13 +170,14 @@ def unified_selector(messages):
165170
decision = "sql_disambiguation_agent"
166171
elif messages[-1].source == "sql_disambiguation_agent":
167172
decision = "sql_query_generation_agent"
173+
174+
elif messages[-1].source == "sql_query_correction_agent":
175+
decision = "sql_query_generation_agent"
176+
168177
elif messages[-1].source == "sql_query_generation_agent":
169178
decision = "sql_query_correction_agent"
170179
elif messages[-1].source == "sql_query_correction_agent":
171-
if messages[-1].content == "VALIDATED":
172-
decision = "answer_agent"
173-
else:
174-
decision = "sql_query_correction_agent"
180+
decision = "sql_query_correction_agent"
175181
elif messages[-1].source == "answer_agent":
176182
return "user_proxy" # Let user_proxy send TERMINATE
177183

@@ -186,7 +192,8 @@ def agentic_flow(self):
186192
allow_repeated_speaker=False,
187193
model_client=LLMModelCreator.get_model("4o-mini"),
188194
termination_condition=self.termination_condition,
189-
selector_func=AutoGenText2Sql.unified_selector,
195+
selector_func=self.selector,
196+
selector_func=self.unified_selector,
190197
)
191198
return flow
192199

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
4949
ai_search_helper.get_column_values,
5050
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
5151
)
52-
elif tool_name == "sql_query_validation_tool":
53-
return FunctionTool(
54-
sql_helper.query_validation,
55-
description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.",
56-
)
5752
elif tool_name == "current_datetime_tool":
5853
return FunctionTool(
5954
sql_helper.get_current_datetime,

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313

1414
class SqlQueryCacheAgent(BaseChatAgent):
15-
def __init__(self):
15+
def __init__(self, name: str = "sql_query_cache_agent"):
1616
super().__init__(
17-
"sql_query_cache_agent",
17+
name,
1818
"An agent that fetches the queries from the cache based on the user question.",
1919
)
2020

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,9 @@ async def on_messages_stream(
101101
if schema not in final_schemas:
102102
final_schemas.append(schema)
103103

104-
final_colmns = []
105-
for column_value_result in column_value_results:
106-
for column in column_value_result:
107-
if column not in final_colmns:
108-
final_colmns.append(column)
109-
110104
final_results = {
111-
"schemas": final_schemas,
112-
"column_values": final_colmns,
105+
"COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results,
106+
"SCHEMA_OPTIONS": final_schemas,
113107
}
114108

115109
logging.info(f"Final results: {final_results}")

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from azure.identity import DefaultAzureCredential
44
from openai import AsyncAzureOpenAI
55
from azure.core.credentials import AzureKeyCredential
6-
from azure.search.documents.models import VectorizedQuery
6+
from azure.search.documents.models import VectorizedQuery, QueryType
77
from azure.search.documents.aio import SearchClient
88
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
99
import os
@@ -69,11 +69,9 @@ async def run_ai_search_query(
6969
credential=credential,
7070
) as search_client:
7171
if semantic_config is not None and vector_query is not None:
72-
query_type = "semantic"
73-
elif vector_query is not None:
74-
query_type = "hybrid"
72+
query_type = QueryType.SEMANTIC
7573
else:
76-
query_type = "full"
74+
query_type = QueryType.FULL
7775

7876
results = await search_client.search(
7977
top=top,
@@ -148,7 +146,7 @@ async def get_column_values(
148146
"AIService__AzureSearchOptions__Text2SqlColumnValueStore__Index"
149147
],
150148
semantic_config=None,
151-
top=15,
149+
top=50,
152150
include_scores=False,
153151
minimum_score=5,
154152
)
@@ -163,10 +161,14 @@ async def get_column_values(
163161

164162
column_values[trimmed_fqn].append(value["Value"])
165163

164+
logging.info("Column Values: %s", column_values)
165+
166+
filter_to_column = {text: column_values}
167+
166168
if as_json:
167-
return json.dumps(column_values, default=str)
169+
return json.dumps(filter_to_column, default=str)
168170
else:
169-
return column_values
171+
return filter_to_column
170172

171173
async def get_entity_schemas(
172174
self,
@@ -193,20 +195,24 @@ async def get_entity_schemas(
193195
str: The schema of the views or tables in JSON format.
194196
"""
195197

198+
logging.info("Search Text: %s", text)
199+
196200
retrieval_fields = [
197-
"FQN",
201+
# "FQN",
198202
"Entity",
199203
"EntityName",
200-
"Schema",
201-
"Definition",
204+
# "Schema",
205+
# "Definition",
206+
"Description",
202207
"Columns",
203208
"EntityRelationships",
204209
"CompleteEntityRelationshipsGraph",
205210
] + engine_specific_fields
206211

207212
schemas = await self.run_ai_search_query(
208213
text,
209-
["DefinitionEmbedding"],
214+
# ["DefinitionEmbedding"],
215+
["DescriptionEmbedding"],
210216
retrieval_fields,
211217
os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"],
212218
os.environ[
@@ -221,7 +227,25 @@ async def get_entity_schemas(
221227
for schema in schemas:
222228
filtered_schemas = []
223229

224-
del schema["FQN"]
230+
# del schema["FQN"]
231+
232+
if (
233+
schema["CompleteEntityRelationshipsGraph"] is not None
234+
and len(schema["CompleteEntityRelationshipsGraph"]) == 0
235+
):
236+
del schema["CompleteEntityRelationshipsGraph"]
237+
238+
if (
239+
schema["SammpleValues"] is not None
240+
and len(schema["SammpleValues"]) == 0
241+
):
242+
del schema["SammpleValues"]
243+
244+
if (
245+
schema["EntityRelationships"] is not None
246+
and len(schema["EntityRelationships"]) == 0
247+
):
248+
del schema["EntityRelationships"]
225249

226250
if schema["Entity"].lower() not in excluded_entities:
227251
filtered_schemas.append(schema)

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ async def get_entity_schemas(
9898
)
9999

100100
for schema in schemas:
101-
schema["SelectFromEntity"] = ".".join(
102-
[schema["Catalog"], schema["Schema"], schema["Entity"]]
103-
)
101+
# schema["SelectFromEntity"] = ".".join(
102+
# [schema["Catalog"], schema["Schema"], schema["Entity"]]
103+
# )
104+
schema["SelectFromEntity"] = schema["Entity"]
104105

105106
del schema["Entity"]
106-
del schema["Schema"]
107+
# del schema["Schema"]
107108
del schema["Catalog"]
108109

109110
if as_json:

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
class SqlConnector(ABC):
1414
def __init__(self):
1515
self.use_query_cache = (
16-
os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
16+
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
1717
)
1818

1919
self.pre_run_query_cache = (
20-
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
20+
os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true"
2121
)
2222

2323
self.use_column_value_store = (
24-
os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true"
24+
os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true"
2525
)
2626

2727
self.ai_search_connector = ConnectorFactory.get_ai_search_connector()
@@ -91,7 +91,14 @@ async def query_execution_with_limit(
9191
-------
9292
list[dict]: The results of the SQL query.
9393
"""
94-
return await self.query_execution(sql_query, cast_to=None, limit=25)
94+
95+
# Validate the SQL query
96+
validation_result = await self.query_validation(sql_query)
97+
98+
if isinstance(validation_result, bool) and validation_result:
99+
return await self.query_execution(sql_query, cast_to=None, limit=25)
100+
else:
101+
return validation_result
95102

96103
async def query_validation(
97104
self,
@@ -127,9 +134,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
127134
["QuestionEmbedding"],
128135
["Question", "SqlQueryDecomposition"],
129136
os.environ["AIService__AzureSearchOptions__Text2SqlQueryCache__Index"],
130-
os.environ[
131-
"AIService__AzureSearchOptions__Text2SqlQueryCache__SemanticConfig"
132-
],
137+
None,
133138
top=1,
134139
include_scores=True,
135140
minimum_score=1.5,

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)