Skip to content

Commit dd1f062

Browse files
Merge branch 'main' into feature/auto-gen-agent
2 parents 87ad106 + 9943302 commit dd1f062

File tree

9 files changed

+285
-12
lines changed

9 files changed

+285
-12
lines changed

deploy_ai_search/.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ OpenAI__Endpoint=<openAIEndpoint>
1919
OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
2020
OpenAI__EmbeddingDeployment=<openAIEmbeddingDeploymentId>
2121
OpenAI__EmbeddingDimensions=1536
22-
Text2Sql__DatabaseName=<databaseName>
22+
Text2Sql__DatabaseEngine=<databaseEngine SQL Server / Snowflake / Databricks >

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
from environment import (
2525
IndexerType,
2626
)
27+
import os
28+
from enum import StrEnum
29+
30+
31+
class DatabaseEngine(StrEnum):
32+
"""An enumeration to represent a database engine."""
33+
34+
SNOWFLAKE = "SNOWFLAKE"
35+
SQL_SERVER = "SQL_SERVER"
36+
DATABRICKS = "DATABRICKS"
2737

2838

2939
class Text2SqlSchemaStoreAISearch(AISearch):
@@ -42,13 +52,34 @@ def __init__(
4252
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
4353
"""
4454
self.indexer_type = IndexerType.TEXT_2_SQL_SCHEMA_STORE
55+
self.database_engine = DatabaseEngine[
56+
os.environ["Text2Sql__DatabaseEngine"].upper()
57+
]
4558
super().__init__(suffix, rebuild)
4659

4760
if single_data_dictionary:
4861
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
4962
else:
5063
self.parsing_mode = BlobIndexerParsingMode.JSON
5164

65+
@property
66+
def excluded_fields_for_database_engine(self):
67+
"""A method to get the excluded fields for the database engine."""
68+
69+
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
70+
if self.database_engine == DatabaseEngine.SNOWFLAKE:
71+
engine_specific_fields = ["Warehouse", "Database"]
72+
elif self.database_engine == DatabaseEngine.SQL_SERVER:
73+
engine_specific_fields = ["Database"]
74+
elif self.database_engine == DatabaseEngine.DATABRICKS:
75+
engine_specific_fields = ["Catalog"]
76+
77+
return [
78+
field
79+
for field in all_engine_specific_fields
80+
if field not in engine_specific_fields
81+
]
82+
5283
def get_index_fields(self) -> list[SearchableField]:
5384
"""This function returns the index fields for sql index.
5485
@@ -78,6 +109,10 @@ def get_index_fields(self) -> list[SearchableField]:
78109
name="Warehouse",
79110
type=SearchFieldDataType.String,
80111
),
112+
SearchableField(
113+
name="Catalog",
114+
type=SearchFieldDataType.String,
115+
),
81116
SearchableField(
82117
name="Definition",
83118
type=SearchFieldDataType.String,
@@ -161,6 +196,13 @@ def get_index_fields(self) -> list[SearchableField]:
161196
),
162197
]
163198

199+
# Remove fields that are not supported by the database engine
200+
fields = [
201+
field
202+
for field in fields
203+
if field.name not in self.excluded_fields_for_database_engine
204+
]
205+
164206
return fields
165207

166208
def get_semantic_search(self) -> SemanticSearch:
@@ -309,4 +351,12 @@ def get_indexer(self) -> SearchIndexer:
309351
parameters=indexer_parameters,
310352
)
311353

354+
# Remove fields that are not supported by the database engine
355+
indexer.output_field_mappings = [
356+
field_mapping
357+
for field_mapping in indexer.output_field_mappings
358+
if field_mapping.target_field_name
359+
not in self.excluded_fields_for_database_engine
360+
]
361+
312362
return indexer

