Skip to content

Databricks Data Dictionary Creator #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deploy_ai_search/.env
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ OpenAI__Endpoint=<openAIEndpoint>
OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
OpenAI__EmbeddingDeployment=<openAIEmbeddingDeploymentId>
OpenAI__EmbeddingDimensions=1536
Text2Sql__DatabaseName=<databaseName>
Text2Sql__DatabaseEngine=<databaseEngine SQL Server / Snowflake / Databricks >
50 changes: 50 additions & 0 deletions deploy_ai_search/text_2_sql_schema_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from environment import (
IndexerType,
)
import os
from enum import StrEnum


class DatabaseEngine(StrEnum):
"""An enumeration to represent a database engine."""

SNOWFLAKE = "SNOWFLAKE"
SQL_SERVER = "SQL_SERVER"
DATABRICKS = "DATABRICKS"


class Text2SqlSchemaStoreAISearch(AISearch):
Expand All @@ -42,13 +52,34 @@ def __init__(
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
"""
self.indexer_type = IndexerType.TEXT_2_SQL_SCHEMA_STORE
self.database_engine = DatabaseEngine[
os.environ["Text2Sql__DatabaseEngine"].upper()
]
super().__init__(suffix, rebuild)

if single_data_dictionary:
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
else:
self.parsing_mode = BlobIndexerParsingMode.JSON

@property
def excluded_fields_for_database_engine(self):
"""A method to get the excluded fields for the database engine."""

all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
if self.database_engine == DatabaseEngine.SNOWFLAKE:
engine_specific_fields = ["Warehouse", "Database"]
elif self.database_engine == DatabaseEngine.SQL_SERVER:
engine_specific_fields = ["Database"]
elif self.database_engine == DatabaseEngine.DATABRICKS:
engine_specific_fields = ["Catalog"]

return [
field
for field in all_engine_specific_fields
if field not in engine_specific_fields
]

def get_index_fields(self) -> list[SearchableField]:
"""This function returns the index fields for sql index.

Expand Down Expand Up @@ -78,6 +109,10 @@ def get_index_fields(self) -> list[SearchableField]:
name="Warehouse",
type=SearchFieldDataType.String,
),
SearchableField(
name="Catalog",
type=SearchFieldDataType.String,
),
SearchableField(
name="Definition",
type=SearchFieldDataType.String,
Expand Down Expand Up @@ -161,6 +196,13 @@ def get_index_fields(self) -> list[SearchableField]:
),
]

# Remove fields that are not supported by the database engine
fields = [
field
for field in fields
if field.name not in self.excluded_fields_for_database_engine
]

return fields

def get_semantic_search(self) -> SemanticSearch:
Expand Down Expand Up @@ -309,4 +351,12 @@ def get_indexer(self) -> SearchIndexer:
parameters=indexer_parameters,
)

# Remove fields that are not supported by the database engine
indexer.output_field_mappings = [
field_mapping
for field_mapping in indexer.output_field_mappings
if field_mapping.target_field_name
not in self.excluded_fields_for_database_engine
]

return indexer
7 changes: 5 additions & 2 deletions text_2_sql/data_dictionary/.env
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
OpenAI__Endpoint=<openAIEndpoint>
OpenAI__ApiKey=<openAIKey if using non managed identity>
OpenAI__ApiVersion=<openAIApiVersion>
Text2Sql__DatabaseEngine=<databaseEngine>
Text2Sql__DatabaseName=<databaseName>
Text2Sql__DatabaseConnectionString=<databaseConnectionString>
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>
Text2Sql__Snowflake__Password=<snowflakePassword if using Snowflake Data Source>
Text2Sql__Snowflake__Account=<snowflakeAccount if using Snowflake Data Source>
Text2Sql__Snowflake__Warehouse=<snowflakeWarehouse if using Snowflake Data Source>
Text2Sql__Databricks__Catalog=<databricksCatalog if using Databricks Data Source with Unity Catalog>
Text2Sql__Databricks__ServerHostname=<databricksServerHostname if using Databricks Data Source with Unity Catalog>
Text2Sql__Databricks__HttpPath=<databricksHttpPath if using Databricks Data Source with Unity Catalog>
Text2Sql__Databricks__AccessToken=<databricks AccessToken if using Databricks Data Source with Unity Catalog>
IdentityType=<identityType> # system_assigned or user_assigned or key
ClientId=<clientId if using user assigned identity>
ClientId=<clientId if using user assigned identity>
3 changes: 2 additions & 1 deletion text_2_sql/data_dictionary/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ See `./generated_samples/` for an example output of the script. This can then be

The following Databases have pre-built scripts for them:

- **Microsoft SQL Server:** `sql_server_data_dictionary_creator.py`
- **Databricks:** `databricks_data_dictionary_creator.py`
- **Snowflake:** `snowflake_data_dictionary_creator.py`
- **SQL Server:** `sql_server_data_dictionary_creator.py`

