Skip to content

Commit 439776e

Browse files
committed
Dump of work
1 parent 2962719 commit 439776e

File tree

11 files changed

+233
-60
lines changed

11 files changed

+233
-60
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
4141
)
4242
elif tool_name == "sql_get_entity_schemas_tool":
4343
return FunctionTool(
44-
ai_search_helper.get_entity_schemas,
44+
sql_helper.get_entity_schemas,
4545
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
4646
)
4747
elif tool_name == "sql_get_column_values_tool":

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(self, **kwargs):
2525

2626
self.open_ai_connector = ConnectorFactory.get_open_ai_connector()
2727

28+
self.sql_connector = ConnectorFactory.get_database_connector()
29+
2830
system_prompt = load("sql_schema_selection_agent")["system_message"]
2931

3032
self.system_prompt = Template(system_prompt).render(kwargs)
@@ -75,7 +77,7 @@ async def on_messages_stream(
7577
logging.info(f"Loaded entity result: {loaded_entity_result}")
7678

7779
entity_search_tasks.append(
78-
self.ai_search_connector.get_entity_schemas(
80+
self.sql_connector.get_entity_schemas(
7981
" ".join(loaded_entity_result["entities"]), as_json=False
8082
)
8183
)

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,27 @@ async def run_ai_search_query(
9090
async for result in results.by_page():
9191
async for item in result:
9292
if (
93-
minimum_score is not None
94-
and item["@search.reranker_score"] < minimum_score
93+
"@search.reranker_score" in item
94+
and item["@search.reranker_score"] is not None
9595
):
96+
score = item["@search.reranker_score"]
97+
elif "@search.score" in item and item["@search.score"] is not None:
98+
score = item["@search.score"]
99+
else:
100+
raise Exception("No score found in the search results.")
101+
102+
if minimum_score is not None and score < minimum_score:
96103
continue
97104

98105
if include_scores is False:
99-
del item["@search.reranker_score"]
100-
del item["@search.score"]
101-
del item["@search.highlights"]
102-
del item["@search.captions"]
106+
if "@search.reranker_score" in item:
107+
del item["@search.reranker_score"]
108+
if "@search.score" in item:
109+
del item["@search.score"]
110+
if "@search.highlights" in item:
111+
del item["@search.highlights"]
112+
if "@search.captions" in item:
113+
del item["@search.captions"]
103114

104115
logging.info("Item: %s", item)
105116
combined_results.append(item)
@@ -131,19 +142,31 @@ async def get_column_values(
131142
text = " ".join([f"{word}~" for word in text.split()])
132143
values = await self.run_ai_search_query(
133144
text,
134-
[],
135-
["FQN", "Column", "Value"],
136-
os.environ[
145+
vector_fields=[],
146+
retrieval_fields=["FQN", "Column", "Value"],
147+
index_name=os.environ[
137148
"AIService__AzureSearchOptions__Text2SqlColumnValueStore__Index"
138149
],
139-
None,
150+
semantic_config=None,
140151
top=10,
152+
include_scores=True,
153+
minimum_score=5,
141154
)
142155

156+
# build into a common format
157+
column_values = {}
158+
159+
for value in values:
160+
trimmed_fqn = ".".join(value["FQN"].split(".")[:-1])
161+
if trimmed_fqn not in column_values:
162+
column_values[trimmed_fqn] = []
163+
164+
column_values[trimmed_fqn].append(value["Value"])
165+
143166
if as_json:
144-
return json.dumps(values, default=str)
167+
return json.dumps(column_values, default=str)
145168
else:
146-
return values
169+
return column_values
147170

148171
async def get_entity_schemas(
149172
self,
@@ -155,7 +178,6 @@ async def get_entity_schemas(
155178
list[str],
156179
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
157180
] = [],
158-
as_json: bool = True,
159181
) -> str:
160182
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
161183
@@ -187,19 +209,14 @@ async def get_entity_schemas(
187209
)
188210

189211
for schema in schemas:
190-
entity = schema["Entity"]
191-
192212
filtered_schemas = []
193213
for excluded_entity in excluded_entities:
194-
if excluded_entity.lower() == entity.lower():
214+
if excluded_entity.lower() == schema["Entity"].lower():
195215
logging.info("Excluded entity: %s", excluded_entity)
196216
else:
197217
filtered_schemas.append(schema)
198218

199-
if as_json:
200-
return json.dumps(schemas, default=str)
201-
else:
202-
return schemas
219+
return filtered_schemas
203220

204221
async def add_entry_to_index(document: dict, vector_fields: dict, index_name: str):
205222
"""Add an entry to the search index."""

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import os
88
import logging
9+
import json
910

1011

1112
class DatabricksSqlConnector(SqlConnector):
@@ -69,3 +70,39 @@ async def query_execution(
6970
connection.close()
7071

7172
return results
73+
74+
async def get_entity_schemas(
75+
self,
76+
text: Annotated[
77+
str,
78+
"The text to run a semantic search against. Relevant entities will be returned.",
79+
],
80+
excluded_entities: Annotated[
81+
list[str],
82+
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
83+
] = [],
84+
as_json: bool = True,
85+
) -> str:
86+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
87+
88+
Args:
89+
----
90+
text (str): The text to run the search against.
91+
92+
Returns:
93+
str: The schema of the views or tables in JSON format.
94+
"""
95+
96+
schemas = await self.ai_search_connector.get_entity_schemas(
97+
text, excluded_entities
98+
)
99+
100+
for schema in schemas:
101+
schema["SelectFromEntity"] = ".".join(
102+
[schema["Catalog"], schema["Schema"], schema["Entity"]]
103+
)
104+
105+
if as_json:
106+
return json.dumps(schemas, default=str)
107+
else:
108+
return schemas

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import os
88
import logging
9+
import json
910

1011

1112
class SnowflakeSqlConnector(SqlConnector):
@@ -68,3 +69,44 @@ async def query_execution(
6869
conn.close()
6970

7071
return results
72+
73+
async def get_entity_schemas(
74+
self,
75+
text: Annotated[
76+
str,
77+
"The text to run a semantic search against. Relevant entities will be returned.",
78+
],
79+
excluded_entities: Annotated[
80+
list[str],
81+
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
82+
] = [],
83+
as_json: bool = True,
84+
) -> str:
85+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
86+
87+
Args:
88+
----
89+
text (str): The text to run the search against.
90+
91+
Returns:
92+
str: The schema of the views or tables in JSON format.
93+
"""
94+
95+
schemas = await self.ai_search_connector.get_entity_schemas(
96+
text, excluded_entities
97+
)
98+
99+
for schema in schemas:
100+
schema["SelectFromEntity"] = ".".join(
101+
[
102+
schema["Warehouse"],
103+
schema["Database"],
104+
schema["Schema"],
105+
schema["Entity"],
106+
]
107+
)
108+
109+
if as_json:
110+
return json.dumps(schemas, default=str)
111+
else:
112+
return schemas

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,29 @@ async def query_execution(
5151
list[dict]: The results of the SQL query.
5252
"""
5353

54+
@abstractmethod
55+
async def get_entity_schemas(
56+
self,
57+
text: Annotated[
58+
str,
59+
"The text to run a semantic search against. Relevant entities will be returned.",
60+
],
61+
excluded_entities: Annotated[
62+
list[str],
63+
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
64+
] = [],
65+
as_json: bool = True,
66+
) -> str:
67+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
68+
69+
Args:
70+
----
71+
text (str): The text to run the search against.
72+
73+
Returns:
74+
str: The schema of the views or tables in JSON format.
75+
"""
76+
5477
async def query_execution_with_limit(
5578
self,
5679
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Annotated
66
import os
77
import logging
8+
import json
89

910

1011
class TSQLSqlConnector(SqlConnector):
@@ -48,3 +49,37 @@ async def query_execution(
4849

4950
logging.debug("Results: %s", results)
5051
return results
52+
53+
async def get_entity_schemas(
54+
self,
55+
text: Annotated[
56+
str,
57+
"The text to run a semantic search against. Relevant entities will be returned.",
58+
],
59+
excluded_entities: Annotated[
60+
list[str],
61+
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
62+
] = [],
63+
as_json: bool = True,
64+
) -> str:
65+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
66+
67+
Args:
68+
----
69+
text (str): The text to run the search against.
70+
71+
Returns:
72+
str: The schema of the views or tables in JSON format.
73+
"""
74+
75+
schemas = await self.ai_search_connector.get_entity_schemas(
76+
text, excluded_entities
77+
)
78+
79+
for schema in schemas:
80+
schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]])
81+
82+
if as_json:
83+
return json.dumps(schemas, default=str)
84+
else:
85+
return schemas

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_decomposition_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description:
55
system_message:
66
"You are a helpful AI Assistant that specialises in decomposing complex user questions into smaller parts that can be used in SQL queries.
77
8-
If a user's question is actually a combination of multiple questions, break down the user's question into smaller questions that can be used in SQL queries.
8+
If a user's question is actually a combination of multiple complex questions, break down the user's question into smaller questions that can be used in SQL queries. If a question is simple and is directly answerable by a set of SQL queries, do not decompose the question.
99
1010
Output Info:
1111
Return the decomposed questions to the user in the following format:

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_query_correction_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ system_message:
2020
tools:
2121
- sql_get_entity_schemas_tool
2222
- sql_query_execution_tool
23-
# - current_datetime_tool
23+
- current_datetime_tool

0 commit comments

Comments
 (0)