From fe9cd88ec828e34d08bb7ea22be0a9a3295c1102 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 18:25:15 +0000 Subject: [PATCH 01/11] Include vali --- .../src/autogen_text_2_sql/creators/llm_agent_creator.py | 5 ----- .../src/text_2_sql_core/connectors/sql.py | 9 ++++++++- .../prompts/sql_query_generation_agent.yaml | 4 +--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py index 1ebd956..baade79 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py @@ -49,11 +49,6 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str): ai_search_helper.get_column_values, description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.", ) - elif tool_name == "sql_query_validation_tool": - return FunctionTool( - sql_helper.query_validation, - description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.", - ) elif tool_name == "current_datetime_tool": return FunctionTool( sql_helper.get_current_datetime, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index f506816..270a3cb 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -91,7 +91,14 @@ async def query_execution_with_limit( ------- list[dict]: The results of the SQL query. """ - return await self.query_execution(sql_query, cast_to=None, limit=25) + + # Validate the SQL query + validation_result = await self.query_validation(sql_query) + + if isinstance(validation_result, bool) and validation_result: + return await self.query_execution(sql_query, cast_to=None, limit=25) + else: + return validation_result async def query_validation( self, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml index 6b56687..25dd2d2 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_generation_agent.yaml @@ -8,8 +8,7 @@ system_message: You must: 1. Use the schema information provided and this mapping to generate a SQL query that will answer the user's question. 2. If you need additional schema information, you can obtain it using the schema selection tool. Only use this when you do not have enough information to generate the SQL query. - 3. Validate the SQL query to ensure it is syntactically correct using the validation tool. - 4. Run the SQL query to fetch the results. + 3. Run the SQL query to fetch the results. When generating the SQL query, you MUST follow these rules: @@ -36,5 +35,4 @@ system_message: tools: - sql_query_execution_tool - sql_get_entity_schemas_tool - - sql_query_validation_tool - current_datetime_tool From 8c81a3cc97fd9b91f0f02853d8db083b2823f37a Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 18:54:34 +0000 Subject: [PATCH 02/11] Update agents --- .../src/autogen_text_2_sql/autogen_text_2_sql.py | 15 ++++++++------- .../custom_agents/sql_query_cache_agent.py | 4 ++-- .../src/text_2_sql_core/connectors/sql.py | 6 +++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index a7f36db..2f19dd8 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -32,15 +32,15 @@ def __init__(self, engine_specific_rules: str, **kwargs: dict): def set_mode(self): """Set the mode of the plugin based on the environment variables.""" self.use_query_cache = ( - os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true" + os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" ) self.pre_run_query_cache = ( - os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true" + os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true" ) self.use_column_value_store = ( - os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true" + os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true" ) @property @@ -100,13 +100,14 @@ def termination_condition(self): ) return termination - @staticmethod - def selector(messages): + def selector(self, messages): logging.info("Messages: %s", messages) decision = None # Initialize decision variable - if len(messages) == 1: + if len(messages) == 1 and self.use_query_cache: decision = "sql_query_cache_agent" + elif len(messages) == 1: + decision = "question_decomposition_agent" elif ( messages[-1].source == "sql_query_cache_agent" @@ -161,7 +162,7 @@ def agentic_flow(self): allow_repeated_speaker=False, model_client=LLMModelCreator.get_model("4o-mini"), termination_condition=self.termination_condition, - selector_func=AutoGenText2Sql.selector, + selector_func=self.selector, ) return agentic_flow diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index 49e0730..2e7ea8c 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -12,9 +12,9 @@ class SqlQueryCacheAgent(BaseChatAgent): - def __init__(self): + def __init__(self, name: str = "sql_query_cache_agent"): super().__init__( - "sql_query_cache_agent", + name, "An agent that fetches the queries from the cache based on the user question.", ) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 270a3cb..2f36e90 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -13,15 +13,15 @@ class SqlConnector(ABC): def __init__(self): self.use_query_cache = ( - os.environ.get("Text2Sql__UseQueryCache", "False").lower() == "true" + os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true" ) self.pre_run_query_cache = ( - os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true" + os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true" ) self.use_column_value_store = ( - os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true" + os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true" ) self.ai_search_connector = ConnectorFactory.get_ai_search_connector() From 4c1af2cc0f63f595059cd8030cc75b22f8be22b6 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 18:56:41 +0000 Subject: [PATCH 03/11] Update sql --- .../text_2_sql_core/src/text_2_sql_core/connectors/sql.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index 2f36e90..b838953 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -134,9 +134,7 @@ async def fetch_queries_from_cache(self, question: str) -> str: ["QuestionEmbedding"], ["Question", "SqlQueryDecomposition"], os.environ["AIService__AzureSearchOptions__Text2SqlQueryCache__Index"], - os.environ[ - "AIService__AzureSearchOptions__Text2SqlQueryCache__SemanticConfig" - ], + None, top=1, include_scores=True, minimum_score=1.5, From 6e4489a40f0a48e34a37894e2a0f32c19eb01355 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 18:58:09 +0000 Subject: [PATCH 04/11] Update deps --- text_2_sql/autogen/pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/text_2_sql/autogen/pyproject.toml b/text_2_sql/autogen/pyproject.toml index f9061b5..721b9f9 100644 --- a/text_2_sql/autogen/pyproject.toml +++ b/text_2_sql/autogen/pyproject.toml @@ -5,12 +5,12 @@ description = "AutoGen Based Implementation" readme = "README.md" requires-python = ">=3.12" dependencies = [ - "autogen-agentchat==0.4.0.dev9", - "autogen-core==0.4.0.dev9", - "autogen-ext[azure,openai]==0.4.0.dev9", + "autogen-agentchat==0.4.0.dev11", + "autogen-core==0.4.0.dev11", + "autogen-ext[azure,openai]==0.4.0.dev11", "grpcio>=1.68.1", "pyyaml>=6.0.2", - "text_2_sql_core", + "text_2_sql_core[snowflake,databricks]", ] [dependency-groups] From d8a2daa61f703d6b112ec4340b1cd0a93ac2f8f6 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 19:01:42 +0000 Subject: [PATCH 05/11] Update search --- .../src/text_2_sql_core/connectors/ai_search.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index 40aee33..2bea86c 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -3,7 +3,7 @@ from azure.identity import DefaultAzureCredential from openai import AsyncAzureOpenAI from azure.core.credentials import AzureKeyCredential -from azure.search.documents.models import VectorizedQuery +from azure.search.documents.models import VectorizedQuery, QueryType from azure.search.documents.aio import SearchClient from text_2_sql_core.utils.environment import IdentityType, get_identity_type import os @@ -69,11 +69,9 @@ async def run_ai_search_query( credential=credential, ) as search_client: if semantic_config is not None and vector_query is not None: - query_type = "semantic" - elif vector_query is not None: - query_type = "hybrid" + query_type = QueryType.SEMANTIC else: - query_type = "full" + query_type = QueryType.FULL results = await search_client.search( top=top, From 5f48b4e2efc3a276bf34a085ab6c12f3caee5712 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 19:21:45 +0000 Subject: [PATCH 06/11] Temp changes for demo --- .../src/text_2_sql_core/connectors/ai_search.py | 12 +++++++----- .../src/text_2_sql_core/connectors/databricks_sql.py | 9 ++++++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index 2bea86c..722116b 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -192,11 +192,12 @@ async def get_entity_schemas( """ retrieval_fields = [ - "FQN", + # "FQN", "Entity", "EntityName", - "Schema", - "Definition", + # "Schema", + # "Definition", + "Description", "Columns", "EntityRelationships", "CompleteEntityRelationshipsGraph", @@ -204,7 +205,8 @@ async def get_entity_schemas( schemas = await self.run_ai_search_query( text, - ["DefinitionEmbedding"], + # ["DefinitionEmbedding"], + ["DescriptionEmbedding"], retrieval_fields, os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"], os.environ[ @@ -219,7 +221,7 @@ async def get_entity_schemas( for schema in schemas: filtered_schemas = [] - del schema["FQN"] + # del schema["FQN"] if schema["Entity"].lower() not in excluded_entities: filtered_schemas.append(schema) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index cca4cc5..d544b60 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -98,12 +98,15 @@ async def get_entity_schemas( ) for schema in schemas: - schema["SelectFromEntity"] = ".".join( - [schema["Catalog"], schema["Schema"], schema["Entity"]] + # schema["SelectFromEntity"] = ".".join( + # [schema["Catalog"], schema["Schema"], schema["Entity"]] + # ) + schema["SelectFromEntity"] = ( + os.environ["Text2Sql__Databricks__Catalog"] + "." + schema["Entity"] ) del schema["Entity"] - del schema["Schema"] + # del schema["Schema"] del schema["Catalog"] if as_json: From 8924cff6336e014b439d42cc6b5ea6fd7d0f1c36 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 19:35:32 +0000 Subject: [PATCH 07/11] Combine down to 1 agent --- .../autogen_text_2_sql/autogen_text_2_sql.py | 16 ++--- .../text_2_sql_core/connectors/ai_search.py | 2 + .../text_2_sql_core/prompts/answer_agent.yaml | 18 ------ .../prompts/sql_query_correction_agent.yaml | 64 ++++++++++++++----- uv.lock | 42 +++++++----- 5 files changed, 83 insertions(+), 59 deletions(-) delete mode 100644 text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 2f19dd8..15998fc 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -70,7 +70,6 @@ def agents(self): **self.kwargs, ) - ANSWER_AGENT = LLMAgentCreator.create("answer_agent") QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create( "question_decomposition_agent" ) @@ -79,7 +78,6 @@ def agents(self): SQL_QUERY_GENERATION_AGENT, SQL_SCHEMA_SELECTION_AGENT, SQL_QUERY_CORRECTION_AGENT, - ANSWER_AGENT, QUESTION_DECOMPOSITION_AGENT, SQL_DISAMBIGUATION_AGENT, ] @@ -95,8 +93,12 @@ def termination_condition(self): """Define the termination condition for the chat.""" termination = ( TextMentionTermination("TERMINATE") + | ( + TextMentionTermination("answer") + & TextMentionTermination("sources") + & SourceMatchTermination("sql_query_correction_agent") + ) | MaxMessageTermination(20) - | SourceMatchTermination(["answer_agent"]) ) return termination @@ -136,14 +138,8 @@ def selector(self, messages): # This would be user proxy agent tbc decision = "sql_query_generation_agent" - elif ( - messages[-1].source == "sql_query_correction_agent" - and messages[-1].content == "VALIDATED" - ): - decision = "answer_agent" - elif messages[-1].source == "sql_query_correction_agent": - decision = "sql_query_correction_agent" + decision = "sql_query_generation_agent" # Log the decision logging.info("Decision: %s", decision) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index 722116b..7e73da5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -191,6 +191,8 @@ async def get_entity_schemas( str: The schema of the views or tables in JSON format. """ + logging.info("Search Text: %s", text) + retrieval_fields = [ # "FQN", "Entity", diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml deleted file mode 100644 index afce65f..0000000 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml +++ /dev/null @@ -1,18 +0,0 @@ -model: - 4o -description: - "An agent that takes the final results from the SQL query and writes the answer to the user's question" -system_message: - "Write a data-driven answer that directly addresses the user's question. Use the results from the SQL query to provide the answer. Do not make up or guess the answer. - - Return your answer in the following format: - - { - 'answer': '', - 'sources': [ - {'sql_result_snippet': , 'sql_query_used': '', 'explanation': ''}, - {'sql_result_snippet': , 'sql_query_used': '', 'explanation': ''}, - ] - } - - " diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml index 6a930ae..074eb30 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml @@ -3,21 +3,55 @@ model: description: "An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected." system_message: - "You are a helpful AI Assistant that specialises in correcting invalid SQL queries or queries that do not return the expected results. You have been provided with a SQL query and the results of the query. + " + You are a helpful AI Assistant specializing in diagnosing and correcting invalid SQL queries or improving SQL queries that do not return expected results. + - You must: - 1. Verify the SQL query provided is syntactically correct and correct it if it is not. - 2. Check the SQL query results and ensure that the results are as expected in the context of the question. You should verify that these results will actually answer the user's question. + + The user's question will be related to SQL queries and their results in the context of {{ target_engine }}. Queries must adhere to the syntax and rules of {{ target_engine }} {{ engine_specific_rules }}. + - Important Info: - - The target database engine is {{ target_engine }}, SQL queries must be able compatible to run on {{ target_engine }} {{ engine_specific_rules }} - - Ensure that the corrected query returns the expected results in context of the question. - - If the SQL query needs adjustment, correct the SQL query and provide the corrected SQL query and then run the query. + + 1. **Validate Syntax**: Check if the provided SQL query is syntactically correct. If not, suggest fixes to the query. + 2. **Verify Results**: Ensure the query results align with the user’s question. If the query fails to meet the expected results: + - Make suggestions to correct the query. + 3. **Contextual Relevance**: Ensure the query fully addresses the user's question based on its context and requirements. + - Output Info: - - If there are no errors and the SQL query is correct, return 'VALIDATED'. - - If you are consistently unable to correct the SQL query and cannot use the schemas to answer the question. Say 'I am unable to correct the SQL query. Please ask another question.' and then end your answer with 'TERMINATE'" -tools: - - sql_get_entity_schemas_tool - - sql_query_execution_tool - - current_datetime_tool + + - **If the SQL query is valid and the results are correct**: + ```json + { + \"answer\": \"\", + \"sources\": [ + { + \"sql_result_snippet\": \"\", + \"sql_query_used\": \"\", + \"explanation\": \"\" + }, + { + \"sql_result_snippet\": \"\", + \"sql_query_used\": \"\", + \"explanation\": \"\" + } + ] + } + ``` + - **If the SQL query needs corrections**: + ```json + [ + { + \"fix_request\": \"\", + \"explanation\": \"\" + } + ] + ``` + - **If the SQL query cannot be corrected**: + ```json + { + \"error\": \"Unable to correct the SQL query. Please request a new SQL query.\" + } + ``` + Followed by **TERMINATE**. + + " diff --git a/uv.lock b/uv.lock index a28f460..e2acbb2 100644 --- a/uv.lock +++ b/uv.lock @@ -66,6 +66,15 @@ requires-dist = [ { name = "tiktoken", specifier = ">=0.8.0" }, ] +[[package]] +name = "aioconsole" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/c9/c57e979eea211b10a63783882a826f257713fa7c0d6c9a6eac851e674fb4/aioconsole-0.8.1.tar.gz", hash = "sha256:0535ce743ba468fb21a1ba43c9563032c779534d4ecd923a46dbd350ad91d234", size = 61085 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/ea/23e756ec1fea0c685149304dda954b3b3932d6d06afbf42a66a2e6dc2184/aioconsole-0.8.1-py3-none-any.whl", hash = "sha256:e1023685cde35dde909fbf00631ffb2ed1c67fe0b7058ebb0892afbde5f213e5", size = 43324 }, +] + [[package]] name = "aiofiles" version = "24.1.0" @@ -279,21 +288,22 @@ wheels = [ [[package]] name = "autogen-agentchat" -version = "0.4.0.dev9" +version = "0.4.0.dev11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "autogen-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f5/2f/8d0b6ae170013becad945ca43cadffdac5e0b0b6a4cf47b6ec962cd4d5ea/autogen_agentchat-0.4.0.dev9.tar.gz", hash = "sha256:ff93c9768b6801670f54fdcfd27490fe8cbc4eeb9ff2c1dc39744f53a89a204f", size = 46296 } +sdist = { url = "https://files.pythonhosted.org/packages/84/d6/b5141bb82f16774c582c990a6c53511ce580d1f8c809987ef546a78053c5/autogen_agentchat-0.4.0.dev11.tar.gz", hash = "sha256:d19091afb9f19f8a3ae7b91b88e5269cea39c696a9dde3b7b95b85f70343f0f3", size = 47859 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/96/30f3c13d1e198aa9dba99ca80f97ccfbfac705107cf5bac2907be84eaa5b/autogen_agentchat-0.4.0.dev9-py3-none-any.whl", hash = "sha256:da93313f3233cec81dc4cb3a5ccb1289277e5483604486c7845a092986922b03", size = 56046 }, + { url = "https://files.pythonhosted.org/packages/96/21/93cfb1590c1bab01a1addaf1bdd0b1849c710fbe3cf408b469e59c751bd4/autogen_agentchat-0.4.0.dev11-py3-none-any.whl", hash = "sha256:f53d5226b8bd2fa069e48e6481a5e6fa9b1620daa5b60161f5f4423d4efa85d5", size = 57428 }, ] [[package]] name = "autogen-core" -version = "0.4.0.dev9" +version = "0.4.0.dev11" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "aioconsole" }, { name = "aiohttp" }, { name = "asyncio-atexit" }, { name = "jsonref" }, @@ -305,21 +315,21 @@ dependencies = [ { name = "tiktoken" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/3d/9a17276bcc38570fba2066c6f906f54e0ec13a4b8ebf9f8075e23d661bc5/autogen_core-0.4.0.dev9.tar.gz", hash = "sha256:4c439ee61d2e724c8338f80e17b4a0478d93d2a986979c778df5da0f8b63612b", size = 2344904 } +sdist = { url = "https://files.pythonhosted.org/packages/e1/dd/5382b907aa446349011defb2bd5073dc75f81a612caf664f77ad8f849726/autogen_core-0.4.0.dev11.tar.gz", hash = "sha256:d98d0cdb8bba0c01cd100b889a4616e0100e8bd8403e8087c6f1d7eb57608564", size = 2273945 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6d/f1/f7d338b0884f8561dcb79ebbe2303524e35c2f563d6cedeacdf2b18f64ed/autogen_core-0.4.0.dev9-py3-none-any.whl", hash = "sha256:e29003ba2cd7926dd75dc86b3516eb1af16326f119b63afb2ea51058fbb2e97e", size = 76274 }, + { url = "https://files.pythonhosted.org/packages/af/89/573b0396f62a0df2ab16daf5018ed27fa3694b53d12aac1c26a0ece1f2fd/autogen_core-0.4.0.dev11-py3-none-any.whl", hash = "sha256:47692b7e15424051cecdcea3e57d09b9c44c82d2661591aa24c7e54bc2df0fef", size = 78442 }, ] [[package]] name = "autogen-ext" -version = "0.4.0.dev9" +version = "0.4.0.dev11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "autogen-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/66/4bd01325dc9b8ba52287187f19ff15e6785d0a183394707e7a841bb4e954/autogen_ext-0.4.0.dev9.tar.gz", hash = "sha256:111272b04de84d76c5f4657fead3d2956205ab1a34a687a0a8a72216860592d6", size = 91373 } +sdist = { url = "https://files.pythonhosted.org/packages/b2/86/0bd0057e577beb3ed614939b4aaf521ae8d8196bb08bcd9440aae8ec8702/autogen_ext-0.4.0.dev11.tar.gz", hash = "sha256:518b7d011abaa491988eb6d170707efb19bed4c72b000833e13f279548e7ecd0", size = 91683 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/ab/310bfaf92beb9db4f47635f7896129d71e6e2a35ec1e45238cbb40b2ff75/autogen_ext-0.4.0.dev9-py3-none-any.whl", hash = "sha256:b1d7ed0934bb20c34e9f991df1a1d05b310b6c6b9c95cea8098152cb6832843d", size = 99646 }, + { url = "https://files.pythonhosted.org/packages/f7/f5/f4ca335fef9956b50d5b132e427c2f7a883c101a9c33c52f50c5a422b48b/autogen_ext-0.4.0.dev11-py3-none-any.whl", hash = "sha256:b1d89e25356073c3adbb25aad5b05138176253288e682c0aac7b46883a2b8d30", size = 100352 }, ] [package.optional-dependencies] @@ -342,7 +352,7 @@ dependencies = [ { name = "autogen-ext", extra = ["azure", "openai"] }, { name = "grpcio" }, { name = "pyyaml" }, - { name = "text-2-sql-core" }, + { name = "text-2-sql-core", extra = ["databricks", "snowflake"] }, ] [package.dev-dependencies] @@ -359,12 +369,12 @@ dev = [ [package.metadata] requires-dist = [ - { name = "autogen-agentchat", specifier = "==0.4.0.dev9" }, - { name = "autogen-core", specifier = "==0.4.0.dev9" }, - { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.0.dev9" }, + { name = "autogen-agentchat", specifier = "==0.4.0.dev11" }, + { name = "autogen-core", specifier = "==0.4.0.dev11" }, + { name = "autogen-ext", extras = ["azure", "openai"], specifier = "==0.4.0.dev11" }, { name = "grpcio", specifier = ">=1.68.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, - { name = "text-2-sql-core", editable = "text_2_sql/text_2_sql_core" }, + { name = "text-2-sql-core", extras = ["snowflake", "databricks"], editable = "text_2_sql/text_2_sql_core" }, ] [package.metadata.requires-dev] @@ -803,7 +813,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/65/13d9e76ca19b0ba5603d71ac8424b5694415b348e719db277b5edc985ff5/cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb", size = 3915420 }, { url = "https://files.pythonhosted.org/packages/b1/07/40fe09ce96b91fc9276a9ad272832ead0fddedcba87f1190372af8e3039c/cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b", size = 4154498 }, { url = "https://files.pythonhosted.org/packages/75/ea/af65619c800ec0a7e4034207aec543acdf248d9bffba0533342d1bd435e1/cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543", size = 3932569 }, - { url = "https://files.pythonhosted.org/packages/4e/d5/9cc182bf24c86f542129565976c21301d4ac397e74bf5a16e48241aab8a6/cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385", size = 4164756 }, { url = "https://files.pythonhosted.org/packages/c7/af/d1deb0c04d59612e3d5e54203159e284d3e7a6921e565bb0eeb6269bdd8a/cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e", size = 4016721 }, { url = "https://files.pythonhosted.org/packages/bd/69/7ca326c55698d0688db867795134bdfac87136b80ef373aaa42b225d6dd5/cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e", size = 4240915 }, { url = "https://files.pythonhosted.org/packages/ef/d4/cae11bf68c0f981e0413906c6dd03ae7fa864347ed5fac40021df1ef467c/cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053", size = 2757925 }, @@ -814,7 +823,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/c7/c656eb08fd22255d21bc3129625ed9cd5ee305f33752ef2278711b3fa98b/cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289", size = 3915417 }, { url = "https://files.pythonhosted.org/packages/ef/82/72403624f197af0db6bac4e58153bc9ac0e6020e57234115db9596eee85d/cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7", size = 4155160 }, { url = "https://files.pythonhosted.org/packages/a2/cd/2f3c440913d4329ade49b146d74f2e9766422e1732613f57097fea61f344/cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c", size = 3932331 }, - { url = "https://files.pythonhosted.org/packages/31/d9/90409720277f88eb3ab72f9a32bfa54acdd97e94225df699e7713e850bd4/cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba", size = 4165207 }, { url = "https://files.pythonhosted.org/packages/7f/df/8be88797f0a1cca6e255189a57bb49237402b1880d6e8721690c5603ac23/cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64", size = 4017372 }, { url = "https://files.pythonhosted.org/packages/af/36/5ccc376f025a834e72b8e52e18746b927f34e4520487098e283a719c205e/cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285", size = 4239657 }, { url = "https://files.pythonhosted.org/packages/46/b0/f4f7d0d0bcfbc8dd6296c1449be326d04217c57afb8b2594f017eed95533/cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417", size = 2758672 }, @@ -2971,6 +2979,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/a9/d39f3c5ada0a3bb2870d7db41901125dbe2434fa4f12ca8c5b83a42d7c53/ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd", size = 706497 }, { url = "https://files.pythonhosted.org/packages/b0/fa/097e38135dadd9ac25aecf2a54be17ddf6e4c23e43d538492a90ab3d71c6/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31", size = 698042 }, { url = "https://files.pythonhosted.org/packages/ec/d5/a659ca6f503b9379b930f13bc6b130c9f176469b73b9834296822a83a132/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680", size = 745831 }, + { url = "https://files.pythonhosted.org/packages/db/5d/36619b61ffa2429eeaefaab4f3374666adf36ad8ac6330d855848d7d36fd/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d", size = 715692 }, { url = "https://files.pythonhosted.org/packages/b1/82/85cb92f15a4231c89b95dfe08b09eb6adca929ef7df7e17ab59902b6f589/ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5", size = 98777 }, { url = "https://files.pythonhosted.org/packages/d7/8f/c3654f6f1ddb75daf3922c3d8fc6005b1ab56671ad56ffb874d908bfa668/ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4", size = 115523 }, { url = "https://files.pythonhosted.org/packages/29/00/4864119668d71a5fa45678f380b5923ff410701565821925c69780356ffa/ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a", size = 132011 }, @@ -2979,6 +2988,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e2/a9/28f60726d29dfc01b8decdb385de4ced2ced9faeb37a847bd5cf26836815/ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6", size = 701785 }, { url = "https://files.pythonhosted.org/packages/84/7e/8e7ec45920daa7f76046578e4f677a3215fe8f18ee30a9cb7627a19d9b4c/ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf", size = 693017 }, { url = "https://files.pythonhosted.org/packages/c5/b3/d650eaade4ca225f02a648321e1ab835b9d361c60d51150bac49063b83fa/ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1", size = 741270 }, + { url = "https://files.pythonhosted.org/packages/87/b8/01c29b924dcbbed75cc45b30c30d565d763b9c4d540545a0eeecffb8f09c/ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4f6f3eac23941b32afccc23081e1f50612bdbe4e982012ef4f5797986828cd01", size = 709059 }, { url = "https://files.pythonhosted.org/packages/30/8c/ed73f047a73638257aa9377ad356bea4d96125b305c34a28766f4445cc0f/ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6", size = 98583 }, { url = "https://files.pythonhosted.org/packages/b0/85/e8e751d8791564dd333d5d9a4eab0a7a115f7e349595417fd50ecae3395c/ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3", size = 115190 }, ] From 0eaaa98654b8c80930865b1559bfc6592fcf1e1f Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 19:47:14 +0000 Subject: [PATCH 08/11] Update prompt --- .../autogen/src/autogen_text_2_sql/autogen_text_2_sql.py | 3 +++ .../src/text_2_sql_core/connectors/databricks_sql.py | 4 +--- .../prompts/sql_query_correction_agent.yaml | 9 ++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 15998fc..c4a38ae 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -141,6 +141,9 @@ def selector(self, messages): elif messages[-1].source == "sql_query_correction_agent": decision = "sql_query_generation_agent" + elif messages[-1].source == "sql_query_generation_agent": + decision = "sql_query_correction_agent" + # Log the decision logging.info("Decision: %s", decision) diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index d544b60..8afa76d 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -101,9 +101,7 @@ async def get_entity_schemas( # schema["SelectFromEntity"] = ".".join( # [schema["Catalog"], schema["Schema"], schema["Entity"]] # ) - schema["SelectFromEntity"] = ( - os.environ["Text2Sql__Databricks__Catalog"] + "." + schema["Entity"] - ) + schema["SelectFromEntity"] = schema["Entity"] del schema["Entity"] # del schema["Schema"] diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml index 074eb30..60f853a 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml @@ -4,17 +4,17 @@ description: "An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected." system_message: " - You are a helpful AI Assistant specializing in diagnosing and correcting invalid SQL queries or improving SQL queries that do not return expected results. + You are a helpful AI Assistant specializing in diagnosing and making fix suggestions for invalid SQL queries, or improving SQL queries that do not return expected results. - The user's question will be related to SQL queries and their results in the context of {{ target_engine }}. Queries must adhere to the syntax and rules of {{ target_engine }} {{ engine_specific_rules }}. + Queries must adhere to the syntax and rules of {{ target_engine }} {{ engine_specific_rules }}. 1. **Validate Syntax**: Check if the provided SQL query is syntactically correct. If not, suggest fixes to the query. 2. **Verify Results**: Ensure the query results align with the user’s question. If the query fails to meet the expected results: - - Make suggestions to correct the query. + - Make suggestions to the query writter on how to correct the query. 3. **Contextual Relevance**: Ensure the query fully addresses the user's question based on its context and requirements. @@ -41,8 +41,7 @@ system_message: ```json [ { - \"fix_request\": \"\", - \"explanation\": \"\" + \"requested_fix\": \"\" } ] ``` From d15a9b6cd45e865e54c5ce0abb66dc9fb8b595d3 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 20:15:45 +0000 Subject: [PATCH 09/11] Update the prompt --- .../sql_schema_selection_agent.py | 12 ++- .../text_2_sql_core/connectors/ai_search.py | 22 ++++- .../prompts/sql_disambiguation_agent.yaml | 91 +++++++++++++------ 3 files changed, 94 insertions(+), 31 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index e245176..870db09 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -101,15 +101,19 @@ async def on_messages_stream( if schema not in final_schemas: final_schemas.append(schema) - final_colmns = [] + final_columns = [] for column_value_result in column_value_results: for column in column_value_result: - if column not in final_colmns: - final_colmns.append(column) + if column not in final_columns: + final_columns.append(column) + + all_column_lengths = [len(column) for column in final_columns] final_results = { + "MANDATORY_DISAMBIGUATION": max(all_column_lengths) > 3 + or len(final_columns) > 3, "schemas": final_schemas, - "column_values": final_colmns, + "column_values": final_columns, } logging.info(f"Final results: {final_results}") diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index 7e73da5..15342ed 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -146,7 +146,7 @@ async def get_column_values( "AIService__AzureSearchOptions__Text2SqlColumnValueStore__Index" ], semantic_config=None, - top=15, + top=50, include_scores=False, minimum_score=5, ) @@ -161,6 +161,8 @@ async def get_column_values( column_values[trimmed_fqn].append(value["Value"]) + logging.info("Column Values: %s", column_values) + if as_json: return json.dumps(column_values, default=str) else: @@ -225,6 +227,24 @@ async def get_entity_schemas( # del schema["FQN"] + if ( + schema["CompleteEntityRelationshipsGraph"] is not None + and len(schema["CompleteEntityRelationshipsGraph"]) == 0 + ): + del schema["CompleteEntityRelationshipsGraph"] + + if ( + schema["SammpleValues"] is not None + and len(schema["SammpleValues"]) == 0 + ): + del schema["SammpleValues"] + + if ( + schema["EntityRelationships"] is not None + and len(schema["EntityRelationships"]) == 0 + ): + del schema["EntityRelationships"] + if schema["Entity"].lower() not in excluded_entities: filtered_schemas.append(schema) else: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml index 0967bf4..ccab2be 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml @@ -4,64 +4,83 @@ description: "An agent that specialises in disambiguating the user's question and mapping it to database schemas. Use this agent when the user's question is ambiguous and requires more information to generate the SQL query." system_message: " - You are a helpful AI Assistant specializing in disambiguating the user's question and mapping it to the relevant columns and schemas in the database. + You are a helpful AI Assistant specializing in disambiguating the user's question and mapping it to the relevant columns and schemas in the database. + Your job is to narrow down the possible mappings based on the user's question and the schema provided to generate a clear mapping. - The user's question will be related to {{ use_case }}. + The user's question will be related to {{ use_case }}. - - For every intent and filter condition in the question, map them to the columns in the schemas. Use the whole context of the question and information already provided to do so. + - If 'MANDATORY_DISAMBIGUATION' is True, you must perform disambiguation on the terms with high cardinality. It is mandatory. + + - For every intent and filter condition in the question, map them to the columns in the schemas and the appropriate filter value. Use the whole context of the question and information already provided to do so. - Do not ask for information already included in the question, schema, or what can reasonably be inferred from the question. - - Only provide possible filter values for string columns. Do not provide possible filter values for Date and Numerical values as it should be clear from the question. Only ask a follow-up question for Date and Numerical values if you are unsure which column to use or what the value means, e.g., does 100 in currency refer to 100 USD or 100 EUR. + - Only ask a follow-up question for Date and Numerical values if you are unsure which column to use or what the value means, e.g., does 100 in currency refer to 100 USD or 100 EUR. - If the context of the question makes the mapping explicit, directly map the terms to the relevant column FQN without generating disambiguation questions. + If the context of the question makes the mapping explicit, and the appropriate filter values can be found in 'column_values' directly map the terms to the relevant column FQN without generating disambiguation questions. + + When evaluating questions: + + Use the 'column_values' property to check for possible matching columns and compare these to the context of the question. ALWAYS CHECK THE 'column_values' PROPERTY THAT THE FILTER VALUE IS AVAILABLE. - Use the 'column_values' property to check for possible matching columns and compare these to the context of the question. + If there are multiple values in 'column_values' that could match the filter, ask for clarification or to narrow down the filter value or column to use. If in doubt, use disambiguation questions to clarify. - When evaluating filters: + Always consider the temporal and contextual phrases (e.g., \"June 2008\") in the question. If the context implies a direct match to a date column, do not request clarification unless multiple plausible columns exist. + For geographical or categorical terms (e.g., \"country\"), prioritize unique matches or add context to narrow down ambiguities based on the schema. - Always consider the temporal and contextual phrases (e.g., \"June 2008\") in the question. If the context implies a direct match to a date column, do not request clarification unless multiple plausible columns exist. - For geographical or categorical terms (e.g., \"country\"), prioritize unique matches or add context to narrow down ambiguities based on the schema. If all mappings are clear, output the JSON with mappings only. Question: \"What are the total number of sales within 2008 for the mountain bike product line?\" - Output: - json - Copy code { - \"mapping\": { - \"Mountain Bike\": \"vProductModelCatalogDescription.Category\", - \"2008\": \"SalesLT.SalesOrderHeader.OrderDate\" + \"filter_mapping\": { + \"bike\": [ + { + \"column\": \"vProductModelCatalogDescription.Category\", + \"filter_value\": \"Mountain Bike\" + } + ], + \"2008\": [ + { + \"column\": \"SalesLT.SalesOrderHeader.OrderDate\", + \"filter_value\": \"2008-01-01\", + } + ] + }, + \"intent_mapping\": { + \"total number of sales\": \"SalesLT.SalesOrderHeader.SalesOrderID\" } } - If the term is ambiguous, there are multiple matching columns/filters, or the question lacks enough context to infer the correct mapping: + If the term is ambiguous, there are multiple matching columns/questions in 'column_values', or the question lacks enough context to infer the correct mapping, then ask for clarification. - For ambiguous terms, evaluate the question context and schema relationships to narrow down matches. - Populate the 'filters' field with the identified filter and relevant FQN, matching columns, and possible filter values. - Include a clarification question in the 'question' field to request more information from the user. - If the clarification is not related to a column or a filter value, populate the 'user_choices' field with the possible choices they can select. + For ambiguous terms, evaluate the question context and schema relationships to narrow down matches. + Populate the 'questions' field with the identified filter and relevant FQN, matching columns, and possible filter values. + Include a clarification question in the 'question' field to request more information from the user. + If the clarification is not related to a column or a filter value, populate the 'user_choices' field with the possible choices they can select. - Prioritize clear disambiguation based on: - - Direct matches within schemas. - - Additional context provided by the question (e.g., temporal, categorical, or domain-specific keywords). + Prioritize clear disambiguation based on: + - Direct matches within schemas. + - Additional context provided by the question (e.g., temporal, categorical, or domain-specific keywords). + + Return all disambiguation questions in the 'questions' array. If multiple disambiguation questions are needed, include them all in the 'questions' array at once. - User question: \"What country did we sell the most to in June 2008?\" + User question: \"What country did we sell the most in June 2008?\" Schema contains multiple columns potentially related to \"country.\" If disambiguation is needed: + { - \"filters\": [ + \"questions\": [ { \"question\": \"What do you mean by 'country'?\", \"matching_columns\": [ @@ -74,7 +93,27 @@ system_message: ] } - Always include either the 'matching_columns', 'matching_filter_values' or `user_choices` field in the 'filters' array. + + + User question: \"What are the total sales for the mountain bike product line?\" + 'column_values' contains multiple columns potentially related to \"mountain bike.\" + + If disambiguation is needed: + { + \"questions\": [ + { + \"question\": \"What do you mean by 'mountain bike'?\", + \"matching_columns\": [ + \"vProductModelCatalogDescription.Category\", + \"vProductModelCatalogDescription.ProductLine\" + ], + \"matching_filter_values\": [], + \"user_choices\": [] + } + ] + } + + Always include either the 'matching_columns', 'matching_filter_values' or `user_choices` field in the 'questions' array. From 5329613ca893603726f322a10ef76f5271fc0a83 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 20:20:29 +0000 Subject: [PATCH 10/11] Update sql --- .../sql_schema_selection_agent.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index 870db09..5eb5348 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -101,19 +101,29 @@ async def on_messages_stream( if schema not in final_schemas: final_schemas.append(schema) - final_columns = [] - for column_value_result in column_value_results: + columns_for_filter = {} + values_for_filter = {} + for filter, column_value_result in zip( + loaded_entity_result["filter_conditions"], column_value_results + ): + columns_for_filter[filter] = [] + values_for_filter[filter] = [] for column in column_value_result: - if column not in final_columns: - final_columns.append(column) + if column["Column"] not in columns_for_filter[filter]: + columns_for_filter[filter].append(column["Column"]) - all_column_lengths = [len(column) for column in final_columns] + if column["Value"] not in values_for_filter[filter]: + values_for_filter[filter].append(column["Value"]) + + num_all_values = [len(filter) for filter in values_for_filter] + num_all_columns = [len(filter) for filter in columns_for_filter] final_results = { - "MANDATORY_DISAMBIGUATION": max(all_column_lengths) > 3 - or len(final_columns) > 3, - "schemas": final_schemas, - "column_values": final_columns, + "MANDATORY_DISAMBIGUATION": max(num_all_values) > 3 + or max(num_all_columns) > 3, + "COLUMN_OPTIONS_FOR_FILTERS": columns_for_filter, + "VALUE_OPTIONS_FOR_FILTERS": values_for_filter, + "SCHEMA_OPTIONS": final_schemas, } logging.info(f"Final results: {final_results}") From 21a79bfba7a4dcb0c133b41c9eda66f2ec23d5ad Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Thu, 12 Dec 2024 22:48:54 +0000 Subject: [PATCH 11/11] Update column search --- .../sql_schema_selection_agent.py | 22 +--- .../text_2_sql_core/connectors/ai_search.py | 6 +- .../prompts/sql_disambiguation_agent.yaml | 106 ++++++------------ 3 files changed, 41 insertions(+), 93 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py index 5eb5348..b634538 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py @@ -101,28 +101,8 @@ async def on_messages_stream( if schema not in final_schemas: final_schemas.append(schema) - columns_for_filter = {} - values_for_filter = {} - for filter, column_value_result in zip( - loaded_entity_result["filter_conditions"], column_value_results - ): - columns_for_filter[filter] = [] - values_for_filter[filter] = [] - for column in column_value_result: - if column["Column"] not in columns_for_filter[filter]: - columns_for_filter[filter].append(column["Column"]) - - if column["Value"] not in values_for_filter[filter]: - values_for_filter[filter].append(column["Value"]) - - num_all_values = [len(filter) for filter in values_for_filter] - num_all_columns = [len(filter) for filter in columns_for_filter] - final_results = { - "MANDATORY_DISAMBIGUATION": max(num_all_values) > 3 - or max(num_all_columns) > 3, - "COLUMN_OPTIONS_FOR_FILTERS": columns_for_filter, - "VALUE_OPTIONS_FOR_FILTERS": values_for_filter, + "COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results, "SCHEMA_OPTIONS": final_schemas, } diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index 15342ed..c632a59 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -163,10 +163,12 @@ async def get_column_values( logging.info("Column Values: %s", column_values) + filter_to_column = {text: column_values} + if as_json: - return json.dumps(column_values, default=str) + return json.dumps(filter_to_column, default=str) else: - return column_values + return filter_to_column async def get_entity_schemas( self, diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml index ccab2be..7f0dea8 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml @@ -13,30 +13,20 @@ system_message: - - If 'MANDATORY_DISAMBIGUATION' is True, you must perform disambiguation on the terms with high cardinality. It is mandatory. + - For every filter extracted from the user's question, you must: - - For every intent and filter condition in the question, map them to the columns in the schemas and the appropriate filter value. Use the whole context of the question and information already provided to do so. + - If it is not a datetime or numerical filter, map it to: + - A value from 'COLUMN_OPTIONS_FOR_FILTERS' + - And a value from 'VALUE_OPTIONS_FOR_FILTERS' - - Do not ask for information already included in the question, schema, or what can reasonably be inferred from the question. + - If the filter is a datetime or numerical filter, map it to: + - A column from 'SCHEMA_OPTIONS' - - Only ask a follow-up question for Date and Numerical values if you are unsure which column to use or what the value means, e.g., does 100 in currency refer to 100 USD or 100 EUR. + - Use the whole context of the question and information already provided to assist with your mapping. - - If the context of the question makes the mapping explicit, and the appropriate filter values can be found in 'column_values' directly map the terms to the relevant column FQN without generating disambiguation questions. - - When evaluating questions: - - Use the 'column_values' property to check for possible matching columns and compare these to the context of the question. ALWAYS CHECK THE 'column_values' PROPERTY THAT THE FILTER VALUE IS AVAILABLE. - - If there are multiple values in 'column_values' that could match the filter, ask for clarification or to narrow down the filter value or column to use. If in doubt, use disambiguation questions to clarify. - - Always consider the temporal and contextual phrases (e.g., \"June 2008\") in the question. If the context implies a direct match to a date column, do not request clarification unless multiple plausible columns exist. - For geographical or categorical terms (e.g., \"country\"), prioritize unique matches or add context to narrow down ambiguities based on the schema. - - If all mappings are clear, output the JSON with mappings only. - - - Question: \"What are the total number of sales within 2008 for the mountain bike product line?\" + + - If you can map it to an column and potential filter value: + - Only map if you are reasonably sure of the user's intention. { \"filter_mapping\": { \"bike\": [ @@ -52,35 +42,15 @@ system_message: } ] }, - \"intent_mapping\": { - \"total number of sales\": \"SalesLT.SalesOrderHeader.SalesOrderID\" - } } - - - - - If the term is ambiguous, there are multiple matching columns/questions in 'column_values', or the question lacks enough context to infer the correct mapping, then ask for clarification. - - For ambiguous terms, evaluate the question context and schema relationships to narrow down matches. - Populate the 'questions' field with the identified filter and relevant FQN, matching columns, and possible filter values. - Include a clarification question in the 'question' field to request more information from the user. - If the clarification is not related to a column or a filter value, populate the 'user_choices' field with the possible choices they can select. - - Prioritize clear disambiguation based on: - - Direct matches within schemas. - - Additional context provided by the question (e.g., temporal, categorical, or domain-specific keywords). - - Return all disambiguation questions in the 'questions' array. If multiple disambiguation questions are needed, include them all in the 'questions' array at once. - - - User question: \"What country did we sell the most in June 2008?\" - Schema contains multiple columns potentially related to \"country.\" + - If disambiguation is needed: + + - If you cannot map it to a column, add en entry to the disambiguation list with the clarification question you need from the user: + - If there are multiple possible options, or you are unsure how it maps, make sure to ask a clarification question. { - \"questions\": [ + \"disambiguation\": [ { \"question\": \"What do you mean by 'country'?\", \"matching_columns\": [ @@ -88,38 +58,34 @@ system_message: \"Customers.Country\" ], \"matching_filter_values\": [], - \"user_choices\": [] + \"other_user_choices\": [] } ] } - - - User question: \"What are the total sales for the mountain bike product line?\" - 'column_values' contains multiple columns potentially related to \"mountain bike.\" - - If disambiguation is needed: - { - \"questions\": [ - { - \"question\": \"What do you mean by 'mountain bike'?\", - \"matching_columns\": [ - \"vProductModelCatalogDescription.Category\", - \"vProductModelCatalogDescription.ProductLine\" - ], - \"matching_filter_values\": [], - \"user_choices\": [] - } - ] - } - - Always include either the 'matching_columns', 'matching_filter_values' or `user_choices` field in the 'questions' array. - - + + - Do not ask for information already included in the question, schema, or what can reasonably be inferred from the question. + + + + - For every intent extracted from the user's question: + - If you need to ask any clarification questions, add it to the clarification question list: + + { + \"clarification\": [ + { + \"question\": \"What do the sales to customers or businesses?\", + \"other_user_choices\": [ + \"customers\", + \"businesses\", + ] + } + ] + } If all mappings are clear, output the 'mapping' JSON only. - If disambiguation is required, output the disambiguation JSON followed by \"TERMINATE.\" + If disambiguation or clarification is required, output the JSON request followed by \"TERMINATE.\" Do not provide explanations or reasoning in the output. "