If there is no pre-built script for your database engine, take one of the above as a starting point and adjust it.
51 changes: 49 additions & 2 deletions text_2_sql/data_dictionary/data_dictionary_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
import random
import re
import networkx as nx
from enum import StrEnum

logging.basicConfig(level=logging.INFO)


class DatabaseEngine(StrEnum):
"""An enumeration to represent a database engine."""

SNOWFLAKE = "SNOWFLAKE"
SQL_SERVER = "SQL_SERVER"
DATABRICKS = "DATABRICKS"


class ForeignKeyRelationship(BaseModel):
column: str = Field(..., alias="Column")
foreign_column: str = Field(..., alias="ForeignColumn")
Expand Down Expand Up @@ -124,6 +133,7 @@ class EntityItem(BaseModel):
entity_name: Optional[str] = Field(default=None, alias="EntityName")
database: Optional[str] = Field(default=None, alias="Database")
warehouse: Optional[str] = Field(default=None, alias="Warehouse")
catalog: Optional[str] = Field(default=None, alias="Catalog")

entity_relationships: Optional[list[EntityRelationship]] = Field(
alias="EntityRelationships", default_factory=list
Expand Down Expand Up @@ -186,6 +196,9 @@ def __init__(

self.warehouse = None
self.database = None
self.catalog = None

self.database_engine = None

load_dotenv(find_dotenv())

Expand Down Expand Up @@ -391,6 +404,7 @@ async def extract_entities_with_definitions(self) -> list[EntityItem]:
for entity in all_entities:
entity.warehouse = self.warehouse
entity.database = self.database
entity.catalog = self.catalog

return all_entities

Expand Down Expand Up @@ -636,6 +650,24 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:

return entity

@property
def excluded_fields_for_database_engine(self):
"""A method to get the excluded fields for the database engine."""

all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
if self.database_engine == DatabaseEngine.SNOWFLAKE:
engine_specific_fields = ["Warehouse", "Database"]
elif self.database_engine == DatabaseEngine.SQL_SERVER:
engine_specific_fields = ["Database"]
elif self.database_engine == DatabaseEngine.DATABRICKS:
engine_specific_fields = ["Catalog"]

return [
field
for field in all_engine_specific_fields
if field not in engine_specific_fields
]

async def create_data_dictionary(self):
"""A method to build a data dictionary from a database. Writes to file."""
entities = await self.extract_entities_with_definitions()
Expand All @@ -654,13 +686,28 @@ async def create_data_dictionary(self):
if self.single_file:
logging.info("Saving data dictionary to entities.json")
with open("entities.json", "w", encoding="utf-8") as f:
data_dictionary_dump = [
entity.model_dump(
by_alias=True, exclude=self.excluded_fields_for_database_engine
)
for entity in data_dictionary
]
json.dump(
data_dictionary.model_dump(by_alias=True), f, indent=4, default=str
data_dictionary_dump,
f,
indent=4,
default=str,
)
else:
for entity in data_dictionary:
logging.info(f"Saving data dictionary for {entity.entity}")
with open(f"{entity.entity}.json", "w", encoding="utf-8") as f:
json.dump(
entity.model_dump(by_alias=True), f, indent=4, default=str
entity.model_dump(
by_alias=True,
exclude=self.excluded_fields_for_database_engine,
),
f,
indent=4,
default=str,
)
167 changes: 167 additions & 0 deletions text_2_sql/data_dictionary/databricks_data_dictionary_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine
import asyncio
from databricks import sql
import logging
import os


class DatabricksDataDictionaryCreator(DataDictionaryCreator):
def __init__(
self,
entities: list[str] = None,
excluded_entities: list[str] = None,
single_file: bool = False,
):
"""A method to initialize the DataDictionaryCreator class.

Args:
entities (list[str], optional): A list of entities to extract. Defaults to None. If None, all entities are extracted.
excluded_entities (list[str], optional): A list of entities to exclude. Defaults to None.
single_file (bool, optional): A flag to indicate if the data dictionary should be saved to a single file. Defaults to False.
"""
if excluded_entities is None:
excluded_entities = []

excluded_schemas = []
super().__init__(entities, excluded_entities, excluded_schemas, single_file)

self.catalog = os.environ["Text2Sql__Databricks__Catalog"]
self.database_engine = DatabaseEngine.DATABRICKS

"""A class to extract data dictionary information from Databricks Unity Catalog."""

@property
def extract_table_entities_sql_query(self) -> str:
"""A property to extract table entities from Databricks Unity Catalog."""
return f"""SELECT
t.TABLE_NAME AS Entity,
t.TABLE_SCHEMA AS EntitySchema,
t.COMMENT AS Definition
FROM
INFORMATION_SCHEMA.TABLES t
WHERE
t.TABLE_CATALOG = '{self.catalog}'
"""

@property
def extract_view_entities_sql_query(self) -> str:
"""A property to extract view entities from Databricks Unity Catalog."""
return """SELECT
v.TABLE_NAME AS Entity,
v.TABLE_SCHEMA AS EntitySchema
NULL AS Definition
FROM
INFORMATION_SCHEMA.VIEWS v
WHERE
v.TABLE_CATALOG = '{self.catalog}'"""

def extract_columns_sql_query(self, entity: EntityItem) -> str:
"""A property to extract column information from Databricks Unity Catalog."""
return f"""SELECT
COLUMN_NAME AS Name,
DATA_TYPE AS Type,
COMMENT AS Definition
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_CATALOG = '{self.catalog}'
AND TABLE_SCHEMA = '{entity.entity_schema}'
AND TABLE_NAME = '{entity.name}';"""

@property
def extract_entity_relationships_sql_query(self) -> str:
"""A property to extract entity relationships from Databricks Unity Catalog."""
return f"""SELECT
fk_schema.TABLE_SCHEMA AS EntitySchema,
fk_tab.TABLE_NAME AS Entity,
pk_schema.TABLE_SCHEMA AS ForeignEntitySchema,
pk_tab.TABLE_NAME AS ForeignEntity,
fk_col.COLUMN_NAME AS [Column],
pk_col.COLUMN_NAME AS ForeignColumn
FROM
INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS fk
INNER JOIN
INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS fkc
ON fk.constraint_name = fkc.constraint_name
INNER JOIN
INFORMATION_SCHEMA.TABLES AS fk_tab
ON fk_tab.TABLE_NAME = fkc.TABLE_NAME AND fk_tab.TABLE_SCHEMA = fkc.TABLE_SCHEMA
INNER JOIN
INFORMATION_SCHEMA.SCHEMATA AS fk_schema
ON fk_tab.TABLE_SCHEMA = fk_schema.TABLE_SCHEMA
INNER JOIN
INFORMATION_SCHEMA.TABLES AS pk_tab
ON pk_tab.TABLE_NAME = fkc.referenced_TABLE_NAME AND pk_tab.TABLE_SCHEMA = fkc.referenced_TABLE_SCHEMA
INNER JOIN
INFORMATION_SCHEMA.SCHEMATA AS pk_schema
ON pk_tab.TABLE_SCHEMA = pk_schema.TABLE_SCHEMA
INNER JOIN
INFORMATION_SCHEMA.COLUMNS AS fk_col
ON fkc.COLUMN_NAME = fk_col.COLUMN_NAME AND fkc.TABLE_NAME = fk_col.TABLE_NAME AND fkc.TABLE_SCHEMA = fk_col.TABLE_SCHEMA
INNER JOIN
INFORMATION_SCHEMA.COLUMNS AS pk_col
ON fkc.referenced_COLUMN_NAME = pk_col.COLUMN_NAME AND fkc.referenced_TABLE_NAME = pk_col.TABLE_NAME AND fkc.referenced_TABLE_SCHEMA = pk_col.TABLE_SCHEMA
WHERE
fk.constraint_type = 'FOREIGN KEY'
AND fk_tab.TABLE_CATALOG = '{self.catalog}'
AND pk_tab.TABLE_CATALOG = '{self.catalog}'
ORDER BY
EntitySchema, Entity, ForeignEntitySchema, ForeignEntity;
"""

async def query_entities(self, sql_query: str, cast_to: any = None) -> list[dict]:
"""
A method to query a Databricks SQL endpoint for entities.

Args:
sql_query (str): The SQL query to run.
cast_to (any, optional): The class to cast the results to. Defaults to None.

Returns:
list[dict]: The list of entities or processed rows.
"""
logging.info(f"Running query: {sql_query}")
results = []

# Set up connection parameters for Databricks SQL endpoint
connection = sql.connect(
server_hostname=os.environ["Text2Sql__Databricks__ServerHostname"],
http_path=os.environ["Text2Sql__Databricks__HttpPath"],
access_token=os.environ["Text2Sql__Databricks__AccessToken"],
)

try:
# Create a cursor
cursor = connection.cursor()

# Execute the query in a thread-safe manner
await asyncio.to_thread(cursor.execute, sql_query)

# Fetch column names
columns = [col[0] for col in cursor.description]

# Fetch rows
rows = await asyncio.to_thread(cursor.fetchall)

# Process rows
for row in rows:
if cast_to:
results.append(cast_to.from_sql_row(row, columns))
else:
results.append(dict(zip(columns, row)))

except Exception as e:
logging.error(f"Error while executing query: {e}")
raise
finally:
cursor.close()
connection.close()

return results


if __name__ == "__main__":
data_dictionary_creator = DatabricksDataDictionaryCreator()
asyncio.run(data_dictionary_creator.create_data_dictionary())
1 change: 1 addition & 0 deletions text_2_sql/data_dictionary/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pydantic
openai
snowflake-connector-python
networkx
databricks
Loading
Loading