Skip to content

Improve Agentic Setup and add DateTime Tool #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion text_2_sql/autogen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ As the query cache is shared between users (no data is stored in the cache), a n

## Provided Notebooks & Scripts

- `./agentic_text_2_sql.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters.
- `./Iteration 5 - Agentic Vector Based Text2SQL.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters.

## Agents

Expand Down
18 changes: 11 additions & 7 deletions text_2_sql/autogen/agentic_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination
from autogen_agentchat.teams import SelectorGroupChat
from utils.models import MINI_MODEL
from utils.llm_model_creator import LLMModelCreator
from utils.llm_agent_creator import LLMAgentCreator
import logging
from agents.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
Expand Down Expand Up @@ -86,13 +86,17 @@ def selector(messages):
and messages[-1].content is not None
):
cache_result = json.loads(messages[-1].content)
if cache_result.get("cached_questions_and_schemas") is not None:
if cache_result.get(
"cached_questions_and_schemas"
) is not None and cache_result.get("contains_pre_run_results"):
decision = "sql_query_correction_agent"
if (
cache_result.get("cached_questions_and_schemas") is not None
and cache_result.get("contains_pre_run_results") is False
):
decision = "sql_query_generation_agent"
else:
decision = "sql_schema_selection_agent"

elif messages[-1].source == "sql_query_cache_agent":
decision = "question_decomposition_agent"
decision = "question_decomposition_agent"

elif messages[-1].source == "question_decomposition_agent":
decomposition_result = json.loads(messages[-1].content)
Expand Down Expand Up @@ -129,7 +133,7 @@ def agentic_flow(self):
agentic_flow = SelectorGroupChat(
self.agents,
allow_repeated_speaker=False,
model_client=MINI_MODEL,
model_client=LLMModelCreator.get_model("4o-mini"),
termination_condition=self.termination_condition,
selector_func=AgenticText2Sql.selector,
)
Expand Down
2 changes: 1 addition & 1 deletion text_2_sql/autogen/agents/llm_agents/answer_agent.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
gpt-4o-mini
4o-mini
description:
"An agent that takes the final results from the SQL query and writes the answer to the user's question"
system_message:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
gpt-4o-mini
4o-mini
description:
"An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query. Only use if the user's question is too complex to be answered in one SQL query."
system_message:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
gpt-4o-mini
4o-mini
description:
"An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected."
system_message:
Expand All @@ -20,3 +20,4 @@ system_message:
tools:
- sql_get_entity_schemas_tool
- sql_query_execution_tool
- current_datetime_tool
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
gpt-4o-mini
4o-mini
description:
"An agent that can generate SQL queries once given the schema and the user's question. It will run the SQL query to fetch the results. Use this agent after the SQL Schema Selection Agent has selected the correct schema."
system_message:
Expand Down Expand Up @@ -39,3 +39,4 @@ tools:
- sql_query_execution_tool
- sql_get_entity_schemas_tool
- sql_query_validation_tool
- current_datetime_tool
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
gpt-4o-mini
4o-mini
description:
"An agent that can take a user's question and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term.

Expand Down
63 changes: 51 additions & 12 deletions text_2_sql/autogen/utils/llm_agent_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,38 @@
from autogen_core.components.tools import FunctionTool
from autogen_agentchat.agents import AssistantAgent
from utils.sql import query_execution, get_entity_schemas, query_validation
from utils.models import MINI_MODEL
from utils.llm_model_creator import LLMModelCreator
from jinja2 import Template
from datetime import datetime


class LLMAgentCreator:
@classmethod
def load_agent_file(cls, name):
def load_agent_file(cls, name: str) -> dict:
"""Loads the agent file based on the agent name.

Args:
----
name (str): The name of the agent to load.

Returns:
-------
dict: The agent file."""
with open(f"./agents/llm_agents/{name.lower()}.yaml", "r") as file:
file = yaml.safe_load(file)

return file

@classmethod
def get_model(cls, model_name):
if model_name == "gpt-4o-mini":
return MINI_MODEL
else:
raise ValueError(f"Model {model_name} not found")
def get_tool(cls, tool_name: str) -> FunctionTool:
"""Retrieves the tool based on the tool name.

@classmethod
def get_tool(cls, tool_name):
Args:
----
tool_name (str): The name of the tool to retrieve.

Returns:
FunctionTool: The tool."""
if tool_name == "sql_query_execution_tool":
return FunctionTool(
query_execution,
Expand All @@ -40,19 +51,47 @@ def get_tool(cls, tool_name):
query_validation,
description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.",
)
elif tool_name == "current_datetime_tool":
return FunctionTool(
lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
description="Gets the current date and time.",
)
else:
raise ValueError(f"Tool {tool_name} not found")

