Skip to content

Commit d456ace

Browse files
committed
Improve building of autogen
1 parent 3ad3a5e commit d456ace

File tree

4 files changed

+99
-71
lines changed

4 files changed

+99
-71
lines changed

text_2_sql/autogen/agentic_text_2_sql.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT License.
33
from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination
44
from autogen_agentchat.teams import SelectorGroupChat
5-
from utils.models import MINI_MODEL
5+
from utils.llm_model_creator import LLMModelCreator
66
from utils.llm_agent_creator import LLMAgentCreator
77
import logging
88
from agents.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
@@ -86,13 +86,17 @@ def selector(messages):
8686
and messages[-1].content is not None
8787
):
8888
cache_result = json.loads(messages[-1].content)
89-
if cache_result.get("cached_questions_and_schemas") is not None:
89+
if cache_result.get(
90+
"cached_questions_and_schemas"
91+
) is not None and cache_result.get("contains_pre_run_results"):
9092
decision = "sql_query_correction_agent"
93+
if (
94+
cache_result.get("cached_questions_and_schemas") is not None
95+
and cache_result.get("contains_pre_run_results") is False
96+
):
97+
decision = "sql_query_generation_agent"
9198
else:
92-
decision = "sql_schema_selection_agent"
93-
94-
elif messages[-1].source == "sql_query_cache_agent":
95-
decision = "question_decomposition_agent"
99+
decision = "question_decomposition_agent"
96100

97101
elif messages[-1].source == "question_decomposition_agent":
98102
decomposition_result = json.loads(messages[-1].content)
@@ -129,7 +133,7 @@ def agentic_flow(self):
129133
agentic_flow = SelectorGroupChat(
130134
self.agents,
131135
allow_repeated_speaker=False,
132-
model_client=MINI_MODEL,
136+
model_client=LLMModelCreator.get_model("4o-mini"),
133137
termination_condition=self.termination_condition,
134138
selector_func=AgenticText2Sql.selector,
135139
)

text_2_sql/autogen/utils/llm_agent_creator.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from autogen_core.components.tools import FunctionTool
55
from autogen_agentchat.agents import AssistantAgent
66
from utils.sql import query_execution, get_entity_schemas, query_validation
7-
from utils.models import GPT_4O_MINI_MODEL, GPT_4O_MODEL
7+
from utils.llm_model_creator import LLMModelCreator
88
from jinja2 import Template
99
from datetime import datetime
10-
from autogen_ext.models import AzureOpenAIChatCompletionClient
1110

1211

1312
class LLMAgentCreator:
@@ -27,23 +26,6 @@ def load_agent_file(cls, name: str) -> dict:
2726

2827
return file
2928

30-
@classmethod
31-
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
32-
"""Retrieves the model based on the model name.
33-
34-
Args:
35-
----
36-
model_name (str): The name of the model to retrieve.
37-
38-
Returns:
39-
AzureOpenAIChatCompletionClient: The model client."""
40-
if model_name == "4o-mini":
41-
return GPT_4O_MINI_MODEL
42-
elif model_name == "4o":
43-
return GPT_4O_MODEL
44-
else:
45-
raise ValueError(f"Model {model_name} not found")
46-
4729
@classmethod
4830
def get_tool(cls, tool_name: str) -> FunctionTool:
4931
"""Retrieves the tool based on the tool name.
@@ -120,7 +102,7 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
120102
agent = AssistantAgent(
121103
name=name,
122104
tools=tools,
123-
model_client=cls.get_model(agent_file["model"]),
105+
model_client=LLMModelCreator.get_model(agent_file["model"]),
124106
description=cls.get_property_and_render_parameters(
125107
agent_file, "description", kwargs
126108
),
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from autogen_ext.models import AzureOpenAIChatCompletionClient
4+
from environment import IdentityType, get_identity_type
5+
6+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
7+
import os
8+
import dotenv
9+
10+
dotenv.load_dotenv()
11+
12+
13+
class LLMModelCreator:
14+
@classmethod
15+
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
16+
"""Retrieves the model based on the model name.
17+
18+
Args:
19+
----
20+
model_name (str): The name of the model to retrieve.
21+
22+
Returns:
23+
AzureOpenAIChatCompletionClient: The model client."""
24+
if model_name == "4o-mini":
25+
return cls.gpt_4o_mini_model()
26+
elif model_name == "4o":
27+
return cls.gpt_4o_model()
28+
else:
29+
raise ValueError(f"Model {model_name} not found")
30+
31+
@classmethod
32+
def get_authentication_properties(cls) -> dict:
33+
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
34+
# Create the token provider
35+
api_key = None
36+
token_provider = get_bearer_token_provider(
37+
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
38+
)
39+
elif get_identity_type() == IdentityType.USER_ASSIGNED:
40+
# Create the token provider
41+
api_key = None
42+
token_provider = get_bearer_token_provider(
43+
DefaultAzureCredential(
44+
managed_identity_client_id=os.environ["ClientId"]
45+
),
46+
"https://cognitiveservices.azure.com/.default",
47+
)
48+
else:
49+
token_provider = None
50+
api_key = os.environ["OpenAI__ApiKey"]
51+
52+
return token_provider, api_key
53+
54+
@classmethod
55+
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
56+
token_provider, api_key = cls.get_authentication_properties()
57+
return AzureOpenAIChatCompletionClient(
58+
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
59+
model=os.environ["OpenAI__MiniCompletionDeployment"],
60+
api_version="2024-08-01-preview",
61+
azure_endpoint=os.environ["OpenAI__Endpoint"],
62+
azure_ad_token_provider=token_provider,
63+
api_key=api_key,
64+
model_capabilities={
65+
"vision": False,
66+
"function_calling": True,
67+
"json_output": True,
68+
},
69+
)
70+
71+
@classmethod
72+
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
73+
token_provider, api_key = cls.get_authentication_properties()
74+
return AzureOpenAIChatCompletionClient(
75+
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
76+
model=os.environ["OpenAI__CompletionDeployment"],
77+
api_version="2024-08-01-preview",
78+
azure_endpoint=os.environ["OpenAI__Endpoint"],
79+
azure_ad_token_provider=token_provider,
80+
api_key=api_key,
81+
model_capabilities={
82+
"vision": False,
83+
"function_calling": True,
84+
"json_output": True,
85+
},
86+
)

text_2_sql/autogen/utils/models.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

0 commit comments

Comments
 (0)