Skip to content

Commit 39bbedc

Browse files
Initial AutoGen Work - Still WIP (#55)
1 parent 9943302 commit 39bbedc

22 files changed

+797
-930
lines changed

deploy_ai_search/text_2_sql_query_cache.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
SearchableField,
88
SimpleField,
99
ComplexField,
10+
SemanticField,
11+
SemanticPrioritizedFields,
12+
SemanticConfiguration,
13+
SemanticSearch,
1014
)
1115
from ai_search import AISearch
1216
from environment import (
@@ -107,3 +111,22 @@ def get_index_fields(self) -> list[SearchableField]:
107111
]
108112

109113
return fields
114+
115+
def get_semantic_search(self) -> SemanticSearch:
116+
"""This function returns the semantic search configuration for sql index
117+
118+
Returns:
119+
SemanticSearch: The semantic search configuration"""
120+
121+
semantic_config = SemanticConfiguration(
122+
name=self.semantic_config_name,
123+
prioritized_fields=SemanticPrioritizedFields(
124+
content_fields=[
125+
SemanticField(field_name="Question"),
126+
],
127+
),
128+
)
129+
130+
semantic_search = SemanticSearch(configurations=[semantic_config])
131+
132+
return semantic_search

text_2_sql/autogen/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Multi-Shot Text2SQL Component - AutoGen
2+
3+
Very much still work in progress, more documentation coming soon.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import dotenv\n",
10+
"import logging\n",
11+
"from autogen_agentchat.task import Console\n",
12+
"from agentic_text_2_sql import text_2_sql_generator"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"logging.basicConfig(level=logging.INFO)"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"dotenv.load_dotenv()"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"result = text_2_sql_generator.run_stream(task=\"What are the total number of sales within 2008?\")"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"await Console(result)"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"metadata": {},
55+
"outputs": [],
56+
"source": []
57+
}
58+
],
59+
"metadata": {
60+
"kernelspec": {
61+
"display_name": "Python 3",
62+
"language": "python",
63+
"name": "python3"
64+
},
65+
"language_info": {
66+
"codemirror_mode": {
67+
"name": "ipython",
68+
"version": 3
69+
},
70+
"file_extension": ".py",
71+
"mimetype": "text/x-python",
72+
"name": "python",
73+
"nbconvert_exporter": "python",
74+
"pygments_lexer": "ipython3",
75+
"version": "3.12.6"
76+
}
77+
},
78+
"nbformat": 4,
79+
"nbformat_minor": 2
80+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination
2+
from autogen_agentchat.teams import SelectorGroupChat
3+
from utils.models import MINI_MODEL
4+
from utils.llm_agent_creator import LLMAgentCreator
5+
import logging
6+
from custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
7+
import json
8+
9+
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
10+
"sql_query_generation_agent",
11+
target_engine="Microsoft SQL Server",
12+
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
13+
)
14+
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create("sql_schema_selection_agent")
15+
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
16+
"sql_query_correction_agent",
17+
target_engine="Microsoft SQL Server",
18+
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
19+
)
20+
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
21+
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
22+
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create("question_decomposition_agent")
23+
24+
25+
def text_2_sql_generator_selector_func(messages):
26+
logging.info("Messages: %s", messages)
27+
decision = None # Initialize decision variable
28+
29+
if len(messages) == 1:
30+
decision = "sql_query_cache_agent"
31+
32+
elif (
33+
messages[-1].source == "sql_query_cache_agent"
34+
and messages[-1].content is not None
35+
):
36+
cache_result = json.loads(messages[-1].content)
37+
if cache_result.get("cached_questions_and_schemas") is not None:
38+
decision = "sql_query_correction_agent"
39+
else:
40+
decision = "sql_schema_selection_agent"
41+
42+
elif messages[-1].source == "question_decomposition_agent":
43+
decision = "sql_schema_selection_agent"
44+
45+
elif messages[-1].source == "sql_schema_selection_agent":
46+
decision = "sql_query_generation_agent"
47+
48+
elif (
49+
messages[-1].source == "sql_query_correction_agent"
50+
and messages[-1].content == "VALIDATED"
51+
):
52+
decision = "answer_agent"
53+
54+
elif messages[-1].source == "sql_query_correction_agent":
55+
decision = "sql_query_correction_agent"
56+
57+
# Log the decision
58+
logging.info("Decision: %s", decision)
59+
60+
return decision
61+
62+
63+
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10)
64+
text_2_sql_generator = SelectorGroupChat(
65+
[
66+
SQL_QUERY_GENERATION_AGENT,
67+
SQL_SCHEMA_SELECTION_AGENT,
68+
SQL_QUERY_CORRECTION_AGENT,
69+
SQL_QUERY_CACHE_AGENT,
70+
ANSWER_AGENT,
71+
QUESTION_DECOMPOSITION_AGENT,
72+
],
73+
allow_repeated_speaker=False,
74+
model_client=MINI_MODEL,
75+
termination_condition=termination,
76+
selector_func=text_2_sql_generator_selector_func,
77+
)
78+
79+
# text_2_sql_cache_updater = SelectorGroupChat(
80+
# [SQL_QUERY_CACHE_AGENT], model_client=MINI_MODEL, termination_condition=termination
81+
# )

