Skip to content

Commit f9a2bcd

Browse files
Merge branch 'main' into feature/semantic-chunking
2 parents 70e9e36 + b8e45f3 commit f9a2bcd

31 files changed

+1088
-942
lines changed

deploy_ai_search/.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ OpenAI__Endpoint=<openAIEndpoint>
2020
OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
2121
OpenAI__EmbeddingDeployment=<openAIEmbeddingDeploymentId>
2222
OpenAI__EmbeddingDimensions=1536
23-
Text2Sql__DatabaseName=<databaseName>
23+
Text2Sql__DatabaseEngine=<databaseEngine SQL Server / Snowflake / Databricks >

deploy_ai_search/text_2_sql_query_cache.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
SearchableField,
88
SimpleField,
99
ComplexField,
10+
SemanticField,
11+
SemanticPrioritizedFields,
12+
SemanticConfiguration,
13+
SemanticSearch,
1014
)
1115
from ai_search import AISearch
1216
from environment import (
@@ -107,3 +111,22 @@ def get_index_fields(self) -> list[SearchableField]:
107111
]
108112

109113
return fields
114+
115+
def get_semantic_search(self) -> SemanticSearch:
116+
"""This function returns the semantic search configuration for sql index
117+
118+
Returns:
119+
SemanticSearch: The semantic search configuration"""
120+
121+
semantic_config = SemanticConfiguration(
122+
name=self.semantic_config_name,
123+
prioritized_fields=SemanticPrioritizedFields(
124+
content_fields=[
125+
SemanticField(field_name="Question"),
126+
],
127+
),
128+
)
129+
130+
semantic_search = SemanticSearch(configurations=[semantic_config])
131+
132+
return semantic_search

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
from environment import (
2525
IndexerType,
2626
)
27+
import os
28+
from enum import StrEnum
29+
30+
31+
class DatabaseEngine(StrEnum):
32+
"""An enumeration to represent a database engine."""
33+
34+
SNOWFLAKE = "SNOWFLAKE"
35+
SQL_SERVER = "SQL_SERVER"
36+
DATABRICKS = "DATABRICKS"
2737

2838

2939
class Text2SqlSchemaStoreAISearch(AISearch):
@@ -42,13 +52,34 @@ def __init__(
4252
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
4353
"""
4454
self.indexer_type = IndexerType.TEXT_2_SQL_SCHEMA_STORE
55+
self.database_engine = DatabaseEngine[
56+
os.environ["Text2Sql__DatabaseEngine"].upper()
57+
]
4558
super().__init__(suffix, rebuild)
4659

4760
if single_data_dictionary:
4861
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
4962
else:
5063
self.parsing_mode = BlobIndexerParsingMode.JSON
5164

65+
@property
66+
def excluded_fields_for_database_engine(self):
67+
"""A method to get the excluded fields for the database engine."""
68+
69+
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
70+
if self.database_engine == DatabaseEngine.SNOWFLAKE:
71+
engine_specific_fields = ["Warehouse", "Database"]
72+
elif self.database_engine == DatabaseEngine.SQL_SERVER:
73+
engine_specific_fields = ["Database"]
74+
elif self.database_engine == DatabaseEngine.DATABRICKS:
75+
engine_specific_fields = ["Catalog"]
76+
77+
return [
78+
field
79+
for field in all_engine_specific_fields
80+
if field not in engine_specific_fields
81+
]
82+
5283
def get_index_fields(self) -> list[SearchableField]:
5384
"""This function returns the index fields for sql index.
5485
@@ -78,6 +109,10 @@ def get_index_fields(self) -> list[SearchableField]:
78109
name="Warehouse",
79110
type=SearchFieldDataType.String,
80111
),
112+
SearchableField(
113+
name="Catalog",
114+
type=SearchFieldDataType.String,
115+
),
81116
SearchableField(
82117
name="Definition",
83118
type=SearchFieldDataType.String,
@@ -161,6 +196,13 @@ def get_index_fields(self) -> list[SearchableField]:
161196
),
162197
]
163198

199+
# Remove fields that are not supported by the database engine
200+
fields = [
201+
field
202+
for field in fields
203+
if field.name not in self.excluded_fields_for_database_engine
204+
]
205+
164206
return fields
165207

166208
def get_semantic_search(self) -> SemanticSearch:
@@ -309,4 +351,12 @@ def get_indexer(self) -> SearchIndexer:
309351
parameters=indexer_parameters,
310352
)
311353

354+
# Remove fields that are not supported by the database engine
355+
indexer.output_field_mappings = [
356+
field_mapping
357+
for field_mapping in indexer.output_field_mappings
358+
if field_mapping.target_field_name
359+
not in self.excluded_fields_for_database_engine
360+
]
361+
312362
return indexer

