Skip to content

Commit 2962719

Browse files
committed
Add custom agent
1 parent 3ead4a1 commit 2962719

File tree

10 files changed

+218
-73
lines changed

10 files changed

+218
-73
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
"metadata": {},
9595
"outputs": [],
9696
"source": [
97-
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008?\")"
97+
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008 for the mountain bike?\")"
9898
]
9999
},
100100
{

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination
3+
from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination
44
from autogen_agentchat.teams import SelectorGroupChat
55
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
66
from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator
77
import logging
88
from autogen_text_2_sql.custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
9+
from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import (
10+
SqlSchemaSelectionAgent,
11+
)
912
import json
1013
import os
1114

@@ -32,6 +35,10 @@ def set_mode(self):
3235
os.environ.get("Text2Sql__PreRunQueryCache", "False").lower() == "true"
3336
)
3437

38+
self.use_column_value_store = (
39+
os.environ.get("Text2Sql__UseColumnValueStore", "False").lower() == "true"
40+
)
41+
3542
@property
3643
def agents(self):
3744
"""Define the agents for the chat."""
@@ -41,8 +48,7 @@ def agents(self):
4148
engine_specific_rules=self.engine_specific_rules,
4249
**self.kwargs,
4350
)
44-
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create(
45-
"sql_schema_selection_agent",
51+
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
4652
target_engine=self.target_engine,
4753
engine_specific_rules=self.engine_specific_rules,
4854
**self.kwargs,

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_model_creator.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_ext.models import AzureOpenAIChatCompletionClient
4-
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
4+
from text_2_sql_core.connectors.factory import ConnectorFactory
55

6-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
76
import os
87
import dotenv
98

@@ -28,32 +27,12 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
2827
else:
2928
raise ValueError(f"Model {model_name} not found")
3029

31-
@classmethod
32-
def get_authentication_properties(cls) -> dict:
33-
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
34-
# Create the token provider
35-
api_key = None
36-
token_provider = get_bearer_token_provider(
37-
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
38-
)
39-
elif get_identity_type() == IdentityType.USER_ASSIGNED:
40-
# Create the token provider
41-
api_key = None
42-
token_provider = get_bearer_token_provider(
43-
DefaultAzureCredential(
44-
managed_identity_client_id=os.environ["ClientId"]
45-
),
46-
"https://cognitiveservices.azure.com/.default",
47-
)
48-
else:
49-
token_provider = None
50-
api_key = os.environ["OpenAI__ApiKey"]
51-
52-
return token_provider, api_key
53-
5430
@classmethod
5531
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
56-
token_provider, api_key = cls.get_authentication_properties()
32+
(
33+
token_provider,
34+
api_key,
35+
) = ConnectorFactory.get_open_ai_connector().get_authentication_properties()
5736
return AzureOpenAIChatCompletionClient(
5837
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
5938
model=os.environ["OpenAI__MiniCompletionDeployment"],
@@ -70,7 +49,10 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
7049

7150
@classmethod
7251
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
73-
token_provider, api_key = cls.get_authentication_properties()
52+
(
53+
token_provider,
54+
api_key,
55+
) = ConnectorFactory.get_open_ai_connector().get_authentication_properties()
7456
return AzureOpenAIChatCompletionClient(
7557
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
7658
model=os.environ["OpenAI__CompletionDeployment"],

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,28 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core import CancellationToken
9-
from text_2_sql_core.connectors.sql import SqlConnector
9+
from text_2_sql_core.connectors.factory import ConnectorFactory
1010
import json
1111
import logging
12+
from text_2_sql_core.prompts.load import load
13+
from jinja2 import Template
14+
import asyncio
1215

1316

14-
class SqlQueryCacheAgent(BaseChatAgent):
17+
class SqlSchemaSelectionAgent(BaseChatAgent):
1518
def __init__(self, **kwargs):
1619
super().__init__(
1720
"sql_schema_selection_agent",
1821
"An agent that fetches the schemas from the cache based on the user question.",
1922
)
2023

21-
self.kwargs = kwargs
22-
self.sql_connector = SqlConnector()
24+
self.ai_search_connector = ConnectorFactory.get_ai_search_connector()
25+
26+
self.open_ai_connector = ConnectorFactory.get_open_ai_connector()
27+
28+
system_prompt = load("sql_schema_selection_agent")["system_message"]
29+
30+
self.system_prompt = Template(system_prompt).render(kwargs)
2331

2432
@property
2533
def produced_message_types(self) -> List[type[ChatMessage]]:
@@ -39,18 +47,63 @@ async def on_messages(
3947
async def on_messages_stream(
4048
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4149
) -> AsyncGenerator[AgentMessage | Response, None]:
42-
user_question = messages[-1].content
50+
last_response = messages[-1].content
4351

44-
# Fetch the queries from the cache based on the user question.
45-
logging.info("Fetching queries from cache based on the user question...")
52+
# load the json of the last message and get the user question's
4653

47-
cached_queries = await self.sql_connector.fetch_queries_from_cache(
48-
user_question
49-
)
54+
user_questions = json.loads(last_response)
55+
56+
logging.info(f"User questions: {user_questions}")
57+
58+
entity_tasks = []
59+
60+
for user_question in user_questions:
61+
messages = [
62+
{"role": "system", "content": self.system_prompt},
63+
{"role": "user", "content": user_question},
64+
]
65+
entity_tasks.append(self.open_ai_connector.run_completion_request(messages))
66+
67+
entity_results = await asyncio.gather(*entity_tasks)
68+
69+
entity_search_tasks = []
70+
column_search_tasks = []
71+
72+
for entity_result in entity_results:
73+
loaded_entity_result = json.loads(entity_result)
74+
75+
logging.info(f"Loaded entity result: {loaded_entity_result}")
76+
77+
entity_search_tasks.append(
78+
self.ai_search_connector.get_entity_schemas(
79+
" ".join(loaded_entity_result["entities"]), as_json=False
80+
)
81+
)
82+
83+
for filter_condition in loaded_entity_result["filter_conditions"]:
84+
column_search_tasks.append(
85+
self.ai_search_connector.get_column_values(
86+
filter_condition, as_json=False
87+
)
88+
)
89+
90+
schemas_results = await asyncio.gather(*entity_search_tasks)
91+
column_value_results = await asyncio.gather(*column_search_tasks)
92+
93+
final_results = {
94+
"schemas": [
95+
schema for schema_result in schemas_results for schema in schema_result
96+
],
97+
"column_values": [
98+
column_values
99+
for column_values_result in column_value_results
100+
for column_values in column_values_result
101+
],
102+
}
50103

51104
yield Response(
52105
chat_message=TextMessage(
53-
content=json.dumps(cached_queries), source=self.name
106+
content=json.dumps(final_results), source=self.name
54107
)
55108
)
56109

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,29 @@ async def run_ai_search_query(
2929
"""Run the AI search query."""
3030
identity_type = get_identity_type()
3131

32-
async with AsyncAzureOpenAI(
33-
# This is the default and can be omitted
34-
api_key=os.environ["OpenAI__ApiKey"],
35-
azure_endpoint=os.environ["OpenAI__Endpoint"],
36-
api_version=os.environ["OpenAI__ApiVersion"],
37-
) as open_ai_client:
38-
embeddings = await open_ai_client.embeddings.create(
39-
model=os.environ["OpenAI__EmbeddingModel"], input=query
40-
)
32+
if len(vector_fields) > 0:
33+
async with AsyncAzureOpenAI(
34+
# This is the default and can be omitted
35+
api_key=os.environ["OpenAI__ApiKey"],
36+
azure_endpoint=os.environ["OpenAI__Endpoint"],
37+
api_version=os.environ["OpenAI__ApiVersion"],
38+
) as open_ai_client:
39+
embeddings = await open_ai_client.embeddings.create(
40+
model=os.environ["OpenAI__EmbeddingModel"], input=query
41+
)
4142

42-
# Extract the embedding vector
43-
embedding_vector = embeddings.data[0].embedding
43+
# Extract the embedding vector
44+
embedding_vector = embeddings.data[0].embedding
4445

45-
vector_query = VectorizedQuery(
46-
vector=embedding_vector,
47-
k_nearest_neighbors=7,
48-
fields=",".join(vector_fields),
49-
)
46+
vector_query = [
47+
VectorizedQuery(
48+
vector=embedding_vector,
49+
k_nearest_neighbors=7,
50+
fields=",".join(vector_fields),
51+
)
52+
]
53+
else:
54+
vector_query = None
5055

5156
if identity_type == IdentityType.SYSTEM_ASSIGNED:
5257
credential = DefaultAzureCredential()
@@ -63,13 +68,20 @@ async def run_ai_search_query(
6368
index_name=index_name,
6469
credential=credential,
6570
) as search_client:
71+
if semantic_config is not None and vector_query is not None:
72+
query_type = "semantic"
73+
elif vector_query is not None:
74+
query_type = "hybrid"
75+
else:
76+
query_type = "full"
77+
6678
results = await search_client.search(
6779
top=top,
6880
semantic_configuration_name=semantic_config,
6981
search_text=query,
7082
select=",".join(retrieval_fields),
71-
vector_queries=[vector_query],
72-
query_type="semantic",
83+
vector_queries=vector_query,
84+
query_type=query_type,
7385
query_language="en-GB",
7486
)
7587

@@ -102,6 +114,7 @@ async def get_column_values(
102114
str,
103115
"The text to run a semantic search against. Relevant entities will be returned.",
104116
],
117+
as_json: bool = True,
105118
):
106119
"""Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
107120
@@ -113,20 +126,24 @@ async def get_column_values(
113126
-------
114127
str: The values of the column in JSON format.
115128
"""
129+
130+
# Adds tildes after each text word to do a fuzzy search
131+
text = " ".join([f"{word}~" for word in text.split()])
116132
values = await self.run_ai_search_query(
117133
text,
118134
[],
119135
["FQN", "Column", "Value"],
120136
os.environ[
121137
"AIService__AzureSearchOptions__Text2SqlColumnValueStore__Index"
122138
],
123-
os.environ[
124-
"AIService__AzureSearchOptions__Text2SqlColumnValueStore__SemanticConfig"
125-
],
139+
None,
126140
top=10,
127141
)
128142

129-
return json.dumps(values, default=str)
143+
if as_json:
144+
return json.dumps(values, default=str)
145+
else:
146+
return values
130147

131148
async def get_entity_schemas(
132149
self,
@@ -138,6 +155,7 @@ async def get_entity_schemas(
138155
list[str],
139156
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
140157
] = [],
158+
as_json: bool = True,
141159
) -> str:
142160
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
143161
@@ -178,7 +196,10 @@ async def get_entity_schemas(
178196
else:
179197
filtered_schemas.append(schema)
180198

181-
return json.dumps(schemas, default=str)
199+
if as_json:
200+
return json.dumps(schemas, default=str)
201+
else:
202+
return schemas
182203

183204
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):
184205
"""Add an entry to the search index."""

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from text_2_sql_core.connectors.ai_search import AISearchConnector
3+
from text_2_sql_core.connectors.open_ai import OpenAIConnector
34

45

56
class ConnectorFactory:
@@ -36,3 +37,7 @@ def get_database_connector():
3637
@staticmethod
3738
def get_ai_search_connector():
3839
return AISearchConnector()
40+
41+
@staticmethod
42+
def get_open_ai_connector():
43+
return OpenAIConnector()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from openai import AsyncAzureOpenAI
2+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
3+
import os
4+
import dotenv
5+
from text_2_sql_core.utils.environment import IdentityType, get_identity_type
6+
7+
dotenv.load_dotenv()
8+
9+
10+
class OpenAIConnector:
11+
@classmethod
12+
def get_authentication_properties(cls) -> dict:
13+
if get_identity_type() == IdentityType.SYSTEM_ASSIGNED:
14+
# Create the token provider
15+
api_key = None
16+
token_provider = get_bearer_token_provider(
17+
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
18+
)
19+
elif get_identity_type() == IdentityType.USER_ASSIGNED:
20+
# Create the token provider
21+
api_key = None
22+
token_provider = get_bearer_token_provider(
23+
DefaultAzureCredential(
24+
managed_identity_client_id=os.environ["ClientId"]
25+
),
26+
"https://cognitiveservices.azure.com/.default",
27+
)
28+
else:
29+
token_provider = None
30+
api_key = os.environ["OpenAI__ApiKey"]
31+
32+
return token_provider, api_key
33+
34+
async def run_completion_request(self, messages: list[dict], temperature=0):
35+
async with AsyncAzureOpenAI(
36+
api_key=os.environ["OpenAI__ApiKey"],
37+
azure_endpoint=os.environ["OpenAI__Endpoint"],
38+
api_version=os.environ["OpenAI__ApiVersion"],
39+
) as open_ai_client:
40+
response = await open_ai_client.chat.completions.create(
41+
model=os.environ["OpenAI__MiniCompletionDeployment"],
42+
messages=messages,
43+
temperature=temperature,
44+
)
45+
return response.choices[0].message.content

0 commit comments

Comments
 (0)