text_2_sql/data_dictionary/.env

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
33
OpenAI__Endpoint=<openAIEndpoint>
44
OpenAI__ApiKey=<openAIKey if using non managed identity>
55
OpenAI__ApiVersion=<openAIApiVersion>
6-
Text2Sql__DatabaseEngine=<databaseEngine>
76
Text2Sql__DatabaseName=<databaseName>
87
Text2Sql__DatabaseConnectionString=<databaseConnectionString>
98
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>
109
Text2Sql__Snowflake__Password=<snowflakePassword if using Snowflake Data Source>
1110
Text2Sql__Snowflake__Account=<snowflakeAccount if using Snowflake Data Source>
1211
Text2Sql__Snowflake__Warehouse=<snowflakeWarehouse if using Snowflake Data Source>
12+
Text2Sql__Databricks__Catalog=<databricksCatalog if using Databricks Data Source with Unity Catalog>
13+
Text2Sql__Databricks__ServerHostname=<databricksServerHostname if using Databricks Data Source with Unity Catalog>
14+
Text2Sql__Databricks__HttpPath=<databricksHttpPath if using Databricks Data Source with Unity Catalog>
15+
Text2Sql__Databricks__AccessToken=<databricks AccessToken if using Databricks Data Source with Unity Catalog>
1316
IdentityType=<identityType> # system_assigned or user_assigned or key
14-
ClientId=<clientId if using user assigned identity>
17+
ClientId=<clientId if using user assigned identity>

text_2_sql/data_dictionary/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ See `./generated_samples/` for an example output of the script. This can then be
9999

100100
The following Databases have pre-built scripts for them:
101101

102-
- **Microsoft SQL Server:** `sql_server_data_dictionary_creator.py`
102+
- **Databricks:** `databricks_data_dictionary_creator.py`
103103
- **Snowflake:** `snowflake_data_dictionary_creator.py`
104+
- **SQL Server:** `sql_server_data_dictionary_creator.py`
104105

105106
If there is no pre-built script for your database engine, take one of the above as a starting point and adjust it.

text_2_sql/data_dictionary/data_dictionary_creator.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,19 @@
1515
import random
1616
import re
1717
import networkx as nx
18+
from enum import StrEnum
1819

1920
logging.basicConfig(level=logging.INFO)
2021

2122

23+
class DatabaseEngine(StrEnum):
24+
"""An enumeration to represent a database engine."""
25+
26+
SNOWFLAKE = "SNOWFLAKE"
27+
SQL_SERVER = "SQL_SERVER"
28+
DATABRICKS = "DATABRICKS"
29+
30+
2231
class ForeignKeyRelationship(BaseModel):
2332
column: str = Field(..., alias="Column")
2433
foreign_column: str = Field(..., alias="ForeignColumn")
@@ -124,6 +133,7 @@ class EntityItem(BaseModel):
124133
entity_name: Optional[str] = Field(default=None, alias="EntityName")
125134
database: Optional[str] = Field(default=None, alias="Database")
126135
warehouse: Optional[str] = Field(default=None, alias="Warehouse")
136+
catalog: Optional[str] = Field(default=None, alias="Catalog")
127137