text_2_sql/autogen/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Multi-Shot Text2SQL Component - AutoGen
2+
3+
Very much still work in progress, more documentation coming soon.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import dotenv\n",
10+
"import logging\n",
11+
"from autogen_agentchat.task import Console\n",
12+
"from agentic_text_2_sql import text_2_sql_generator"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"logging.basicConfig(level=logging.INFO)"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"dotenv.load_dotenv()"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"result = text_2_sql_generator.run_stream(task=\"What are the total number of sales within 2008?\")"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"await Console(result)"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"metadata": {},
55+
"outputs": [],
56+
"source": []
57+
}
58+
],
59+
"metadata": {
60+
"kernelspec": {
61+
"display_name": "Python 3",
62+
"language": "python",
63+
"name": "python3"
64+
},
65+
"language_info": {
66+
"codemirror_mode": {
67+
"name": "ipython",
68+
"version": 3
69+
},
70+
"file_extension": ".py",
71+
"mimetype": "text/x-python",
72+
"name": "python",
73+
"nbconvert_exporter": "python",
74+
"pygments_lexer": "ipython3",
75+
"version": "3.12.6"
76+
}
77+
},
78+
"nbformat": 4,
79+
"nbformat_minor": 2
80+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from autogen_agentchat.task import TextMentionTermination, MaxMessageTermination
4+
from autogen_agentchat.teams import SelectorGroupChat
5+
from utils.models import MINI_MODEL
6+
from utils.llm_agent_creator import LLMAgentCreator
7+
import logging
8+
from custom_agents.sql_query_cache_agent import SqlQueryCacheAgent
9+
import json
10+
11+
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
12+
"sql_query_generation_agent",
13+
target_engine="Microsoft SQL Server",
14+
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
15+
)
16+
SQL_SCHEMA_SELECTION_AGENT = LLMAgentCreator.create("sql_schema_selection_agent")
17+
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
18+
"sql_query_correction_agent",
19+
target_engine="Microsoft SQL Server",
20+
engine_specific_rules="Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.",
21+
)
22+
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
23+
ANSWER_AGENT = LLMAgentCreator.create("answer_agent")
24+
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create("question_decomposition_agent")
25+
26+
27+
def text_2_sql_generator_selector_func(messages):
28+
logging.info("Messages: %s", messages)
29+
decision = None # Initialize decision variable
30+
31+
if len(messages) == 1:
32+
decision = "sql_query_cache_agent"
33+
34+
elif (
35+
messages[-1].source == "sql_query_cache_agent"
36+
and messages[-1].content is not None
37+
):
38+
cache_result = json.loads(messages[-1].content)
39+
if cache_result.get("cached_questions_and_schemas") is not None:
40+
decision = "sql_query_correction_agent"
41+
else:
42+
decision = "sql_schema_selection_agent"
43+
44+
elif messages[-1].source == "question_decomposition_agent":
45+
decision = "sql_schema_selection_agent"
46+
47+
elif messages[-1].source == "sql_schema_selection_agent":
48+
decision = "sql_query_generation_agent"
49+
50+
elif (
51+
messages[-1].source == "sql_query_correction_agent"
52+
and messages[-1].content == "VALIDATED"
53+
):
54+
decision = "answer_agent"
55+
56+
elif messages[-1].source == "sql_query_correction_agent":
57+
decision = "sql_query_correction_agent"
58+
59+
# Log the decision
60+
logging.info("Decision: %s", decision)
61+
62+
return decision
63+
64+
65+
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10)
66+
text_2_sql_generator = SelectorGroupChat(
67+
[
68+
SQL_QUERY_GENERATION_AGENT,
69+
SQL_SCHEMA_SELECTION_AGENT,
70+
SQL_QUERY_CORRECTION_AGENT,
71+
SQL_QUERY_CACHE_AGENT,
72+
ANSWER_AGENT,
73+
QUESTION_DECOMPOSITION_AGENT,
74+
],
75+
allow_repeated_speaker=False,
76+
model_client=MINI_MODEL,
77+
termination_condition=termination,
78+
selector_func=text_2_sql_generator_selector_func,
79+
)

text_2_sql/autogen/custom_agents/__init__.py

Whitespace-only changes.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from typing import AsyncGenerator, List, Sequence
4+
5+
from autogen_agentchat.agents import BaseChatAgent
6+
from autogen_agentchat.base import Response
7+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
8+
from autogen_core.base import CancellationToken
9+
from utils.sql_utils import fetch_queries_from_cache
10+
import json
11+
import logging
12+
13+
14+
class SqlQueryCacheAgent(BaseChatAgent):
15+
def __init__(self):
16+
super().__init__(
17+
"sql_query_cache_agent",
18+
"An agent that fetches the queries from the cache based on the user question.",
19+
)
20+
21+
@property
22+
def produced_message_types(self) -> List[type[ChatMessage]]:
23+
return [TextMessage]
24+
25+
async def on_messages(
26+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
27+
) -> Response:
28+
# Calls the on_messages_stream.
29+
response: Response | None = None
30+
async for message in self.on_messages_stream(messages, cancellation_token):
31+
if isinstance(message, Response):
32+
response = message
33+
assert response is not None
34+
return response
35+
36+
async def on_messages_stream(
37+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
38+
) -> AsyncGenerator[AgentMessage | Response, None]:
39+
user_question = messages[0].content
40+
41+
# Fetch the queries from the cache based on the user question.
42+
logging.info("Fetching queries from cache based on the user question...")
43+
44+
cached_queries = await fetch_queries_from_cache(user_question)
45+
46+
yield Response(
47+
chat_message=TextMessage(
48+
content=json.dumps(cached_queries), source=self.name
49+
)
50+
)
51+
52+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
53+
pass

text_2_sql/autogen/environment.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import os
4+
from enum import Enum
5+
6+
7+
class IdentityType(Enum):
8+
"""The type of the indexer"""
9+
10+
USER_ASSIGNED = "user_assigned"
11+
SYSTEM_ASSIGNED = "system_assigned"
12+
KEY = "key"
13+
14+
15+
def get_identity_type() -> IdentityType:
16+
"""This function returns the identity type.
17+
18+
Returns:
19+
IdentityType: The identity type
20+
"""
21+
identity = os.environ.get("IdentityType")
22+
23+
if identity == "user_assigned":
24+
return IdentityType.USER_ASSIGNED
25+
elif identity == "system_assigned":
26+
return IdentityType.SYSTEM_ASSIGNED
27+
elif identity == "key":
28+
return IdentityType.KEY
29+
else:
30+
raise ValueError("Invalid identity type")

0 commit comments

Comments
 (0)