Skip to content

Improves disambiguation & answer agent #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions text_2_sql/autogen/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ description = "AutoGen Based Implementation"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"autogen-agentchat==0.4.0.dev9",
"autogen-core==0.4.0.dev9",
"autogen-ext[azure,openai]==0.4.0.dev9",
"autogen-agentchat==0.4.0.dev11",
"autogen-core==0.4.0.dev11",
"autogen-ext[azure,openai]==0.4.0.dev11",
"grpcio>=1.68.1",
"pyyaml>=6.0.2",
"text_2_sql_core",
"text_2_sql_core[snowflake,databricks]",
]

[dependency-groups]
Expand Down
35 changes: 21 additions & 14 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def __init__(self, engine_specific_rules: str, **kwargs: dict):
def set_mode(self):
"""Set the mode of the plugin based on the environment variables."""
self.use_query_cache = (
os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
)

self.pre_run_query_cache = (
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true"
)

self.use_column_value_store = (
os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true"
os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true"
)

def get_all_agents(self):
Expand Down Expand Up @@ -97,8 +97,10 @@ def get_all_agents(self):
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

ANSWER_AGENT = LLMAgentCreator.create("answer_agent")

QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
"question_decomposition_agent"
)

# Auto-responding UserProxyAgent
USER_PROXY = EmptyResponseUserProxyAgent(
Expand All @@ -111,8 +113,8 @@ def get_all_agents(self):
SQL_QUERY_GENERATION_AGENT,
SQL_SCHEMA_SELECTION_AGENT,
SQL_QUERY_CORRECTION_AGENT,
SQL_DISAMBIGUATION_AGENT,
ANSWER_AGENT,
QUESTION_DECOMPOSITION_AGENT,
SQL_DISAMBIGUATION_AGENT
]

if self.use_query_cache:
Expand All @@ -126,12 +128,15 @@ def termination_condition(self):
"""Define the termination condition for the chat."""
termination = (
TextMentionTermination("TERMINATE")
| (
TextMentionTermination("answer")
& TextMentionTermination("sources")
& SourceMatchTermination("sql_query_correction_agent")
)
| MaxMessageTermination(20)
| SourceMatchTermination(["answer_agent"])
)
return termination

@staticmethod
def unified_selector(messages):
"""Unified selector for the complete flow."""
logging.info("Messages: %s", messages)
Expand Down Expand Up @@ -165,13 +170,14 @@ def unified_selector(messages):
decision = "sql_disambiguation_agent"
elif messages[-1].source == "sql_disambiguation_agent":
decision = "sql_query_generation_agent"

elif messages[-1].source == "sql_query_correction_agent":
decision = "sql_query_generation_agent"

elif messages[-1].source == "sql_query_generation_agent":
decision = "sql_query_correction_agent"
elif messages[-1].source == "sql_query_correction_agent":
if messages[-1].content == "VALIDATED":
decision = "answer_agent"
else:
decision = "sql_query_correction_agent"
decision = "sql_query_correction_agent"
elif messages[-1].source == "answer_agent":
return "user_proxy" # Let user_proxy send TERMINATE