128138
entity_relationships: Optional[list[EntityRelationship]] = Field(
129139
alias="EntityRelationships", default_factory=list
@@ -186,6 +196,9 @@ def __init__(
186196

187197
self.warehouse = None
188198
self.database = None
199+
self.catalog = None
200+
201+
self.database_engine = None
189202

190203
load_dotenv(find_dotenv())
191204

@@ -391,6 +404,7 @@ async def extract_entities_with_definitions(self) -> list[EntityItem]:
391404
for entity in all_entities:
392405
entity.warehouse = self.warehouse
393406
entity.database = self.database
407+
entity.catalog = self.catalog
394408

395409
return all_entities
396410

@@ -636,6 +650,24 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:
636650

637651
return entity
638652

653+
@property
654+
def excluded_fields_for_database_engine(self):
655+
"""A method to get the excluded fields for the database engine."""
656+
657+
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
658+
if self.database_engine == DatabaseEngine.SNOWFLAKE:
659+
engine_specific_fields = ["Warehouse", "Database"]
660+
elif self.database_engine == DatabaseEngine.SQL_SERVER:
661+
engine_specific_fields = ["Database"]
662+
elif self.database_engine == DatabaseEngine.DATABRICKS:
663+
engine_specific_fields = ["Catalog"]
664+
665+
return [
666+
field
667+
for field in all_engine_specific_fields
668+
if field not in engine_specific_fields
669+
]
670+
639671
async def create_data_dictionary(self):
640672
"""A method to build a data dictionary from a database. Writes to file."""
641673
entities = await self.extract_entities_with_definitions()
@@ -654,13 +686,28 @@ async def create_data_dictionary(self):
654686
if self.single_file:
655687
logging.info("Saving data dictionary to entities.json")
656688
with open("entities.json", "w", encoding="utf-8") as f:
689+
data_dictionary_dump = [
690+
entity.model_dump(
691+
by_alias=True, exclude=self.excluded_fields_for_database_engine
692+
)
693+
for entity in data_dictionary
694+
]
657695
json.dump(
658-
data_dictionary.model_dump(by_alias=True), f, indent=4, default=str
696+
data_dictionary_dump,
697+
f,
698+
indent=4,
699+
default=str,
659700
)
660701
else:
661702
for entity in data_dictionary:
662703
logging.info(f"Saving data dictionary for {entity.entity}")
663704
with open(f"{entity.entity}.json", "w", encoding="utf-8") as f:
664705
json.dump(
665-
entity.model_dump(by_alias=True), f, indent=4, default=str
706+
entity.model_dump(
707+
by_alias=True,
708+
exclude=self.excluded_fields_for_database_engine,
709+
),
710+
f,
711+
indent=4,
712+
default=str,
666713
)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine
4+
import asyncio
5+
from databricks import sql
6+
import logging
7+
import os
8+
9+
10+
class DatabricksDataDictionaryCreator(DataDictionaryCreator):
11+
def __init__(
12+
self,
13+
entities: list[str] = None,
14+
excluded_entities: list[str] = None,
15+
single_file: bool = False,
16+
):
17+
"""A method to initialize the DataDictionaryCreator class.
18+
19+
Args:
20+
entities (list[str], optional): A list of entities to extract. Defaults to None. If None, all entities are extracted.
21+
excluded_entities (list[str], optional): A list of entities to exclude. Defaults to None.
22+
single_file (bool, optional): A flag to indicate if the data dictionary should be saved to a single file. Defaults to False.
23+
"""
24+
if excluded_entities is None:
25+
excluded_entities = []
26+
27+
excluded_schemas = []
28+
super().__init__(entities, excluded_entities, excluded_schemas, single_file)
29+
30+
self.catalog = os.environ["Text2Sql__Databricks__Catalog"]
31+
self.database_engine = DatabaseEngine.DATABRICKS
32+
33+
"""A class to extract data dictionary information from Databricks Unity Catalog."""
34+
35+
@property
36+
def extract_table_entities_sql_query(self) -> str:
37+
"""A property to extract table entities from Databricks Unity Catalog."""
38+
return f"""SELECT
39+
t.TABLE_NAME AS Entity,
40+
t.TABLE_SCHEMA AS EntitySchema,
41+
t.COMMENT AS Definition
42+
FROM
43+
INFORMATION_SCHEMA.TABLES t
44+
WHERE
45+
t.TABLE_CATALOG = '{self.catalog}'
46+
"""
47+
48+
@property
49+
def extract_view_entities_sql_query(self) -> str:
50+
"""A property to extract view entities from Databricks Unity Catalog."""
51+
return """SELECT
52+
v.TABLE_NAME AS Entity,
53+
v.TABLE_SCHEMA AS EntitySchema
54+
NULL AS Definition
55+
FROM
56+
INFORMATION_SCHEMA.VIEWS v
57+
WHERE
58+
v.TABLE_CATALOG = '{self.catalog}'"""
59+
60+
def extract_columns_sql_query(self, entity: EntityItem) -> str:
61+
"""A property to extract column information from Databricks Unity Catalog."""
62+
return f"""SELECT
63+
COLUMN_NAME AS Name,
64+
DATA_TYPE AS Type,
65+
COMMENT AS Definition
66+
FROM
67+
INFORMATION_SCHEMA.COLUMNS
68+
WHERE
69+
TABLE_CATALOG = '{self.catalog}'
70+
AND TABLE_SCHEMA = '{entity.entity_schema}'
71+
AND TABLE_NAME = '{entity.name}';"""
72+
73+
@property
74+
def extract_entity_relationships_sql_query(self) -> str:
75+
"""A property to extract entity relationships from Databricks Unity Catalog."""
76+
return f"""SELECT
77+
fk_schema.TABLE_SCHEMA AS EntitySchema,
78+
fk_tab.TABLE_NAME AS Entity,
79+
pk_schema.TABLE_SCHEMA AS ForeignEntitySchema,
80+
pk_tab.TABLE_NAME AS ForeignEntity,
81+
fk_col.COLUMN_NAME AS [Column],
82+
pk_col.COLUMN_NAME AS ForeignColumn
83+
FROM
84+
INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS fk
85+
INNER JOIN
86+
INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS fkc
87+
ON fk.constraint_name = fkc.constraint_name
88+
INNER JOIN
89+
INFORMATION_SCHEMA.TABLES AS fk_tab
90+
ON fk_tab.TABLE_NAME = fkc.TABLE_NAME AND fk_tab.TABLE_SCHEMA = fkc.TABLE_SCHEMA
91+
INNER JOIN
92+
INFORMATION_SCHEMA.SCHEMATA AS fk_schema
93+
ON fk_tab.TABLE_SCHEMA = fk_schema.TABLE_SCHEMA
94+
INNER JOIN
95+
INFORMATION_SCHEMA.TABLES AS pk_tab
96+
ON pk_tab.TABLE_NAME = fkc.referenced_TABLE_NAME AND pk_tab.TABLE_SCHEMA = fkc.referenced_TABLE_SCHEMA
97+
INNER JOIN
98+
INFORMATION_SCHEMA.SCHEMATA AS pk_schema
99+
ON pk_tab.TABLE_SCHEMA = pk_schema.TABLE_SCHEMA
100+
INNER JOIN
101+
INFORMATION_SCHEMA.COLUMNS AS fk_col
102+
ON fkc.COLUMN_NAME = fk_col.COLUMN_NAME AND fkc.TABLE_NAME = fk_col.TABLE_NAME AND fkc.TABLE_SCHEMA = fk_col.TABLE_SCHEMA
103+
INNER JOIN
104+
INFORMATION_SCHEMA.COLUMNS AS pk_col
105+
ON fkc.referenced_COLUMN_NAME = pk_col.COLUMN_NAME AND fkc.referenced_TABLE_NAME = pk_col.TABLE_NAME AND fkc.referenced_TABLE_SCHEMA = pk_col.TABLE_SCHEMA
106+
WHERE
107+
fk.constraint_type = 'FOREIGN KEY'
108+
AND fk_tab.TABLE_CATALOG = '{self.catalog}'
109+
AND pk_tab.TABLE_CATALOG = '{self.catalog}'
110+
ORDER BY
111+
EntitySchema, Entity, ForeignEntitySchema, ForeignEntity;
112+
"""
113+
114+
async def query_entities(self, sql_query: str, cast_to: any = None) -> list[dict]:
115+
"""
116+
A method to query a Databricks SQL endpoint for entities.
117+
118+
Args:
119+
sql_query (str): The SQL query to run.
120+
cast_to (any, optional): The class to cast the results to. Defaults to None.
121+
122+
Returns:
123+
list[dict]: The list of entities or processed rows.
124+
"""
125+
logging.info(f"Running query: {sql_query}")
126+
results = []
127+
128+
# Set up connection parameters for Databricks SQL endpoint
129+
connection = sql.connect(
130+
server_hostname=os.environ["Text2Sql__Databricks__ServerHostname"],
131+
http_path=os.environ["Text2Sql__Databricks__HttpPath"],
132+
access_token=os.environ["Text2Sql__Databricks__AccessToken"],
133+
)
134+
135+
try:
136+
# Create a cursor
137+
cursor = connection.cursor()
138+
139+
# Execute the query in a thread-safe manner
140+
await asyncio.to_thread(cursor.execute, sql_query)
141+
142+
# Fetch column names
143+
columns = [col[0] for col in cursor.description]
144+
145+
# Fetch rows
146+
rows = await asyncio.to_thread(cursor.fetchall)
147+
148+
# Process rows
149+
for row in rows:
150+
if cast_to:
151+
results.append(cast_to.from_sql_row(row, columns))
152+
else:
153+
results.append(dict(zip(columns, row)))
154+
155+
except Exception as e:
156+
logging.error(f"Error while executing query: {e}")
157+
raise
158+
finally:
159+
cursor.close()
160+
connection.close()
161+
162+
return results
163+
164+
165+
if __name__ == "__main__":
166+
data_dictionary_creator = DatabricksDataDictionaryCreator()
167+
asyncio.run(data_dictionary_creator.create_data_dictionary())

text_2_sql/data_dictionary/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pydantic
55
openai
66
snowflake-connector-python
77
networkx
8+
databricks

0 commit comments

Comments
 (0)