From 3ea1ffe86eea4389eba43219ec7ec64af26d41e8 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 29 Nov 2024 16:41:17 +0000 Subject: [PATCH 1/4] Add current datetime --- .../agents/llm_agents/sql_query_correction_agent.yaml | 1 + .../agents/llm_agents/sql_query_generation_agent.yaml | 1 + text_2_sql/autogen/utils/llm_agent_creator.py | 6 ++++++ 3 files changed, 8 insertions(+) diff --git a/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml index 3641472..5e9cc9c 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml @@ -20,3 +20,4 @@ system_message: tools: - sql_get_entity_schemas_tool - sql_query_execution_tool + - current_datetime_tool diff --git a/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml index 6a9e92c..f514453 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml @@ -39,3 +39,4 @@ tools: - sql_query_execution_tool - sql_get_entity_schemas_tool - sql_query_validation_tool + - current_datetime_tool diff --git a/text_2_sql/autogen/utils/llm_agent_creator.py b/text_2_sql/autogen/utils/llm_agent_creator.py index d72c887..3889ed1 100644 --- a/text_2_sql/autogen/utils/llm_agent_creator.py +++ b/text_2_sql/autogen/utils/llm_agent_creator.py @@ -6,6 +6,7 @@ from utils.sql import query_execution, get_entity_schemas, query_validation from utils.models import MINI_MODEL from jinja2 import Template +from datetime import datetime class LLMAgentCreator: @@ -40,6 +41,11 @@ def get_tool(cls, tool_name): query_validation, description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.", ) + elif tool_name == "current_datetime_tool": + return FunctionTool( + lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + description="Gets the current date and time.", + ) else: raise ValueError(f"Tool {tool_name} not found") From 3ad3a5e7d2cd3570b12a5ffea738b164703c642d Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 29 Nov 2024 16:50:38 +0000 Subject: [PATCH 2/4] Update agent logic --- ...n 5 - Agentic Vector Based Text2SQL.ipynb} | 0 .../agents/llm_agents/answer_agent.yaml | 2 +- .../question_decomposition_agent.yaml | 2 +- .../sql_query_correction_agent.yaml | 2 +- .../sql_query_generation_agent.yaml | 2 +- .../sql_schema_selection_agent.yaml | 2 +- text_2_sql/autogen/utils/llm_agent_creator.py | 67 ++++++++++++++++--- text_2_sql/autogen/utils/models.py | 17 ++++- text_2_sql/autogen/utils/sql.py | 12 +++- 9 files changed, 89 insertions(+), 17 deletions(-) rename text_2_sql/autogen/{agentic_text_2_sql.ipynb => Iteration 5 - Agentic Vector Based Text2SQL.ipynb} (100%) diff --git a/text_2_sql/autogen/agentic_text_2_sql.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb similarity index 100% rename from text_2_sql/autogen/agentic_text_2_sql.ipynb rename to text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb diff --git a/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml b/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml index 61b5893..d70f0f3 100644 --- a/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that takes the final results from the SQL query and writes the answer to the user's question" system_message: diff --git a/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml b/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml index 57e139d..5878ef5 100644 --- a/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query. Only use if the user's question is too complex to be answered in one SQL query." system_message: diff --git a/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml index 5e9cc9c..6a930ae 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected." system_message: diff --git a/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml index f514453..76ee8bf 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that can generate SQL queries once given the schema and the user's question. It will run the SQL query to fetch the results. Use this agent after the SQL Schema Selection Agent has selected the correct schema." system_message: diff --git a/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml index 3876faf..ccec4fe 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that can take a user's question and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. diff --git a/text_2_sql/autogen/utils/llm_agent_creator.py b/text_2_sql/autogen/utils/llm_agent_creator.py index 3889ed1..9a335ae 100644 --- a/text_2_sql/autogen/utils/llm_agent_creator.py +++ b/text_2_sql/autogen/utils/llm_agent_creator.py @@ -4,28 +4,56 @@ from autogen_core.components.tools import FunctionTool from autogen_agentchat.agents import AssistantAgent from utils.sql import query_execution, get_entity_schemas, query_validation -from utils.models import MINI_MODEL +from utils.models import GPT_4O_MINI_MODEL, GPT_4O_MODEL from jinja2 import Template from datetime import datetime +from autogen_ext.models import AzureOpenAIChatCompletionClient class LLMAgentCreator: @classmethod - def load_agent_file(cls, name): + def load_agent_file(cls, name: str) -> dict: + """Loads the agent file based on the agent name. + + Args: + ---- + name (str): The name of the agent to load. + + Returns: + ------- + dict: The agent file.""" with open(f"./agents/llm_agents/{name.lower()}.yaml", "r") as file: file = yaml.safe_load(file) return file @classmethod - def get_model(cls, model_name): - if model_name == "gpt-4o-mini": - return MINI_MODEL + def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient: + """Retrieves the model based on the model name. + + Args: + ---- + model_name (str): The name of the model to retrieve. + + Returns: + AzureOpenAIChatCompletionClient: The model client.""" + if model_name == "4o-mini": + return GPT_4O_MINI_MODEL + elif model_name == "4o": + return GPT_4O_MODEL else: raise ValueError(f"Model {model_name} not found") @classmethod - def get_tool(cls, tool_name): + def get_tool(cls, tool_name: str) -> FunctionTool: + """Retrieves the tool based on the tool name. + + Args: + ---- + tool_name (str): The name of the tool to retrieve. + + Returns: + FunctionTool: The tool.""" if tool_name == "sql_query_execution_tool": return FunctionTool( query_execution, @@ -50,7 +78,20 @@ def get_tool(cls, tool_name): raise ValueError(f"Tool {tool_name} not found") @classmethod - def get_property_and_render_parameters(cls, agent_file, property, parameters): + def get_property_and_render_parameters( + cls, agent_file: dict, property: str, parameters: dict + ) -> str: + """Gets the property from the agent file and renders the parameters. + + Args: + ---- + agent_file (dict): The agent file. + property (str): The property to retrieve. + parameters (dict): The parameters to render. + + Returns: + ------- + str: The rendered property.""" unrendered_parameters = agent_file[property] rendered_template = Template(unrendered_parameters).render(parameters) @@ -58,7 +99,17 @@ def get_property_and_render_parameters(cls, agent_file, property, parameters): return rendered_template @classmethod - def create(cls, name: str, **kwargs): + def create(cls, name: str, **kwargs) -> AssistantAgent: + """Creates an assistant agent based on the agent name. + + Args: + ---- + name (str): The name of the agent to create. + **kwargs: The parameters to render. + + Returns: + ------- + AssistantAgent: The assistant agent.""" agent_file = cls.load_agent_file(name) tools = [] diff --git a/text_2_sql/autogen/utils/models.py b/text_2_sql/autogen/utils/models.py index 1d4edbb..0d7e003 100644 --- a/text_2_sql/autogen/utils/models.py +++ b/text_2_sql/autogen/utils/models.py @@ -13,7 +13,7 @@ # DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" # ) -MINI_MODEL = AzureOpenAIChatCompletionClient( +GPT_4O_MINI_MODEL = AzureOpenAIChatCompletionClient( azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"], model=os.environ["OpenAI__MiniCompletionDeployment"], api_version="2024-08-01-preview", @@ -27,3 +27,18 @@ "json_output": True, }, ) + +GPT_4O_MODEL = AzureOpenAIChatCompletionClient( + azure_deployment=os.environ["OpenAI__CompletionDeployment"], + model=os.environ["OpenAI__CompletionDeployment"], + api_version="2024-08-01-preview", + azure_endpoint=os.environ["OpenAI__Endpoint"], + # # Optional if you choose key-based authentication. + # azure_ad_token_provider=token_provider, + api_key=os.environ["OpenAI__ApiKey"], # For key-based authentication. + model_capabilities={ + "vision": False, + "function_calling": True, + "json_output": True, + }, +) diff --git a/text_2_sql/autogen/utils/sql.py b/text_2_sql/autogen/utils/sql.py index 4f340c1..ce5027d 100644 --- a/text_2_sql/autogen/utils/sql.py +++ b/text_2_sql/autogen/utils/sql.py @@ -138,7 +138,7 @@ async def fetch_queries_from_cache(question: str) -> str: ) if len(cached_schemas) == 0: - return {"cached_questions_and_schemas": None} + return {"contains_pre_run_results": False, "cached_questions_and_schemas": None} logging.info("Cached schemas: %s", cached_schemas) if PRE_RUN_QUERY_CACHE and len(cached_schemas) > 0: @@ -165,6 +165,12 @@ async def fetch_queries_from_cache(question: str) -> str: "schemas": sql_query["Schemas"], } - return {"cached_questions_and_schemas": query_result_store} + return { + "contains_pre_run_results": True, + "cached_questions_and_schemas": query_result_store, + } - return {"cached_questions_and_schemas": cached_schemas} + return { + "contains_pre_run_results": False, + "cached_questions_and_schemas": cached_schemas, + } From d456aced50cc17d36fc092be1774288173ccee6e Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 29 Nov 2024 17:00:07 +0000 Subject: [PATCH 3/4] Improve building of autogen --- text_2_sql/autogen/agentic_text_2_sql.py | 18 ++-- text_2_sql/autogen/utils/llm_agent_creator.py | 22 +---- text_2_sql/autogen/utils/llm_model_creator.py | 86 +++++++++++++++++++ text_2_sql/autogen/utils/models.py | 44 ---------- 4 files changed, 99 insertions(+), 71 deletions(-) create mode 100644 text_2_sql/autogen/utils/llm_model_creator.py delete mode 100644 text_2_sql/autogen/utils/models.py diff --git a/text_2_sql/autogen/agentic_text_2_sql.py b/text_2_sql/autogen/agentic_text_2_sql.py index 2320e56..bbc6fba 100644 --- a/text_2_sql/autogen/agentic_text_2_sql.py +++ b/text_2_sql/autogen/agentic_text_2_sql.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination from autogen_agentchat.teams import SelectorGroupChat -from utils.models import MINI_MODEL +from utils.llm_model_creator import LLMModelCreator from utils.llm_agent_creator import LLMAgentCreator import logging from agents.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent @@ -86,13 +86,17 @@ def selector(messages): and messages[-1].content is not None ): cache_result = json.loads(messages[-1].content) - if cache_result.get("cached_questions_and_schemas") is not None: + if cache_result.get( + "cached_questions_and_schemas" + ) is not None and cache_result.get("contains_pre_run_results"): decision = "sql_query_correction_agent" + if ( + cache_result.get("cached_questions_and_schemas") is not None + and cache_result.get("contains_pre_run_results") is False + ): + decision = "sql_query_generation_agent" else: - decision = "sql_schema_selection_agent" - - elif messages[-1].source == "sql_query_cache_agent": - decision = "question_decomposition_agent" + decision = "question_decomposition_agent" elif messages[-1].source == "question_decomposition_agent": decomposition_result = json.loads(messages[-1].content) @@ -129,7 +133,7 @@ def agentic_flow(self): agentic_flow = SelectorGroupChat( self.agents, allow_repeated_speaker=False, - model_client=MINI_MODEL, + model_client=LLMModelCreator.get_model("4o-mini"), termination_condition=self.termination_condition, selector_func=AgenticText2Sql.selector, ) diff --git a/text_2_sql/autogen/utils/llm_agent_creator.py b/text_2_sql/autogen/utils/llm_agent_creator.py index 9a335ae..aee3734 100644 --- a/text_2_sql/autogen/utils/llm_agent_creator.py +++ b/text_2_sql/autogen/utils/llm_agent_creator.py @@ -4,10 +4,9 @@ from autogen_core.components.tools import FunctionTool from autogen_agentchat.agents import AssistantAgent from utils.sql import query_execution, get_entity_schemas, query_validation -from utils.models import GPT_4O_MINI_MODEL, GPT_4O_MODEL +from utils.llm_model_creator import LLMModelCreator from jinja2 import Template from datetime import datetime -from autogen_ext.models import AzureOpenAIChatCompletionClient class LLMAgentCreator: @@ -27,23 +26,6 @@ def load_agent_file(cls, name: str) -> dict: return file - @classmethod - def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient: - """Retrieves the model based on the model name. - - Args: - ---- - model_name (str): The name of the model to retrieve. - - Returns: - AzureOpenAIChatCompletionClient: The model client.""" - if model_name == "4o-mini": - return GPT_4O_MINI_MODEL - elif model_name == "4o": - return GPT_4O_MODEL - else: - raise ValueError(f"Model {model_name} not found") - @classmethod def get_tool(cls, tool_name: str) -> FunctionTool: """Retrieves the tool based on the tool name. @@ -120,7 +102,7 @@ def create(cls, name: str, **kwargs) -> AssistantAgent: agent = AssistantAgent( name=name, tools=tools, - model_client=cls.get_model(agent_file["model"]), + model_client=LLMModelCreator.get_model(agent_file["model"]), description=cls.get_property_and_render_parameters( agent_file, "description", kwargs ), diff --git a/text_2_sql/autogen/utils/llm_model_creator.py b/text_2_sql/autogen/utils/llm_model_creator.py new file mode 100644 index 0000000..e62ecfd --- /dev/null +++ b/text_2_sql/autogen/utils/llm_model_creator.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from autogen_ext.models import AzureOpenAIChatCompletionClient +from environment import IdentityType, get_identity_type + +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +import os +import dotenv + +dotenv.load_dotenv() + + +class LLMModelCreator: + @classmethod + def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient: + """Retrieves the model based on the model name. + + Args: + ---- + model_name (str): The name of the model to retrieve. + + Returns: + AzureOpenAIChatCompletionClient: The model client.""" + if model_name == "4o-mini": + return cls.gpt_4o_mini_model() + elif model_name == "4o": + return cls.gpt_4o_model() + else: + raise ValueError(f"Model {model_name} not found") + + @classmethod + def get_authentication_properties(cls) -> dict: + if get_identity_type() == IdentityType.SYSTEM_ASSIGNED: + # Create the token provider + api_key = None + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + elif get_identity_type() == IdentityType.USER_ASSIGNED: + # Create the token provider + api_key = None + token_provider = get_bearer_token_provider( + DefaultAzureCredential( + managed_identity_client_id=os.environ["ClientId"] + ), + "https://cognitiveservices.azure.com/.default", + ) + else: + token_provider = None + api_key = os.environ["OpenAI__ApiKey"] + + return token_provider, api_key + + @classmethod + def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient: + token_provider, api_key = cls.get_authentication_properties() + return AzureOpenAIChatCompletionClient( + azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"], + model=os.environ["OpenAI__MiniCompletionDeployment"], + api_version="2024-08-01-preview", + azure_endpoint=os.environ["OpenAI__Endpoint"], + azure_ad_token_provider=token_provider, + api_key=api_key, + model_capabilities={ + "vision": False, + "function_calling": True, + "json_output": True, + }, + ) + + @classmethod + def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient: + token_provider, api_key = cls.get_authentication_properties() + return AzureOpenAIChatCompletionClient( + azure_deployment=os.environ["OpenAI__CompletionDeployment"], + model=os.environ["OpenAI__CompletionDeployment"], + api_version="2024-08-01-preview", + azure_endpoint=os.environ["OpenAI__Endpoint"], + azure_ad_token_provider=token_provider, + api_key=api_key, + model_capabilities={ + "vision": False, + "function_calling": True, + "json_output": True, + }, + ) diff --git a/text_2_sql/autogen/utils/models.py b/text_2_sql/autogen/utils/models.py deleted file mode 100644 index 0d7e003..0000000 --- a/text_2_sql/autogen/utils/models.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from autogen_ext.models import AzureOpenAIChatCompletionClient - -# from azure.identity import DefaultAzureCredential, get_bearer_token_provider -import os -import dotenv - -dotenv.load_dotenv() - -# # Create the token provider -# token_provider = get_bearer_token_provider( -# DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" -# ) - -GPT_4O_MINI_MODEL = AzureOpenAIChatCompletionClient( - azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"], - model=os.environ["OpenAI__MiniCompletionDeployment"], - api_version="2024-08-01-preview", - azure_endpoint=os.environ["OpenAI__Endpoint"], - # # Optional if you choose key-based authentication. - # azure_ad_token_provider=token_provider, - api_key=os.environ["OpenAI__ApiKey"], # For key-based authentication. - model_capabilities={ - "vision": False, - "function_calling": True, - "json_output": True, - }, -) - -GPT_4O_MODEL = AzureOpenAIChatCompletionClient( - azure_deployment=os.environ["OpenAI__CompletionDeployment"], - model=os.environ["OpenAI__CompletionDeployment"], - api_version="2024-08-01-preview", - azure_endpoint=os.environ["OpenAI__Endpoint"], - # # Optional if you choose key-based authentication. - # azure_ad_token_provider=token_provider, - api_key=os.environ["OpenAI__ApiKey"], # For key-based authentication. - model_capabilities={ - "vision": False, - "function_calling": True, - "json_output": True, - }, -) From 3349a0ae64908bbdf9772aa4e9724ea84deba6cc Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 29 Nov 2024 17:01:25 +0000 Subject: [PATCH 4/4] Update README code --- text_2_sql/autogen/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/text_2_sql/autogen/README.md b/text_2_sql/autogen/README.md index 1114e0c..bc6e876 100644 --- a/text_2_sql/autogen/README.md +++ b/text_2_sql/autogen/README.md @@ -20,7 +20,7 @@ As the query cache is shared between users (no data is stored in the cache), a n ## Provided Notebooks & Scripts -- `./agentic_text_2_sql.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters. +- `./Iteration 5 - Agentic Vector Based Text2SQL.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters. ## Agents