diff --git a/deploy_ai_search/.env b/deploy_ai_search/.env index 2b858b2..e738621 100644 --- a/deploy_ai_search/.env +++ b/deploy_ai_search/.env @@ -19,4 +19,4 @@ OpenAI__Endpoint= OpenAI__EmbeddingModel= OpenAI__EmbeddingDeployment= OpenAI__EmbeddingDimensions=1536 -Text2Sql__DatabaseName= +Text2Sql__DatabaseEngine= diff --git a/deploy_ai_search/text_2_sql_schema_store.py b/deploy_ai_search/text_2_sql_schema_store.py index 59d981e..5f8c24d 100644 --- a/deploy_ai_search/text_2_sql_schema_store.py +++ b/deploy_ai_search/text_2_sql_schema_store.py @@ -24,6 +24,16 @@ from environment import ( IndexerType, ) +import os +from enum import StrEnum + + +class DatabaseEngine(StrEnum): + """An enumeration to represent a database engine.""" + + SNOWFLAKE = "SNOWFLAKE" + SQL_SERVER = "SQL_SERVER" + DATABRICKS = "DATABRICKS" class Text2SqlSchemaStoreAISearch(AISearch): @@ -42,6 +52,9 @@ def __init__( rebuild (bool, optional): Whether to rebuild the index. Defaults to False. """ self.indexer_type = IndexerType.TEXT_2_SQL_SCHEMA_STORE + self.database_engine = DatabaseEngine[ + os.environ["Text2Sql__DatabaseEngine"].upper() + ] super().__init__(suffix, rebuild) if single_data_dictionary: @@ -49,6 +62,24 @@ def __init__( else: self.parsing_mode = BlobIndexerParsingMode.JSON + @property + def excluded_fields_for_database_engine(self): + """A method to get the excluded fields for the database engine.""" + + all_engine_specific_fields = ["Warehouse", "Database", "Catalog"] + if self.database_engine == DatabaseEngine.SNOWFLAKE: + engine_specific_fields = ["Warehouse", "Database"] + elif self.database_engine == DatabaseEngine.SQL_SERVER: + engine_specific_fields = ["Database"] + elif self.database_engine == DatabaseEngine.DATABRICKS: + engine_specific_fields = ["Catalog"] + + return [ + field + for field in all_engine_specific_fields + if field not in engine_specific_fields + ] + def get_index_fields(self) -> list[SearchableField]: """This function returns the index fields for sql index. @@ -78,6 +109,10 @@ def get_index_fields(self) -> list[SearchableField]: name="Warehouse", type=SearchFieldDataType.String, ), + SearchableField( + name="Catalog", + type=SearchFieldDataType.String, + ), SearchableField( name="Definition", type=SearchFieldDataType.String, @@ -161,6 +196,13 @@ def get_index_fields(self) -> list[SearchableField]: ), ] + # Remove fields that are not supported by the database engine + fields = [ + field + for field in fields + if field.name not in self.excluded_fields_for_database_engine + ] + return fields def get_semantic_search(self) -> SemanticSearch: @@ -309,4 +351,12 @@ def get_indexer(self) -> SearchIndexer: parameters=indexer_parameters, ) + # Remove fields that are not supported by the database engine + indexer.output_field_mappings = [ + field_mapping + for field_mapping in indexer.output_field_mappings + if field_mapping.target_field_name + not in self.excluded_fields_for_database_engine + ] + return indexer diff --git a/text_2_sql/data_dictionary/.env b/text_2_sql/data_dictionary/.env index e5cca6f..ad420ec 100644 --- a/text_2_sql/data_dictionary/.env +++ b/text_2_sql/data_dictionary/.env @@ -3,12 +3,15 @@ OpenAI__EmbeddingModel= OpenAI__Endpoint= OpenAI__ApiKey= OpenAI__ApiVersion= -Text2Sql__DatabaseEngine= Text2Sql__DatabaseName= Text2Sql__DatabaseConnectionString= Text2Sql__Snowflake__User= Text2Sql__Snowflake__Password= Text2Sql__Snowflake__Account= Text2Sql__Snowflake__Warehouse= +Text2Sql__Databricks__Catalog= +Text2Sql__Databricks__ServerHostname= +Text2Sql__Databricks__HttpPath= +Text2Sql__Databricks__AccessToken= IdentityType= # system_assigned or user_assigned or key -ClientId= +ClientId= diff --git a/text_2_sql/data_dictionary/README.md b/text_2_sql/data_dictionary/README.md index 9ebba79..c8492e9 100644 --- a/text_2_sql/data_dictionary/README.md +++ b/text_2_sql/data_dictionary/README.md @@ -99,7 +99,8 @@ See `./generated_samples/` for an example output of the script. This can then be The following Databases have pre-built scripts for them: -- **Microsoft SQL Server:** `sql_server_data_dictionary_creator.py` +- **Databricks:** `databricks_data_dictionary_creator.py` - **Snowflake:** `snowflake_data_dictionary_creator.py` +- **SQL Server:** `sql_server_data_dictionary_creator.py` If there is no pre-built script for your database engine, take one of the above as a starting point and adjust it. diff --git a/text_2_sql/data_dictionary/data_dictionary_creator.py b/text_2_sql/data_dictionary/data_dictionary_creator.py index 24212a7..0da3209 100644 --- a/text_2_sql/data_dictionary/data_dictionary_creator.py +++ b/text_2_sql/data_dictionary/data_dictionary_creator.py @@ -15,10 +15,19 @@ import random import re import networkx as nx +from enum import StrEnum logging.basicConfig(level=logging.INFO) +class DatabaseEngine(StrEnum): + """An enumeration to represent a database engine.""" + + SNOWFLAKE = "SNOWFLAKE" + SQL_SERVER = "SQL_SERVER" + DATABRICKS = "DATABRICKS" + + class ForeignKeyRelationship(BaseModel): column: str = Field(..., alias="Column") foreign_column: str = Field(..., alias="ForeignColumn") @@ -124,6 +133,7 @@ class EntityItem(BaseModel): entity_name: Optional[str] = Field(default=None, alias="EntityName") database: Optional[str] = Field(default=None, alias="Database") warehouse: Optional[str] = Field(default=None, alias="Warehouse") + catalog: Optional[str] = Field(default=None, alias="Catalog") entity_relationships: Optional[list[EntityRelationship]] = Field( alias="EntityRelationships", default_factory=list @@ -186,6 +196,9 @@ def __init__( self.warehouse = None self.database = None + self.catalog = None + + self.database_engine = None load_dotenv(find_dotenv()) @@ -391,6 +404,7 @@ async def extract_entities_with_definitions(self) -> list[EntityItem]: for entity in all_entities: entity.warehouse = self.warehouse entity.database = self.database + entity.catalog = self.catalog return all_entities @@ -636,6 +650,24 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem: return entity + @property + def excluded_fields_for_database_engine(self): + """A method to get the excluded fields for the database engine.""" + + all_engine_specific_fields = ["Warehouse", "Database", "Catalog"] + if self.database_engine == DatabaseEngine.SNOWFLAKE: + engine_specific_fields = ["Warehouse", "Database"] + elif self.database_engine == DatabaseEngine.SQL_SERVER: + engine_specific_fields = ["Database"] + elif self.database_engine == DatabaseEngine.DATABRICKS: + engine_specific_fields = ["Catalog"] + + return [ + field + for field in all_engine_specific_fields + if field not in engine_specific_fields + ] + async def create_data_dictionary(self): """A method to build a data dictionary from a database. Writes to file.""" entities = await self.extract_entities_with_definitions() @@ -654,13 +686,28 @@ async def create_data_dictionary(self): if self.single_file: logging.info("Saving data dictionary to entities.json") with open("entities.json", "w", encoding="utf-8") as f: + data_dictionary_dump = [ + entity.model_dump( + by_alias=True, exclude=self.excluded_fields_for_database_engine + ) + for entity in data_dictionary + ] json.dump( - data_dictionary.model_dump(by_alias=True), f, indent=4, default=str + data_dictionary_dump, + f, + indent=4, + default=str, ) else: for entity in data_dictionary: logging.info(f"Saving data dictionary for {entity.entity}") with open(f"{entity.entity}.json", "w", encoding="utf-8") as f: json.dump( - entity.model_dump(by_alias=True), f, indent=4, default=str + entity.model_dump( + by_alias=True, + exclude=self.excluded_fields_for_database_engine, + ), + f, + indent=4, + default=str, ) diff --git a/text_2_sql/data_dictionary/databricks_data_dictionary_creator.py b/text_2_sql/data_dictionary/databricks_data_dictionary_creator.py new file mode 100644 index 0000000..c6fc17b --- /dev/null +++ b/text_2_sql/data_dictionary/databricks_data_dictionary_creator.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine +import asyncio +from databricks import sql +import logging +import os + + +class DatabricksDataDictionaryCreator(DataDictionaryCreator): + def __init__( + self, + entities: list[str] = None, + excluded_entities: list[str] = None, + single_file: bool = False, + ): + """A method to initialize the DataDictionaryCreator class. + + Args: + entities (list[str], optional): A list of entities to extract. Defaults to None. If None, all entities are extracted. + excluded_entities (list[str], optional): A list of entities to exclude. Defaults to None. + single_file (bool, optional): A flag to indicate if the data dictionary should be saved to a single file. Defaults to False. + """ + if excluded_entities is None: + excluded_entities = [] + + excluded_schemas = [] + super().__init__(entities, excluded_entities, excluded_schemas, single_file) + + self.catalog = os.environ["Text2Sql__Databricks__Catalog"] + self.database_engine = DatabaseEngine.DATABRICKS + + """A class to extract data dictionary information from Databricks Unity Catalog.""" + + @property + def extract_table_entities_sql_query(self) -> str: + """A property to extract table entities from Databricks Unity Catalog.""" + return f"""SELECT + t.TABLE_NAME AS Entity, + t.TABLE_SCHEMA AS EntitySchema, + t.COMMENT AS Definition + FROM + INFORMATION_SCHEMA.TABLES t + WHERE + t.TABLE_CATALOG = '{self.catalog}' + """ + + @property + def extract_view_entities_sql_query(self) -> str: + """A property to extract view entities from Databricks Unity Catalog.""" + return """SELECT + v.TABLE_NAME AS Entity, + v.TABLE_SCHEMA AS EntitySchema + NULL AS Definition + FROM + INFORMATION_SCHEMA.VIEWS v + WHERE + v.TABLE_CATALOG = '{self.catalog}'""" + + def extract_columns_sql_query(self, entity: EntityItem) -> str: + """A property to extract column information from Databricks Unity Catalog.""" + return f"""SELECT + COLUMN_NAME AS Name, + DATA_TYPE AS Type, + COMMENT AS Definition + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_CATALOG = '{self.catalog}' + AND TABLE_SCHEMA = '{entity.entity_schema}' + AND TABLE_NAME = '{entity.name}';""" + + @property + def extract_entity_relationships_sql_query(self) -> str: + """A property to extract entity relationships from Databricks Unity Catalog.""" + return f"""SELECT + fk_schema.TABLE_SCHEMA AS EntitySchema, + fk_tab.TABLE_NAME AS Entity, + pk_schema.TABLE_SCHEMA AS ForeignEntitySchema, + pk_tab.TABLE_NAME AS ForeignEntity, + fk_col.COLUMN_NAME AS [Column], + pk_col.COLUMN_NAME AS ForeignColumn + FROM + INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS fk + INNER JOIN + INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS fkc + ON fk.constraint_name = fkc.constraint_name + INNER JOIN + INFORMATION_SCHEMA.TABLES AS fk_tab + ON fk_tab.TABLE_NAME = fkc.TABLE_NAME AND fk_tab.TABLE_SCHEMA = fkc.TABLE_SCHEMA + INNER JOIN + INFORMATION_SCHEMA.SCHEMATA AS fk_schema + ON fk_tab.TABLE_SCHEMA = fk_schema.TABLE_SCHEMA + INNER JOIN + INFORMATION_SCHEMA.TABLES AS pk_tab + ON pk_tab.TABLE_NAME = fkc.referenced_TABLE_NAME AND pk_tab.TABLE_SCHEMA = fkc.referenced_TABLE_SCHEMA + INNER JOIN + INFORMATION_SCHEMA.SCHEMATA AS pk_schema + ON pk_tab.TABLE_SCHEMA = pk_schema.TABLE_SCHEMA + INNER JOIN + INFORMATION_SCHEMA.COLUMNS AS fk_col + ON fkc.COLUMN_NAME = fk_col.COLUMN_NAME AND fkc.TABLE_NAME = fk_col.TABLE_NAME AND fkc.TABLE_SCHEMA = fk_col.TABLE_SCHEMA + INNER JOIN + INFORMATION_SCHEMA.COLUMNS AS pk_col + ON fkc.referenced_COLUMN_NAME = pk_col.COLUMN_NAME AND fkc.referenced_TABLE_NAME = pk_col.TABLE_NAME AND fkc.referenced_TABLE_SCHEMA = pk_col.TABLE_SCHEMA + WHERE + fk.constraint_type = 'FOREIGN KEY' + AND fk_tab.TABLE_CATALOG = '{self.catalog}' + AND pk_tab.TABLE_CATALOG = '{self.catalog}' + ORDER BY + EntitySchema, Entity, ForeignEntitySchema, ForeignEntity; + """ + + async def query_entities(self, sql_query: str, cast_to: any = None) -> list[dict]: + """ + A method to query a Databricks SQL endpoint for entities. + + Args: + sql_query (str): The SQL query to run. + cast_to (any, optional): The class to cast the results to. Defaults to None. + + Returns: + list[dict]: The list of entities or processed rows. + """ + logging.info(f"Running query: {sql_query}") + results = [] + + # Set up connection parameters for Databricks SQL endpoint + connection = sql.connect( + server_hostname=os.environ["Text2Sql__Databricks__ServerHostname"], + http_path=os.environ["Text2Sql__Databricks__HttpPath"], + access_token=os.environ["Text2Sql__Databricks__AccessToken"], + ) + + try: + # Create a cursor + cursor = connection.cursor() + + # Execute the query in a thread-safe manner + await asyncio.to_thread(cursor.execute, sql_query) + + # Fetch column names + columns = [col[0] for col in cursor.description] + + # Fetch rows + rows = await asyncio.to_thread(cursor.fetchall) + + # Process rows + for row in rows: + if cast_to: + results.append(cast_to.from_sql_row(row, columns)) + else: + results.append(dict(zip(columns, row))) + + except Exception as e: + logging.error(f"Error while executing query: {e}") + raise + finally: + cursor.close() + connection.close() + + return results + + +if __name__ == "__main__": + data_dictionary_creator = DatabricksDataDictionaryCreator() + asyncio.run(data_dictionary_creator.create_data_dictionary()) diff --git a/text_2_sql/data_dictionary/requirements.txt b/text_2_sql/data_dictionary/requirements.txt index 903dd0c..c8cc551 100644 --- a/text_2_sql/data_dictionary/requirements.txt +++ b/text_2_sql/data_dictionary/requirements.txt @@ -5,3 +5,4 @@ pydantic openai snowflake-connector-python networkx +databricks diff --git a/text_2_sql/data_dictionary/snowflake_data_dictionary_creator.py b/text_2_sql/data_dictionary/snowflake_data_dictionary_creator.py index e01c0c2..0232a50 100644 --- a/text_2_sql/data_dictionary/snowflake_data_dictionary_creator.py +++ b/text_2_sql/data_dictionary/snowflake_data_dictionary_creator.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from data_dictionary_creator import DataDictionaryCreator, EntityItem +from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine import asyncio import snowflake.connector import logging @@ -25,9 +25,11 @@ def __init__( excluded_entities = [] excluded_schemas = ["INFORMATION_SCHEMA"] - return super().__init__( - entities, excluded_entities, excluded_schemas, single_file - ) + super().__init__(entities, excluded_entities, excluded_schemas, single_file) + + self.database = os.environ["Text2Sql__DatabaseName"] + self.warehouse = os.environ["Text2Sql__Snowflake__Warehouse"] + self.database_engine = DatabaseEngine.SNOWFLAKE """A class to extract data dictionary information from a Snowflake database.""" @@ -65,7 +67,7 @@ def extract_columns_sql_query(self, entity: EntityItem) -> str: @property def extract_entity_relationships_sql_query(self) -> str: - """A property to extract entity relationships from a SQL Server database.""" + """A property to extract entity relationships from a Snowflake database.""" return """SELECT tc.table_schema AS EntitySchema, tc.table_name AS Entity, diff --git a/text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py b/text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py index 235ff00..2b421d2 100644 --- a/text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py +++ b/text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from data_dictionary_creator import DataDictionaryCreator, EntityItem +from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine import asyncio import os @@ -26,6 +26,8 @@ def __init__( super().__init__(entities, excluded_entities, excluded_schemas, single_file) self.database = os.environ["Text2Sql__DatabaseName"] + self.database_engine = DatabaseEngine.SQL_SERVER + """A class to extract data dictionary information from a SQL Server database.""" @property