diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index cca4cc5..14c99ac 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -8,8 +8,15 @@ import logging import json +from text_2_sql_core.utils.database import DatabaseEngine + class DatabricksSqlConnector(SqlConnector): + def __init__(self): + super().__init__() + + self.database_engine = DatabaseEngine.DATABRICKS + async def query_execution( self, sql_query: Annotated[ diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py index c25f47e..c8da9c2 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py @@ -8,8 +8,15 @@ import logging import json +from text_2_sql_core.utils.database import DatabaseEngine + class SnowflakeSqlConnector(SqlConnector): + def __init__(self): + super().__init__() + + self.database_engine = DatabaseEngine.SNOWFLAKE + async def query_execution( self, sql_query: Annotated[ diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py index cae21e1..2309128 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py @@ -28,6 +28,8 @@ def __init__(self): self.ai_search_connector = ConnectorFactory.get_ai_search_connector() + self.database_engine = None + def get_current_datetime(self) -> str: """Get the current datetime.""" return datetime.now().strftime("%d/%m/%Y, %H:%M:%S") @@ -138,7 +140,11 @@ async def query_validation( """Validate the SQL query.""" try: logging.info("Validating SQL Query: %s", sql_query) - sqlglot.transpile(sql_query) + sqlglot.transpile( + sql_query, + read=self.database_engine.value.lower(), + error_level=sqlglot.ErrorLevel.ERROR, + ) except sqlglot.errors.ParseError as e: logging.error("SQL Query is invalid: %s", e.errors) return e.errors diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py index ddb9b28..e494bb1 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py @@ -7,8 +7,15 @@ import logging import json +from text_2_sql_core.utils.database import DatabaseEngine + class TSQLSqlConnector(SqlConnector): + def __init__(self): + super().__init__() + + self.database_engine = DatabaseEngine.TSQL + async def query_execution( self, sql_query: Annotated[