@classmethod
def get_property_and_render_parameters(cls, agent_file, property, parameters):
def get_property_and_render_parameters(
cls, agent_file: dict, property: str, parameters: dict
) -> str:
"""Gets the property from the agent file and renders the parameters.

Args:
----
agent_file (dict): The agent file.
property (str): The property to retrieve.
parameters (dict): The parameters to render.

Returns:
-------
str: The rendered property."""
unrendered_parameters = agent_file[property]

rendered_template = Template(unrendered_parameters).render(parameters)

return rendered_template

@classmethod
def create(cls, name: str, **kwargs):
def create(cls, name: str, **kwargs) -> AssistantAgent:
"""Creates an assistant agent based on the agent name.

Args:
----
name (str): The name of the agent to create.
**kwargs: The parameters to render.

Returns:
-------
AssistantAgent: The assistant agent."""
agent_file = cls.load_agent_file(name)

tools = []
Expand All @@ -63,7 +102,7 @@ def create(cls, name: str, **kwargs):
agent = AssistantAgent(
name=name,
tools=tools,
model_client=cls.get_model(agent_file["model"]),
model_client=LLMModelCreator.get_model(agent_file["model"]),
description=cls.get_property_and_render_parameters(
agent_file, "description", kwargs
),
Expand Down
86 changes: 86 additions & 0 deletions text_2_sql/autogen/utils/llm_model_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_ext.models import AzureOpenAIChatCompletionClient
from environment import IdentityType, get_identity_type

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import os
import dotenv

dotenv.load_dotenv()


class LLMModelCreator:
@classmethod
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
"""Retrieves the model based on the model name.

Args:
----
model_name (str): The name of the model to retrieve.

Returns:
AzureOpenAIChatCompletionClient: The model client."""
if model_name == "4o-mini":
return cls.gpt_4o_mini_model()
elif model_name == "4o":
return cls.gpt_4o_model()
else:
raise ValueError(f"Model {model_name} not found")

@classmethod
def get_authentication_properties(cls) -> dict:
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
# Create the token provider
api_key = None
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
elif get_identity_type() == IdentityType.USER_ASSIGNED:
# Create the token provider
api_key = None
token_provider = get_bearer_token_provider(
DefaultAzureCredential(
managed_identity_client_id=os.environ["ClientId"]
),
"https://cognitiveservices.azure.com/.default",
)
else:
token_provider = None
api_key = os.environ["OpenAI__ApiKey"]

return token_provider, api_key

@classmethod
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
token_provider, api_key = cls.get_authentication_properties()
return AzureOpenAIChatCompletionClient(
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
model=os.environ["OpenAI__MiniCompletionDeployment"],
api_version="2024-08-01-preview",
azure_endpoint=os.environ["OpenAI__Endpoint"],
azure_ad_token_provider=token_provider,
api_key=api_key,
model_capabilities={
"vision": False,
"function_calling": True,
"json_output": True,
},
)

@classmethod
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
token_provider, api_key = cls.get_authentication_properties()
return AzureOpenAIChatCompletionClient(
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
model=os.environ["OpenAI__CompletionDeployment"],
api_version="2024-08-01-preview",
azure_endpoint=os.environ["OpenAI__Endpoint"],
azure_ad_token_provider=token_provider,
api_key=api_key,
model_capabilities={
"vision": False,
"function_calling": True,
"json_output": True,
},
)
29 changes: 0 additions & 29 deletions text_2_sql/autogen/utils/models.py

This file was deleted.

12 changes: 9 additions & 3 deletions text_2_sql/autogen/utils/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def fetch_queries_from_cache(question: str) -> str:
)

if len(cached_schemas) == 0:
return {"cached_questions_and_schemas": None}
return {"contains_pre_run_results": False, "cached_questions_and_schemas": None}

logging.info("Cached schemas: %s", cached_schemas)
if PRE_RUN_QUERY_CACHE and len(cached_schemas) > 0:
Expand All @@ -165,6 +165,12 @@ async def fetch_queries_from_cache(question: str) -> str:
"schemas": sql_query["Schemas"],
}

return {"cached_questions_and_schemas": query_result_store}
return {
"contains_pre_run_results": True,
"cached_questions_and_schemas": query_result_store,
}

return {"cached_questions_and_schemas": cached_schemas}
return {
"contains_pre_run_results": False,
"cached_questions_and_schemas": cached_schemas,
}
Loading