Skip to content

Commit 111d898

Browse files
Transpile with the read engine (#113)
1 parent e1a689f commit 111d898

File tree

4 files changed

+28
-1
lines changed

4 files changed

+28
-1
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
import logging
99
import json
1010

11+
from text_2_sql_core.utils.database import DatabaseEngine
12+
1113

1214
class DatabricksSqlConnector(SqlConnector):
15+
def __init__(self):
16+
super().__init__()
17+
18+
self.database_engine = DatabaseEngine.DATABRICKS
19+
1320
async def query_execution(
1421
self,
1522
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
import logging
99
import json
1010

11+
from text_2_sql_core.utils.database import DatabaseEngine
12+
1113

1214
class SnowflakeSqlConnector(SqlConnector):
15+
def __init__(self):
16+
super().__init__()
17+
18+
self.database_engine = DatabaseEngine.SNOWFLAKE
19+
1320
async def query_execution(
1421
self,
1522
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(self):
2828

2929
self.ai_search_connector = ConnectorFactory.get_ai_search_connector()
3030

31+
self.database_engine = None
32+
3133
def get_current_datetime(self) -> str:
3234
"""Get the current datetime."""
3335
return datetime.now().strftime("%d/%m/%Y, %H:%M:%S")
@@ -138,7 +140,11 @@ async def query_validation(
138140
"""Validate the SQL query."""
139141
try:
140142
logging.info("Validating SQL Query: %s", sql_query)
141-
sqlglot.transpile(sql_query)
143+
sqlglot.transpile(
144+
sql_query,
145+
read=self.database_engine.value.lower(),
146+
error_level=sqlglot.ErrorLevel.ERROR,
147+
)
142148
except sqlglot.errors.ParseError as e:
143149
logging.error("SQL Query is invalid: %s", e.errors)
144150
return e.errors

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
import logging
88
import json
99

10+
from text_2_sql_core.utils.database import DatabaseEngine
11+
1012

1113
class TSQLSqlConnector(SqlConnector):
14+
def __init__(self):
15+
super().__init__()
16+
17+
self.database_engine = DatabaseEngine.TSQL
18+
1219
async def query_execution(
1320
self,
1421
sql_query: Annotated[

0 commit comments

Comments
 (0)