Skip to content

Commit 74d28ba

Browse files
committed
Update snowflake script
1 parent d035ba2 commit 74d28ba

File tree

5 files changed

+131
-3
lines changed

5 files changed

+131
-3
lines changed

text_2_sql/data_dictionary/.env

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Text2Sql__UseQueryCache=<whether to use the query cache first or not>
88
Text2Sql__PreRunQueryCache=<whether to pre-run the top result from the query cache or not>
99
Text2Sql__DatabaseName=<databaseName>
1010
Text2Sql__DatabaseConnectionString=<databaseConnectionString>
11+
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>
12+
Text2Sql__Snowflake__Password=<snowflakePassword if using Snowflake Data Source>
13+
Text2Sql__Snowflake__Account=<snowflakeAccount if using Snowflake Data Source>
14+
Text2Sql__Snowflake__Warehouse=<snowflakeWarehouse if using Snowflake Data Source>
1115
AIService__AzureSearchOptions__Endpoint=<searchServiceEndpoint>
1216
AIService__AzureSearchOptions__Key=<searchServiceKey if not using identity>
1317
AIService__AzureSearchOptions__RagDocuments__Index=<ragDocumentsIndexName>

text_2_sql/data_dictionary/data_dictionary_creator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
self,
7979
entities: list[str] = None,
8080
excluded_entities: list[str] = None,
81+
excluded_schemas: list[str] = None,
8182
single_file: bool = False,
8283
generate_descriptions: bool = True,
8384
):
@@ -86,12 +87,14 @@ def __init__(
8687
Args:
8788
entities (list[str], optional): A list of entities to extract. Defaults to None. If None, all entities are extracted.
8889
excluded_entities (list[str], optional): A list of entities to exclude. Defaults to None.
90+
excluded_schemas (list[str], optional): A list of schemas to exclude. Defaults to None.
8991
single_file (bool, optional): A flag to indicate if the data dictionary should be saved to a single file. Defaults to False.
9092
generate_descriptions (bool, optional): A flag to indicate if descriptions should be generated. Defaults to True.
9193
"""
9294

9395
self.entities = entities
9496
self.excluded_entities = excluded_entities
97+
self.excluded_schemas = excluded_schemas
9598
self.single_file = single_file
9699
self.generate_descriptions = generate_descriptions
97100

@@ -189,6 +192,7 @@ async def extract_entities_with_descriptions(self) -> list[EntityItem]:
189192
entity
190193
for entity in all_entities
191194
if entity.entity not in self.excluded_entities
195+
and entity.entity_schema not in self.excluded_schemas
192196
]
193197

194198
return all_entities

text_2_sql/data_dictionary/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ azure-identity
33
python-dotenv
44
pydantic
55
openai
6+
snowflake-connector-python
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from data_dictionary_creator import DataDictionaryCreator, EntityItem
4+
import asyncio
5+
import snowflake.connector
6+
import logging
7+
import os
8+
9+
10+
class SnowflakeDataDictionaryCreator(DataDictionaryCreator):
11+
def __init__(
12+
self,
13+
entities: list[str] = None,
14+
excluded_entities: list[str] = None,
15+
single_file: bool = False,
16+
):
17+
"""A method to initialize the DataDictionaryCreator class.
18+
19+
Args:
20+
entities (list[str], optional): A list of entities to extract. Defaults to None. If None, all entities are extracted.
21+
excluded_entities (list[str], optional): A list of entities to exclude. Defaults to None.
22+
single_file (bool, optional): A flag to indicate if the data dictionary should be saved to a single file. Defaults to False.
23+
"""
24+
if excluded_entities is None:
25+
excluded_entities = []
26+
27+
excluded_schemas = ["dbo", "sys"]
28+
return super().__init__(
29+
entities, excluded_entities, excluded_schemas, single_file
30+
)
31+
32+
"""A class to extract data dictionary information from a Snowflake database."""
33+
34+
@property
35+
def extract_table_entities_sql_query(self) -> str:
36+
"""A property to extract table entities from a Snowflake database."""
37+
return """SELECT
38+
t.TABLE_NAME AS Entity,
39+
t.TABLE_SCHEMA AS EntitySchema,
40+
t.COMMENT AS Description
41+
FROM
42+
INFORMATION_SCHEMA.TABLES t"""
43+
44+
@property
45+
def extract_view_entities_sql_query(self) -> str:
46+
"""A property to extract view entities from a Snowflake database."""
47+
return """SELECT
48+
v.TABLE_NAME AS Entity,
49+
v.TABLE_SCHEMA AS EntitySchema,
50+
v.COMMENT AS Description
51+
FROM
52+
INFORMATION_SCHEMA.VIEWS v"""
53+
54+
def extract_columns_sql_query(self, entity: EntityItem) -> str:
55+
"""A property to extract column information from a Snowflake database."""
56+
return f"""SELECT
57+
COLUMN_NAME AS Name,
58+
DATA_TYPE AS Type,
59+
COMMENT AS Definition
60+
FROM
61+
INFORMATION_SCHEMA.COLUMNS
62+
WHERE
63+
TABLE_SCHEMA = '{entity.entity_schema}'
64+
AND TABLE_NAME = '{entity.name}';"""
65+
66+
async def query_entities(
67+
self, sql_query: str, cast_to: any = None
68+
) -> list[EntityItem]:
69+
"""A method to query a database for entities using Snowflake Connector. Overrides the base class method.
70+
71+
Args:
72+
sql_query (str): The SQL query to run.
73+
cast_to (any, optional): The class to cast the results to. Defaults to None.
74+
75+
Returns:
76+
list[EntityItem]: The list of entities.
77+
"""
78+
logging.info(f"Running query: {sql_query}")
79+
results = []
80+
81+
# Create a connection to Snowflake, without specifying a schema
82+
conn = snowflake.connector.connect(
83+
user=os.environ["Text2Sql__Snowflake__User"],
84+
password=os.environ["Text2Sql__Snowflake__Password"],
85+
account=os.environ["Text2Sql__Snowflake__Account"],
86+
warehouse=os.environ["Text2Sql__Snowflake__Warehouse"],
87+
database=os.environ["Text2Sql__DatabaseName"],
88+
)
89+
90+
try:
91+
# Using the connection to create a cursor
92+
cursor = conn.cursor()
93+
94+
# Execute the query
95+
await asyncio.to_thread(cursor.execute, sql_query)
96+
97+
# Fetch column names
98+
columns = [col[0] for col in cursor.description]
99+
100+
# Fetch rows
101+
rows = await asyncio.to_thread(cursor.fetchall)
102+
103+
# Process rows
104+
for row in rows:
105+
if cast_to:
106+
results.append(cast_to.from_sql_row(row, columns))
107+
else:
108+
results.append(dict(zip(columns, row)))
109+
110+
finally:
111+
cursor.close()
112+
conn.close()
113+
114+
return results
115+
116+
117+
if __name__ == "__main__":
118+
data_dictionary_creator = SnowflakeDataDictionaryCreator()
119+
asyncio.run(data_dictionary_creator.create_data_dictionary())

text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def __init__(
2121
if excluded_entities is None:
2222
excluded_entities = []
2323

24-
excluded_entities.extend(
25-
["dbo.BuildVersion", "dbo.ErrorLog", "sys.database_firewall_rules"]
24+
excluded_schemas = ["dbo", "sys"]
25+
return super().__init__(
26+
entities, excluded_entities, excluded_schemas, single_file
2627
)
27-
return super().__init__(entities, excluded_entities, single_file)
2828

2929
"""A class to extract data dictionary information from a SQL Server database."""
3030

0 commit comments

Comments
 (0)