diff --git a/text_2_sql/autogen/agentic_text_2_sql.ipynb b/text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb similarity index 100% rename from text_2_sql/autogen/agentic_text_2_sql.ipynb rename to text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb diff --git a/text_2_sql/autogen/README.md b/text_2_sql/autogen/README.md index 1114e0c..bc6e876 100644 --- a/text_2_sql/autogen/README.md +++ b/text_2_sql/autogen/README.md @@ -20,7 +20,7 @@ As the query cache is shared between users (no data is stored in the cache), a n ## Provided Notebooks & Scripts -- `./agentic_text_2_sql.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters. +- `./Iteration 5 - Agentic Vector Based Text2SQL.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters. ## Agents diff --git a/text_2_sql/autogen/agentic_text_2_sql.py b/text_2_sql/autogen/agentic_text_2_sql.py index 2320e56..bbc6fba 100644 --- a/text_2_sql/autogen/agentic_text_2_sql.py +++ b/text_2_sql/autogen/agentic_text_2_sql.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination from autogen_agentchat.teams import SelectorGroupChat -from utils.models import MINI_MODEL +from utils.llm_model_creator import LLMModelCreator from utils.llm_agent_creator import LLMAgentCreator import logging from agents.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent @@ -86,13 +86,17 @@ def selector(messages): and messages[-1].content is not None ): cache_result = json.loads(messages[-1].content) - if cache_result.get("cached_questions_and_schemas") is not None: + if cache_result.get( + "cached_questions_and_schemas" + ) is not None and cache_result.get("contains_pre_run_results"): decision = "sql_query_correction_agent" + if ( + cache_result.get("cached_questions_and_schemas") is not None + and cache_result.get("contains_pre_run_results") is False + ): + decision = "sql_query_generation_agent" else: - decision = "sql_schema_selection_agent" - - elif messages[-1].source == "sql_query_cache_agent": - decision = "question_decomposition_agent" + decision = "question_decomposition_agent" elif messages[-1].source == "question_decomposition_agent": decomposition_result = json.loads(messages[-1].content) @@ -129,7 +133,7 @@ def agentic_flow(self): agentic_flow = SelectorGroupChat( self.agents, allow_repeated_speaker=False, - model_client=MINI_MODEL, + model_client=LLMModelCreator.get_model("4o-mini"), termination_condition=self.termination_condition, selector_func=AgenticText2Sql.selector, ) diff --git a/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml b/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml index 61b5893..d70f0f3 100644 --- a/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/answer_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that takes the final results from the SQL query and writes the answer to the user's question" system_message: diff --git a/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml b/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml index 57e139d..5878ef5 100644 --- a/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query. Only use if the user's question is too complex to be answered in one SQL query." system_message: diff --git a/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml index 3641472..6a930ae 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected." system_message: @@ -20,3 +20,4 @@ system_message: tools: - sql_get_entity_schemas_tool - sql_query_execution_tool + - current_datetime_tool diff --git a/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml index 6a9e92c..76ee8bf 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that can generate SQL queries once given the schema and the user's question. It will run the SQL query to fetch the results. Use this agent after the SQL Schema Selection Agent has selected the correct schema." system_message: @@ -39,3 +39,4 @@ tools: - sql_query_execution_tool - sql_get_entity_schemas_tool - sql_query_validation_tool + - current_datetime_tool diff --git a/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml b/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml index 3876faf..ccec4fe 100644 --- a/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml +++ b/text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml @@ -1,5 +1,5 @@ model: - gpt-4o-mini + 4o-mini description: "An agent that can take a user's question and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. diff --git a/text_2_sql/autogen/utils/llm_agent_creator.py b/text_2_sql/autogen/utils/llm_agent_creator.py index d72c887..aee3734 100644 --- a/text_2_sql/autogen/utils/llm_agent_creator.py +++ b/text_2_sql/autogen/utils/llm_agent_creator.py @@ -4,27 +4,38 @@ from autogen_core.components.tools import FunctionTool from autogen_agentchat.agents import AssistantAgent from utils.sql import query_execution, get_entity_schemas, query_validation -from utils.models import MINI_MODEL +from utils.llm_model_creator import LLMModelCreator from jinja2 import Template +from datetime import datetime class LLMAgentCreator: @classmethod - def load_agent_file(cls, name): + def load_agent_file(cls, name: str) -> dict: + """Loads the agent file based on the agent name. + + Args: + ---- + name (str): The name of the agent to load. + + Returns: + ------- + dict: The agent file.""" with open(f"./agents/llm_agents/{name.lower()}.yaml", "r") as file: file = yaml.safe_load(file) return file @classmethod - def get_model(cls, model_name): - if model_name == "gpt-4o-mini": - return MINI_MODEL - else: - raise ValueError(f"Model {model_name} not found") + def get_tool(cls, tool_name: str) -> FunctionTool: + """Retrieves the tool based on the tool name. - @classmethod - def get_tool(cls, tool_name): + Args: + ---- + tool_name (str): The name of the tool to retrieve. + + Returns: + FunctionTool: The tool.""" if tool_name == "sql_query_execution_tool": return FunctionTool( query_execution, @@ -40,11 +51,29 @@ def get_tool(cls, tool_name): query_validation, description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.", ) + elif tool_name == "current_datetime_tool": + return FunctionTool( + lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + description="Gets the current date and time.", + ) else: raise ValueError(f"Tool {tool_name} not found") @classmethod - def get_property_and_render_parameters(cls, agent_file, property, parameters): + def get_property_and_render_parameters( + cls, agent_file: dict, property: str, parameters: dict + ) -> str: + """Gets the property from the agent file and renders the parameters. + + Args: + ---- + agent_file (dict): The agent file. + property (str): The property to retrieve. + parameters (dict): The parameters to render. + + Returns: + ------- + str: The rendered property.""" unrendered_parameters = agent_file[property] rendered_template = Template(unrendered_parameters).render(parameters) @@ -52,7 +81,17 @@ def get_property_and_render_parameters(cls, agent_file, property, parameters): return rendered_template @classmethod - def create(cls, name: str, **kwargs): + def create(cls, name: str, **kwargs) -> AssistantAgent: + """Creates an assistant agent based on the agent name. + + Args: + ---- + name (str): The name of the agent to create. + **kwargs: The parameters to render. + + Returns: + ------- + AssistantAgent: The assistant agent.""" agent_file = cls.load_agent_file(name) tools = [] @@ -63,7 +102,7 @@ def create(cls, name: str, **kwargs): agent = AssistantAgent( name=name, tools=tools, - model_client=cls.get_model(agent_file["model"]), + model_client=LLMModelCreator.get_model(agent_file["model"]), description=cls.get_property_and_render_parameters( agent_file, "description", kwargs ), diff --git a/text_2_sql/autogen/utils/llm_model_creator.py b/text_2_sql/autogen/utils/llm_model_creator.py new file mode 100644 index 0000000..e62ecfd --- /dev/null +++ b/text_2_sql/autogen/utils/llm_model_creator.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from autogen_ext.models import AzureOpenAIChatCompletionClient +from environment import IdentityType, get_identity_type + +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +import os +import dotenv + +dotenv.load_dotenv() + + +class LLMModelCreator: + @classmethod + def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient: + """Retrieves the model based on the model name. + + Args: + ---- + model_name (str): The name of the model to retrieve. + + Returns: + AzureOpenAIChatCompletionClient: The model client.""" + if model_name == "4o-mini": + return cls.gpt_4o_mini_model() + elif model_name == "4o": + return cls.gpt_4o_model() + else: + raise ValueError(f"Model {model_name} not found") + + @classmethod + def get_authentication_properties(cls) -> dict: + if get_identity_type() == IdentityType.SYSTEM_ASSIGNED: + # Create the token provider + api_key = None + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + elif get_identity_type() == IdentityType.USER_ASSIGNED: + # Create the token provider + api_key = None + token_provider = get_bearer_token_provider( + DefaultAzureCredential( + managed_identity_client_id=os.environ["ClientId"] + ), + "https://cognitiveservices.azure.com/.default", + ) + else: + token_provider = None + api_key = os.environ["OpenAI__ApiKey"] + + return token_provider, api_key + + @classmethod + def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient: + token_provider, api_key = cls.get_authentication_properties() + return AzureOpenAIChatCompletionClient( + azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"], + model=os.environ["OpenAI__MiniCompletionDeployment"], + api_version="2024-08-01-preview", + azure_endpoint=os.environ["OpenAI__Endpoint"], + azure_ad_token_provider=token_provider, + api_key=api_key, + model_capabilities={ + "vision": False, + "function_calling": True, + "json_output": True, + }, + ) + + @classmethod + def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient: + token_provider, api_key = cls.get_authentication_properties() + return AzureOpenAIChatCompletionClient( + azure_deployment=os.environ["OpenAI__CompletionDeployment"], + model=os.environ["OpenAI__CompletionDeployment"], + api_version="2024-08-01-preview", + azure_endpoint=os.environ["OpenAI__Endpoint"], + azure_ad_token_provider=token_provider, + api_key=api_key, + model_capabilities={ + "vision": False, + "function_calling": True, + "json_output": True, + }, + ) diff --git a/text_2_sql/autogen/utils/models.py b/text_2_sql/autogen/utils/models.py deleted file mode 100644 index 1d4edbb..0000000 --- a/text_2_sql/autogen/utils/models.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from autogen_ext.models import AzureOpenAIChatCompletionClient - -# from azure.identity import DefaultAzureCredential, get_bearer_token_provider -import os -import dotenv - -dotenv.load_dotenv() - -# # Create the token provider -# token_provider = get_bearer_token_provider( -# DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" -# ) - -MINI_MODEL = AzureOpenAIChatCompletionClient( - azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"], - model=os.environ["OpenAI__MiniCompletionDeployment"], - api_version="2024-08-01-preview", - azure_endpoint=os.environ["OpenAI__Endpoint"], - # # Optional if you choose key-based authentication. - # azure_ad_token_provider=token_provider, - api_key=os.environ["OpenAI__ApiKey"], # For key-based authentication. - model_capabilities={ - "vision": False, - "function_calling": True, - "json_output": True, - }, -) diff --git a/text_2_sql/autogen/utils/sql.py b/text_2_sql/autogen/utils/sql.py index 4f340c1..ce5027d 100644 --- a/text_2_sql/autogen/utils/sql.py +++ b/text_2_sql/autogen/utils/sql.py @@ -138,7 +138,7 @@ async def fetch_queries_from_cache(question: str) -> str: ) if len(cached_schemas) == 0: - return {"cached_questions_and_schemas": None} + return {"contains_pre_run_results": False, "cached_questions_and_schemas": None} logging.info("Cached schemas: %s", cached_schemas) if PRE_RUN_QUERY_CACHE and len(cached_schemas) > 0: @@ -165,6 +165,12 @@ async def fetch_queries_from_cache(question: str) -> str: "schemas": sql_query["Schemas"], } - return {"cached_questions_and_schemas": query_result_store} + return { + "contains_pre_run_results": True, + "cached_questions_and_schemas": query_result_store, + } - return {"cached_questions_and_schemas": cached_schemas} + return { + "contains_pre_run_results": False, + "cached_questions_and_schemas": cached_schemas, + }