text_2_sql/autogen/custom_agents/__init__.py

Whitespace-only changes.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import AsyncGenerator, List, Sequence
2+
3+
from autogen_agentchat.agents import BaseChatAgent
4+
from autogen_agentchat.base import Response
5+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
6+
from autogen_core.base import CancellationToken
7+
from utils.sql_utils import fetch_queries_from_cache
8+
import json
9+
import logging
10+
11+
12+
class SqlQueryCacheAgent(BaseChatAgent):
13+
def __init__(self):
14+
super().__init__(
15+
"sql_query_cache_agent",
16+
"An agent that fetches the queries from the cache based on the user question.",
17+
)
18+
19+
@property
20+
def produced_message_types(self) -> List[type[ChatMessage]]:
21+
return [TextMessage]
22+
23+
async def on_messages(
24+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
25+
) -> Response:
26+
# Calls the on_messages_stream.
27+
response: Response | None = None
28+
async for message in self.on_messages_stream(messages, cancellation_token):
29+
if isinstance(message, Response):
30+
response = message
31+
assert response is not None
32+
return response
33+
34+
async def on_messages_stream(
35+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
36+
) -> AsyncGenerator[AgentMessage | Response, None]:
37+
user_question = messages[0].content
38+
39+
# Fetch the queries from the cache based on the user question.
40+
logging.info("Fetching queries from cache based on the user question...")
41+
42+
cached_queries = await fetch_queries_from_cache(user_question)
43+
44+
yield Response(
45+
chat_message=TextMessage(
46+
content=json.dumps(cached_queries), source=self.name
47+
)
48+
)
49+
50+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
51+
pass

text_2_sql/autogen/environment.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import os
4+
from enum import Enum
5+
6+
7+
class IdentityType(Enum):
8+
"""The type of the indexer"""
9+
10+
USER_ASSIGNED = "user_assigned"
11+
SYSTEM_ASSIGNED = "system_assigned"
12+
KEY = "key"
13+
14+
15+
def get_identity_type() -> IdentityType:
16+
"""This function returns the identity type.
17+
18+
Returns:
19+
IdentityType: The identity type
20+
"""
21+
identity = os.environ.get("IdentityType")
22+
23+
if identity == "user_assigned":
24+
return IdentityType.USER_ASSIGNED
25+
elif identity == "system_assigned":
26+
return IdentityType.SYSTEM_ASSIGNED
27+
elif identity == "key":
28+
return IdentityType.KEY
29+
else:
30+
raise ValueError("Invalid identity type")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
model:
2+
gpt-4o-mini
3+
description:
4+
"An agent that takes the final results from the SQL query and writes the answer to the user's question"
5+
system_message:
6+
"Write a data-driven answer that directly addresses the user's question. Use the results from the SQL query to provide the answer. Do not make up or guess the answer.
7+
8+
Return your answer in the following format:
9+
10+
{
11+
'answer': '<GENERATED ANSWER>',
12+
'sources': [
13+
{'title': <SOURCE SCHEMA NAME 1>, 'chunk': <SOURCE 1 CONTEXT CHUNK>, 'reference': '<SOURCE 1 SQL QUERY>'},
14+
{'title': <SOURCE SCHEMA NAME 2>, 'chunk': <SOURCE 2 CONTEXT CHUNK>, 'reference': '<SOURCE 2 SQL QUERY>'}
15+
]
16+
}
17+
18+
Title is the entity name of the schema, chunk is the result of the SQL query and reference is the SQL query used to generate the answer.
19+
20+
End your answer with 'TERMINATE'"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
model:
2+
gpt-4o-mini
3+
description:
4+
"An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query. Only use if the user's question is too complex to be answered in one SQL query.
5+
6+
Only use this agent once per user question and after the 'Query Cache Agent' if the results are none."
7+
system_message:
8+
"You are a helpful AI Assistant that specialises in decomposing complex user questions into smaller parts that can be used in SQL queries.
9+
10+
Break down the user's question into smaller parts that can be used in SQL queries."
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
model:
2+
gpt-4o-mini
3+
description:
4+
"An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected."
5+
system_message:
6+
"You are a helpful AI Assistant that specialises in correcting invalid SQL queries or queries that do not return the expected results.
7+
8+
Review the SQL query provided and correct any errors or issues that you find. Bear in mind that the target database engine is {{ target_engine }}, SQL queries must be able compatible to run on {{ target_engine }} {{ engine_specific_rules }}
9+
10+
Ensure that the corrected query returns the expected results in context of the question.
11+
12+
If there are no errors and the SQL query is correct, return 'VALIDATED'.
13+
14+
If the SQL query needs adjustment, correct the SQL query and provide the corrected SQL query and then run the query.
15+
16+
If you are consistently unable to correct the SQL query and cannot use the schemas to answer the question. Say 'I am unable to correct the SQL query. Please ask another question.' and then end your answer with 'TERMINATE'"
17+
tools:
18+
- sql_get_entity_schemas_tool
19+
- sql_query_execution_tool

0 commit comments

Comments
 (0)