Expand All @@ -186,7 +192,8 @@ def agentic_flow(self):
allow_repeated_speaker=False,
model_client=LLMModelCreator.get_model("4o-mini"),
termination_condition=self.termination_condition,
selector_func=AutoGenText2Sql.unified_selector,
selector_func=self.selector,
selector_func=self.unified_selector,
)
return flow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
ai_search_helper.get_column_values,
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
)
elif tool_name == "sql_query_validation_tool":
return FunctionTool(
sql_helper.query_validation,
description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.",
)
elif tool_name == "current_datetime_tool":
return FunctionTool(
sql_helper.get_current_datetime,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


class SqlQueryCacheAgent(BaseChatAgent):
def __init__(self):
def __init__(self, name: str = "sql_query_cache_agent"):
super().__init__(
"sql_query_cache_agent",
name,
"An agent that fetches the queries from the cache based on the user question.",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,9 @@ async def on_messages_stream(
if schema not in final_schemas:
final_schemas.append(schema)

final_colmns = []
for column_value_result in column_value_results:
for column in column_value_result:
if column not in final_colmns:
final_colmns.append(column)

final_results = {
"schemas": final_schemas,
"column_values": final_colmns,
"COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results,
"SCHEMA_OPTIONS": final_schemas,
}

logging.info(f"Final results: {final_results}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from azure.identity import DefaultAzureCredential
from openai import AsyncAzureOpenAI
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.models import VectorizedQuery
from azure.search.documents.models import VectorizedQuery, QueryType
from azure.search.documents.aio import SearchClient
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
import os
Expand Down Expand Up @@ -69,11 +69,9 @@ async def run_ai_search_query(
credential=credential,
) as search_client:
if semantic_config is not None and vector_query is not None:
query_type = "semantic"
elif vector_query is not None:
query_type = "hybrid"
query_type = QueryType.SEMANTIC
else:
query_type = "full"
query_type = QueryType.FULL

results = await search_client.search(
top=top,
Expand Down Expand Up @@ -148,7 +146,7 @@ async def get_column_values(
"AIService__AzureSearchOptions__Text2SqlColumnValueStore__Index"
],
semantic_config=None,
top=15,
top=50,
include_scores=False,
minimum_score=5,
)
Expand All @@ -163,10 +161,14 @@ async def get_column_values(

column_values[trimmed_fqn].append(value["Value"])

logging.info("Column Values: %s", column_values)

filter_to_column = {text: column_values}

if as_json:
return json.dumps(column_values, default=str)
return json.dumps(filter_to_column, default=str)
else:
return column_values
return filter_to_column

async def get_entity_schemas(
self,
Expand All @@ -193,20 +195,24 @@ async def get_entity_schemas(
str: The schema of the views or tables in JSON format.
"""

logging.info("Search Text: %s", text)

retrieval_fields = [
"FQN",
# "FQN",
"Entity",
"EntityName",
"Schema",
"Definition",
# "Schema",
# "Definition",
"Description",
"Columns",
"EntityRelationships",
"CompleteEntityRelationshipsGraph",
] + engine_specific_fields

schemas = await self.run_ai_search_query(
text,
["DefinitionEmbedding"],
# ["DefinitionEmbedding"],
["DescriptionEmbedding"],
retrieval_fields,
os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"],
os.environ[
Expand All @@ -221,7 +227,25 @@ async def get_entity_schemas(
for schema in schemas:
filtered_schemas = []

del schema["FQN"]
# del schema["FQN"]

if (
schema["CompleteEntityRelationshipsGraph"] is not None
and len(schema["CompleteEntityRelationshipsGraph"]) == 0
):
del schema["CompleteEntityRelationshipsGraph"]

if (
schema["SammpleValues"] is not None
and len(schema["SammpleValues"]) == 0
):
del schema["SammpleValues"]

if (
schema["EntityRelationships"] is not None
and len(schema["EntityRelationships"]) == 0
):
del schema["EntityRelationships"]

if schema["Entity"].lower() not in excluded_entities:
filtered_schemas.append(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ async def get_entity_schemas(
)

for schema in schemas:
schema["SelectFromEntity"] = ".".join(
[schema["Catalog"], schema["Schema"], schema["Entity"]]
)
# schema["SelectFromEntity"] = ".".join(
# [schema["Catalog"], schema["Schema"], schema["Entity"]]
# )
schema["SelectFromEntity"] = schema["Entity"]

del schema["Entity"]
del schema["Schema"]
# del schema["Schema"]
del schema["Catalog"]

if as_json:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
class SqlConnector(ABC):
def __init__(self):
self.use_query_cache = (
os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true"
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
)

self.pre_run_query_cache = (
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true"
)

self.use_column_value_store = (
os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true"
os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true"
)

self.ai_search_connector = ConnectorFactory.get_ai_search_connector()
Expand Down Expand Up @@ -91,7 +91,14 @@ async def query_execution_with_limit(
-------
list[dict]: The results of the SQL query.
"""
return await self.query_execution(sql_query, cast_to=None, limit=25)

# Validate the SQL query
validation_result = await self.query_validation(sql_query)

if isinstance(validation_result, bool) and validation_result:
return await self.query_execution(sql_query, cast_to=None, limit=25)
else:
return validation_result

async def query_validation(
self,
Expand Down Expand Up @@ -127,9 +134,7 @@ async def fetch_queries_from_cache(self, question: str) -> str:
["QuestionEmbedding"],
["Question", "SqlQueryDecomposition"],
os.environ["AIService__AzureSearchOptions__Text2SqlQueryCache__Index"],
os.environ[
"AIService__AzureSearchOptions__Text2SqlQueryCache__SemanticConfig"
],
None,
top=1,
include_scores=True,
minimum_score=1.5,
Expand Down

This file was deleted.

Loading
Loading