Skip to content

Commit 2edb05c

Browse files
Improve Agentic Setup and add DateTime Tool (#70)
* Add current datetime * Update agent logic * Improve building of autogen * Update README code
1 parent 011f665 commit 2edb05c

12 files changed

+165
-57
lines changed

text_2_sql/autogen/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ As the query cache is shared between users (no data is stored in the cache), a n
2020

2121
## Provided Notebooks & Scripts
2222

23-
- `./agentic_text_2_sql.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters.
23+
- `./Iteration 5 - Agentic Vector Based Text2SQL.ipynb` provides example of how to utilise the Agentic Vector Based Text2SQL approach to query the database. The query cache plugin will be enabled or disabled depending on the environmental parameters.
2424

2525
## Agents
2626

text_2_sql/autogen/agentic_text_2_sql.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT License.
33
from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination
44
from autogen_agentchat.teams import SelectorGroupChat
5-
from utils.models import MINI_MODEL
5+
from utils.llm_model_creator import LLMModelCreator
66
from utils.llm_agent_creator import LLMAgentCreator
77
import logging
88
from agents.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
@@ -86,13 +86,17 @@ def selector(messages):
8686
and messages[-1].content is not None
8787
):
8888
cache_result = json.loads(messages[-1].content)
89-
if cache_result.get("cached_questions_and_schemas") is not None:
89+
if cache_result.get(
90+
"cached_questions_and_schemas"
91+
) is not None and cache_result.get("contains_pre_run_results"):
9092
decision = "sql_query_correction_agent"
93+
if (
94+
cache_result.get("cached_questions_and_schemas") is not None
95+
and cache_result.get("contains_pre_run_results") is False
96+
):
97+
decision = "sql_query_generation_agent"
9198
else:
92-
decision = "sql_schema_selection_agent"
93-
94-
elif messages[-1].source == "sql_query_cache_agent":
95-
decision = "question_decomposition_agent"
99+
decision = "question_decomposition_agent"
96100

97101
elif messages[-1].source == "question_decomposition_agent":
98102
decomposition_result = json.loads(messages[-1].content)
@@ -129,7 +133,7 @@ def agentic_flow(self):
129133
agentic_flow = SelectorGroupChat(
130134
self.agents,
131135
allow_repeated_speaker=False,
132-
model_client=MINI_MODEL,
136+
model_client=LLMModelCreator.get_model("4o-mini"),
133137
termination_condition=self.termination_condition,
134138
selector_func=AgenticText2Sql.selector,
135139
)

text_2_sql/autogen/agents/llm_agents/answer_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that takes the final results from the SQL query and writes the answer to the user's question"
55
system_message:

text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query. Only use if the user's question is too complex to be answered in one SQL query."
55
system_message:

text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected."
55
system_message:
@@ -20,3 +20,4 @@ system_message:
2020
tools:
2121
- sql_get_entity_schemas_tool
2222
- sql_query_execution_tool
23+
- current_datetime_tool

text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that can generate SQL queries once given the schema and the user's question. It will run the SQL query to fetch the results. Use this agent after the SQL Schema Selection Agent has selected the correct schema."
55
system_message:
@@ -39,3 +39,4 @@ tools:
3939
- sql_query_execution_tool
4040
- sql_get_entity_schemas_tool
4141
- sql_query_validation_tool
42+
- current_datetime_tool

text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that can take a user's question and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term.
55

text_2_sql/autogen/utils/llm_agent_creator.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,38 @@
44
from autogen_core.components.tools import FunctionTool
55
from autogen_agentchat.agents import AssistantAgent
66
from utils.sql import query_execution, get_entity_schemas, query_validation
7-
from utils.models import MINI_MODEL
7+
from utils.llm_model_creator import LLMModelCreator
88
from jinja2 import Template
9+
from datetime import datetime
910

1011

1112
class LLMAgentCreator:
1213
@classmethod
13-
def load_agent_file(cls, name):
14+
def load_agent_file(cls, name: str) -> dict:
15+
"""Loads the agent file based on the agent name.
16+
17+
Args:
18+
----
19+
name (str): The name of the agent to load.
20+
21+
Returns:
22+
-------
23+
dict: The agent file."""
1424
with open(f"./agents/llm_agents/{name.lower()}.yaml", "r") as file:
1525
file = yaml.safe_load(file)
1626

1727
return file
1828

1929
@classmethod
20-
def get_model(cls, model_name):
21-
if model_name == "gpt-4o-mini":
22-
return MINI_MODEL
23-
else:
24-
raise ValueError(f"Model {model_name} not found")
30+
def get_tool(cls, tool_name: str) -> FunctionTool:
31+
"""Retrieves the tool based on the tool name.
2532
26-
@classmethod
27-
def get_tool(cls, tool_name):
33+
Args:
34+
----
35+
tool_name (str): The name of the tool to retrieve.
36+
37+
Returns:
38+
FunctionTool: The tool."""
2839
if tool_name == "sql_query_execution_tool":
2940
return FunctionTool(
3041
query_execution,
@@ -40,19 +51,47 @@ def get_tool(cls, tool_name):
4051
query_validation,
4152
description="Validates the SQL query to ensure that it is syntactically correct for the target database engine. Use this BEFORE executing any SQL statement.",
4253
)
54+
elif tool_name == "current_datetime_tool":
55+
return FunctionTool(
56+
lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
57+
description="Gets the current date and time.",
58+
)
4359
else:
4460
raise ValueError(f"Tool {tool_name} not found")
4561

4662
@classmethod
47-
def get_property_and_render_parameters(cls, agent_file, property, parameters):
63+
def get_property_and_render_parameters(
64+
cls, agent_file: dict, property: str, parameters: dict
65+
) -> str:
66+
"""Gets the property from the agent file and renders the parameters.
67+
68+
Args:
69+
----
70+
agent_file (dict): The agent file.
71+
property (str): The property to retrieve.
72+
parameters (dict): The parameters to render.
73+
74+
Returns:
75+
-------
76+
str: The rendered property."""
4877
unrendered_parameters = agent_file[property]
4978

5079
rendered_template = Template(unrendered_parameters).render(parameters)
5180

5281
return rendered_template
5382

5483
@classmethod
55-
def create(cls, name: str, **kwargs):
84+
def create(cls, name: str, **kwargs) -> AssistantAgent:
85+
"""Creates an assistant agent based on the agent name.
86+
87+
Args:
88+
----
89+
name (str): The name of the agent to create.
90+
**kwargs: The parameters to render.
91+
92+
Returns:
93+
-------
94+
AssistantAgent: The assistant agent."""
5695
agent_file = cls.load_agent_file(name)
5796

5897
tools = []
@@ -63,7 +102,7 @@ def create(cls, name: str, **kwargs):
63102
agent = AssistantAgent(
64103
name=name,
65104
tools=tools,
66-
model_client=cls.get_model(agent_file["model"]),
105+
model_client=LLMModelCreator.get_model(agent_file["model"]),
67106
description=cls.get_property_and_render_parameters(
68107
agent_file, "description", kwargs
69108
),
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from autogen_ext.models import AzureOpenAIChatCompletionClient
4+
from environment import IdentityType, get_identity_type
5+
6+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
7+
import os
8+
import dotenv
9+
10+
dotenv.load_dotenv()
11+
12+
13+
class LLMModelCreator:
14+
@classmethod
15+
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
16+
"""Retrieves the model based on the model name.
17+
18+
Args:
19+
----
20+
model_name (str): The name of the model to retrieve.
21+
22+
Returns:
23+
AzureOpenAIChatCompletionClient: The model client."""
24+
if model_name == "4o-mini":
25+
return cls.gpt_4o_mini_model()
26+
elif model_name == "4o":
27+
return cls.gpt_4o_model()
28+
else:
29+
raise ValueError(f"Model {model_name} not found")
30+
31+
@classmethod
32+
def get_authentication_properties(cls) -> dict:
33+
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
34+
# Create the token provider
35+
api_key = None
36+
token_provider = get_bearer_token_provider(
37+
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
38+
)
39+
elif get_identity_type() == IdentityType.USER_ASSIGNED:
40+
# Create the token provider
41+
api_key = None
42+
token_provider = get_bearer_token_provider(
43+
DefaultAzureCredential(
44+
managed_identity_client_id=os.environ["ClientId"]
45+
),
46+
"https://cognitiveservices.azure.com/.default",
47+
)
48+
else:
49+
token_provider = None
50+
api_key = os.environ["OpenAI__ApiKey"]
51+
52+
return token_provider, api_key
53+
54+
@classmethod
55+
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
56+
token_provider, api_key = cls.get_authentication_properties()
57+
return AzureOpenAIChatCompletionClient(
58+
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
59+
model=os.environ["OpenAI__MiniCompletionDeployment"],
60+
api_version="2024-08-01-preview",
61+
azure_endpoint=os.environ["OpenAI__Endpoint"],
62+
azure_ad_token_provider=token_provider,
63+
api_key=api_key,
64+
model_capabilities={
65+
"vision": False,
66+
"function_calling": True,
67+
"json_output": True,
68+
},
69+
)
70+
71+
@classmethod
72+
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
73+
token_provider, api_key = cls.get_authentication_properties()
74+
return AzureOpenAIChatCompletionClient(
75+
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
76+
model=os.environ["OpenAI__CompletionDeployment"],
77+
api_version="2024-08-01-preview",
78+
azure_endpoint=os.environ["OpenAI__Endpoint"],
79+
azure_ad_token_provider=token_provider,
80+
api_key=api_key,
81+
model_capabilities={
82+
"vision": False,
83+
"function_calling": True,
84+
"json_output": True,
85+
},
86+
)

0 commit comments

Comments
 (0)