Skip to content

Commit b4662a6

Browse files
PostgreSQL Support (#127)
* Add postgres support * Store excluded fields in * Fix bad property * Add postgres identifiers * Update lock file * Update postgresql commands * Make optional * Update postgresql queries
1 parent f1fd21c commit b4662a6

File tree

14 files changed

+373
-40
lines changed

14 files changed

+373
-40
lines changed

deploy_ai_search/pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,14 @@ dev = [
2626

2727
[tool.uv.sources]
2828
text_2_sql_core = { workspace = true }
29+
30+
[project.optional-dependencies]
31+
snowflake = [
32+
"text_2_sql_core[snowflake]",
33+
]
34+
databricks = [
35+
"text_2_sql_core[databricks]",
36+
]
37+
postgresql = [
38+
"text_2_sql_core[postgresql]",
39+
]

deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
import os
2828
from text_2_sql_core.utils.database import DatabaseEngine
29+
from text_2_sql_core.connectors.factory import ConnectorFactory
2930

3031

3132
class Text2SqlSchemaStoreAISearch(AISearch):
@@ -49,29 +50,13 @@ def __init__(
4950
os.environ["Text2Sql__DatabaseEngine"].upper()
5051
]
5152

53+
self.database_connector = ConnectorFactory.get_database_connector()
54+
5255
if single_data_dictionary_file:
5356
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
5457
else:
5558
self.parsing_mode = BlobIndexerParsingMode.JSON
5659

57-
@property
58-
def excluded_fields_for_database_engine(self):
59-
"""A method to get the excluded fields for the database engine."""
60-
61-
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
62-
if self.database_engine == DatabaseEngine.SNOWFLAKE:
63-
engine_specific_fields = ["Warehouse", "Database"]
64-
elif self.database_engine == DatabaseEngine.TSQL:
65-
engine_specific_fields = ["Database"]
66-
elif self.database_engine == DatabaseEngine.DATABRICKS:
67-
engine_specific_fields = ["Catalog"]
68-
69-
return [
70-
field
71-
for field in all_engine_specific_fields
72-
if field not in engine_specific_fields
73-
]
74-
7560
def get_index_fields(self) -> list[SearchableField]:
7661
"""This function returns the index fields for sql index.
7762
@@ -196,7 +181,7 @@ def get_index_fields(self) -> list[SearchableField]:
196181
fields = [
197182
field
198183
for field in fields
199-
if field.name not in self.excluded_fields_for_database_engine
184+
if field.name not in self.database_connector.excluded_engine_specific_fields
200185
]
201186

202187
return fields

text_2_sql/autogen/pyproject.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111
"autogen-ext[azure,openai]==0.4.0.dev11",
1212
"grpcio>=1.68.1",
1313
"pyyaml>=6.0.2",
14-
"text_2_sql_core[snowflake,databricks]",
14+
"text_2_sql_core",
1515
]
1616

1717
[dependency-groups]
@@ -28,3 +28,14 @@ dev = [
2828

2929
[tool.uv.sources]
3030
text_2_sql_core = { workspace = true }
31+
32+
[project.optional-dependencies]
33+
snowflake = [
34+
"text_2_sql_core[snowflake]",
35+
]
36+
databricks = [
37+
"text_2_sql_core[databricks]",
38+
]
39+
postgresql = [
40+
"text_2_sql_core[postgresql]",
41+
]

text_2_sql/text_2_sql_core/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ databricks = [
4646
"databricks-sql-connector>=3.0.1",
4747
"pyarrow>=14.0.2,<17",
4848
]
49+
postgresql = [
50+
"psycopg>=3.2.3",
51+
]
52+
4953

5054
[build-system]
5155
requires = ["hatchling"]

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import json
1010

11-
from text_2_sql_core.utils.database import DatabaseEngine
11+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1212

1313

1414
class DatabricksSqlConnector(SqlConnector):
@@ -17,6 +17,11 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.DATABRICKS
1919

20+
@property
21+
def engine_specific_fields(self) -> list[str]:
22+
"""Get the engine specific fields."""
23+
return [DatabaseEngineSpecificFields.CATALOG]
24+
2025
@property
2126
def invalid_identifiers(self) -> list[str]:
2227
"""Get the invalid identifiers upon which a sql query is rejected."""

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def get_database_connector():
2525
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector
2626

2727
return TSQLSqlConnector()
28+
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL":
29+
from text_2_sql_core.connectors.postgresql_sql import (
30+
PostgresqlSqlConnector,
31+
)
32+
33+
return PostgresqlSqlConnector()
2834
else:
2935
raise ValueError(
3036
f"""Database engine {
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from text_2_sql_core.connectors.sql import SqlConnector
4+
import psycopg
5+
from typing import Annotated
6+
import os
7+
import logging
8+
import json
9+
10+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
11+
12+
13+
class PostgresqlSqlConnector(SqlConnector):
14+
def __init__(self):
15+
super().__init__()
16+
17+
self.database_engine = DatabaseEngine.POSTGRESQL
18+
19+
@property
20+
def engine_specific_fields(self) -> list[str]:
21+
"""Get the engine specific fields."""
22+
return [DatabaseEngineSpecificFields.DATABASE]
23+
24+
@property
25+
def invalid_identifiers(self) -> list[str]:
26+
"""Get the invalid identifiers upon which a sql query is rejected."""
27+
28+
return [
29+
"CURRENT_USER", # Returns the name of the current user
30+
"SESSION_USER", # Returns the name of the user that initiated the session
31+
"USER", # Returns the name of the current user
32+
"CURRENT_ROLE", # Returns the current role
33+
"CURRENT_DATABASE", # Returns the name of the current database
34+
"CURRENT_SCHEMA()", # Returns the name of the current schema
35+
"CURRENT_SETTING()", # Returns the value of a specified configuration parameter
36+
"PG_CURRENT_XACT_ID()", # Returns the current transaction ID
37+
# (if the extension is enabled) Provides a view of query statistics
38+
"PG_STAT_STATEMENTS()",
39+
"PG_SLEEP()", # Delays execution by the specified number of seconds
40+
"CLIENT_ADDR()", # Returns the IP address of the client (from pg_stat_activity)
41+
"CLIENT_HOSTNAME()", # Returns the hostname of the client (from pg_stat_activity)
42+
"PGP_SYM_DECRYPT()", # (from pgcrypto extension) Symmetric decryption function
43+
"PGP_PUB_DECRYPT()", # (from pgcrypto extension) Asymmetric decryption function
44+
]
45+
46+
async def query_execution(
47+
self,
48+
sql_query: Annotated[str, "The SQL query to run against the database."],
49+
cast_to: any = None,
50+
limit=None,
51+
) -> list[dict]:
52+
"""Run the SQL query against the PostgreSQL database asynchronously.
53+
54+
Args:
55+
----
56+
sql_query (str): The SQL query to run against the database.
57+
58+
Returns:
59+
-------
60+
list[dict]: The results of the SQL query.
61+
"""
62+
logging.info(f"Running query: {sql_query}")
63+
results = []
64+
connection_string = os.environ["Text2Sql__DatabaseConnectionString"]
65+
66+
# Establish an asynchronous connection to the PostgreSQL database
67+
async with psycopg.AsyncConnection.connect(connection_string) as conn:
68+
# Create an asynchronous cursor
69+
async with conn.cursor() as cursor:
70+
await cursor.execute(sql_query)
71+
72+
# Fetch column names
73+
columns = [column[0] for column in cursor.description]
74+
75+
# Fetch rows based on the limit
76+
if limit is not None:
77+
rows = await cursor.fetchmany(limit)
78+
else:
79+
rows = await cursor.fetchall()
80+
81+
# Process the rows
82+
for row in rows:
83+
if cast_to:
84+
results.append(cast_to.from_sql_row(row, columns))
85+
else:
86+
results.append(dict(zip(columns, row)))
87+
88+
logging.debug("Results: %s", results)
89+
return results
90+
91+
async def get_entity_schemas(
92+
self,
93+
text: Annotated[
94+
str,
95+
"The text to run a semantic search against. Relevant entities will be returned.",
96+
],
97+
excluded_entities: Annotated[
98+
list[str],
99+
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
100+
] = [],
101+
as_json: bool = True,
102+
) -> str:
103+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
104+
105+
Args:
106+
----
107+
text (str): The text to run the search against.
108+
109+
Returns:
110+
str: The schema of the views or tables in JSON format.
111+
"""
112+
113+
schemas = await self.ai_search_connector.get_entity_schemas(
114+
text, excluded_entities
115+
)
116+
117+
for schema in schemas:
118+
schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]])
119+
120+
del schema["Entity"]
121+
del schema["Schema"]
122+
123+
if as_json:
124+
return json.dumps(schemas, default=str)
125+
else:
126+
return schemas

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import json
1010

11-
from text_2_sql_core.utils.database import DatabaseEngine
11+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1212

1313

1414
class SnowflakeSqlConnector(SqlConnector):
@@ -17,6 +17,14 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.SNOWFLAKE
1919

20+
@property
21+
def engine_specific_fields(self) -> list[str]:
22+
"""Get the engine specific fields."""
23+
return [
24+
DatabaseEngineSpecificFields.WAREHOUSE,
25+
DatabaseEngineSpecificFields.DATABASE,
26+
]
27+
2028
@property
2129
def invalid_identifiers(self) -> list[str]:
2230
"""Get the invalid identifiers upon which a sql query is rejected."""

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from abc import ABC, abstractmethod
1111
from jinja2 import Template
1212
import json
13+
from text_2_sql_core.utils.database import DatabaseEngineSpecificFields
1314

1415

1516
class SqlConnector(ABC):
@@ -36,6 +37,22 @@ def invalid_identifiers(self) -> list[str]:
3637
"""Get the invalid identifiers upon which a sql query is rejected."""
3738
pass
3839

40+
@property
41+
@abstractmethod
42+
def engine_specific_fields(self) -> list[str]:
43+
"""Get the engine specific fields."""
44+
pass
45+
46+
@property
47+
def excluded_engine_specific_fields(self):
48+
"""A method to get the excluded fields for the database engine."""
49+
50+
return [
51+
field.value.capitalize()
52+
for field in DatabaseEngineSpecificFields
53+
if field not in self.engine_specific_fields
54+
]
55+
3956
@abstractmethod
4057
async def query_execution(
4158
self,
@@ -155,7 +172,7 @@ def handle_node(node):
155172

156173
for token in expressions + identifiers:
157174
if isinstance(token, Parameter):
158-
identifier = token.this.this
175+
identifier = str(token.this.this).upper()
159176
else:
160177
identifier = str(token).strip("()").upper()
161178

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import json
99

10-
from text_2_sql_core.utils.database import DatabaseEngine
10+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1111

1212

1313
class TSQLSqlConnector(SqlConnector):
@@ -16,6 +16,11 @@ def __init__(self):
1616

1717
self.database_engine = DatabaseEngine.TSQL
1818

19+
@property
20+
def engine_specific_fields(self) -> list[str]:
21+
"""Get the engine specific fields."""
22+
return [DatabaseEngineSpecificFields.DATABASE]
23+
1924
@property
2025
def invalid_identifiers(self) -> list[str]:
2126
"""Get the invalid identifiers upon which a sql query is rejected."""

0 commit comments

Comments
 (0)