diff --git a/CHANGES.md b/CHANGES.md index 2fba45ed..ca5ee5b9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,11 @@ # Changelog ## Unreleased +- Added canonical [PostgreSQL client parameter `sslmode`], implementing + `sslmode=prefer` to connect to SSL-enabled CrateDB instances without + verifying the host name. + +[PostgreSQL client parameter `sslmode`]: https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION ## 2025/01/30 0.41.0 - Dependencies: Updated to `crate-2.0.0`, which uses `orjson` for JSON marshalling diff --git a/docs/inspection-reflection.rst b/docs/inspection-reflection.rst index 3faa6094..253baa01 100644 --- a/docs/inspection-reflection.rst +++ b/docs/inspection-reflection.rst @@ -87,10 +87,10 @@ Create a SQLAlchemy table object: Reflect column data types from the table metadata: >>> table.columns.get('name') - Column('name', String(), table=) + Column('name', VARCHAR(), table=) >>> table.primary_key - PrimaryKeyConstraint(Column('id', String(), table=, primary_key=True... + PrimaryKeyConstraint(Column('id', VARCHAR(), table=, primary_key=True... CrateDialect diff --git a/pyproject.toml b/pyproject.toml index 66a0730b..edfc4bf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,10 @@ requires = [ [project] name = "sqlalchemy-cratedb" +#dynamic = [ +# "version", +#] +version = "0.42.0.dev2" description = "SQLAlchemy dialect for CrateDB." readme = "README.md" keywords = [ @@ -77,9 +81,6 @@ classifiers = [ "Topic :: Text Processing", "Topic :: Utilities", ] -dynamic = [ - "version", -] dependencies = [ "backports.zoneinfo<1; python_version<'3.9'", "crate>=2,<3", @@ -110,7 +111,7 @@ optional-dependencies.release = [ optional-dependencies.test = [ "cratedb-toolkit[testing]", "dask[dataframe]", - "pandas<2.3", + "pandas[test]<2.3", "pueblo>=0.0.7", "pytest<9", "pytest-cov<7", diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 7b2c5ccd..6d46a798 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -200,6 +200,17 @@ def visit_unique_constraint(self, constraint, **kw): ) return + def visit_create_index(self, create, **kw) -> str: + """ + CrateDB does not support `CREATE INDEX` statements. + """ + warnings.warn( + "CrateDB does not support `CREATE INDEX` statements, " + "they will be omitted when generating DDL statements.", + stacklevel=2, + ) + return "SELECT 1" + class CrateTypeCompiler(compiler.GenericTypeCompiler): def visit_string(self, type_, **kw): @@ -254,6 +265,36 @@ def visit_TIMESTAMP(self, type_, **kw): """ return "TIMESTAMP %s" % ((type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",) + def visit_BLOB(self, type_, **kw): + return "STRING" + + def visit_FLOAT(self, type_, **kw): + """ + From `sqlalchemy.sql.sqltypes.Float`. + + When a :paramref:`.Float.precision` is not provided in a + :class:`_types.Float` type some backend may compile this type as + an 8 bytes / 64 bit float datatype. To use a 4 bytes / 32 bit float + datatype a precision <= 24 can usually be provided or the + :class:`_types.REAL` type can be used. + This is known to be the case in the PostgreSQL and MSSQL dialects + that render the type as ``FLOAT`` that's in both an alias of + ``DOUBLE PRECISION``. Other third party dialects may have similar + behavior. + """ + if not type_.precision: + return "FLOAT" + elif type_.precision <= 24: + return "FLOAT" + else: + return "DOUBLE" + + def visit_JSON(self, type_, **kw): + return "OBJECT" + + def visit_JSONB(self, type_, **kw): + return "OBJECT" + class CrateCompiler(compiler.SQLCompiler): def visit_getitem_binary(self, binary, operator, **kw): diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 90102a78..d0f0f399 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -34,46 +34,60 @@ ) from .sa_version import SA_1_4, SA_2_0, SA_VERSION from .type import FloatVector, ObjectArray, ObjectType +from .type.binary import LargeBinary +# For SQLAlchemy >= 1.1. TYPES_MAP = { - "boolean": sqltypes.Boolean, - "short": sqltypes.SmallInteger, - "smallint": sqltypes.SmallInteger, - "timestamp": sqltypes.TIMESTAMP(timezone=False), + "boolean": sqltypes.BOOLEAN, + "short": sqltypes.SMALLINT, + "smallint": sqltypes.SMALLINT, + "timestamp": sqltypes.TIMESTAMP, "timestamp with time zone": sqltypes.TIMESTAMP(timezone=True), + "timestamp without time zone": sqltypes.TIMESTAMP(timezone=False), "object": ObjectType, - "integer": sqltypes.Integer, - "long": sqltypes.NUMERIC, - "bigint": sqltypes.NUMERIC, + "object_array": ObjectArray, # TODO: Can this also be improved to use `sqltypes.ARRAY`? + "integer": sqltypes.INTEGER, + "long": sqltypes.BIGINT, + "bigint": sqltypes.BIGINT, + "float": sqltypes.FLOAT, "double": sqltypes.DECIMAL, "double precision": sqltypes.DECIMAL, - "object_array": ObjectArray, - "float": sqltypes.Float, - "real": sqltypes.Float, - "string": sqltypes.String, - "text": sqltypes.String, + "real": sqltypes.REAL, + "string": sqltypes.VARCHAR, + "text": sqltypes.VARCHAR, "float_vector": FloatVector, } -# Needed for SQLAlchemy >= 1.1. -# TODO: Dissolve. +# For SQLAlchemy >= 1.4. try: from sqlalchemy.types import ARRAY - TYPES_MAP["integer_array"] = ARRAY(sqltypes.Integer) - TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean) - TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger) - TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SmallInteger) + TYPES_MAP["integer_array"] = ARRAY(sqltypes.INTEGER) + TYPES_MAP["boolean_array"] = ARRAY(sqltypes.BOOLEAN) + TYPES_MAP["short_array"] = ARRAY(sqltypes.SMALLINT) + TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SMALLINT) + TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP) TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=False)) TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=True)) - TYPES_MAP["long_array"] = ARRAY(sqltypes.NUMERIC) - TYPES_MAP["bigint_array"] = ARRAY(sqltypes.NUMERIC) - TYPES_MAP["double_array"] = ARRAY(sqltypes.DECIMAL) - TYPES_MAP["double precision_array"] = ARRAY(sqltypes.DECIMAL) - TYPES_MAP["float_array"] = ARRAY(sqltypes.Float) - TYPES_MAP["real_array"] = ARRAY(sqltypes.Float) - TYPES_MAP["string_array"] = ARRAY(sqltypes.String) - TYPES_MAP["text_array"] = ARRAY(sqltypes.String) + TYPES_MAP["long_array"] = ARRAY(sqltypes.BIGINT) + TYPES_MAP["bigint_array"] = ARRAY(sqltypes.BIGINT) + TYPES_MAP["float_array"] = ARRAY(sqltypes.FLOAT) + TYPES_MAP["real_array"] = ARRAY(sqltypes.REAL) + TYPES_MAP["string_array"] = ARRAY(sqltypes.VARCHAR) + TYPES_MAP["text_array"] = ARRAY(sqltypes.VARCHAR) +except Exception: # noqa: S110 + pass + +# For SQLAlchemy >= 2.0. +try: + from sqlalchemy.types import DOUBLE, DOUBLE_PRECISION + + TYPES_MAP["real"] = DOUBLE + TYPES_MAP["real_array"] = ARRAY(DOUBLE) + TYPES_MAP["double"] = DOUBLE + TYPES_MAP["double_array"] = ARRAY(DOUBLE) + TYPES_MAP["double precision"] = DOUBLE_PRECISION + TYPES_MAP["double precision_array"] = ARRAY(DOUBLE_PRECISION) except Exception: # noqa: S110 pass @@ -158,6 +172,7 @@ def process(value): sqltypes.Date: Date, sqltypes.DateTime: DateTime, sqltypes.TIMESTAMP: DateTime, + sqltypes.LargeBinary: LargeBinary, } @@ -206,6 +221,15 @@ def __init__(self, **kwargs): # start with _. Adding it here causes sqlalchemy to quote such columns. self.identifier_preparer.illegal_initial_characters.add("_") + def get_isolation_level_values(self, dbapi_conn): + return () + + def set_isolation_level(self, dbapi_connection, level): + pass + + def get_isolation_level(self, dbapi_connection): + return "NONE" + def initialize(self, connection): # get lowest server version self.server_version_info = self._get_server_version_info(connection) @@ -228,8 +252,12 @@ def connect(self, host=None, port=None, *args, **kwargs): servers = to_list(server) if servers: use_ssl = asbool(kwargs.pop("ssl", False)) - if use_ssl: + # TODO: Switch to the canonical default `sslmode=prefer` later. + sslmode = kwargs.pop("sslmode", "disable") + if use_ssl or sslmode in ["allow", "prefer", "require", "verify-ca", "verify-full"]: servers = ["https://" + server for server in servers] + if sslmode == "require": + kwargs["verify_ssl_cert"] = False return self.dbapi.connect(servers=servers, **kwargs) return self.dbapi.connect(**kwargs) diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index b524bb39..6d92e0e2 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -1,4 +1,5 @@ from .array import ObjectArray +from .binary import LargeBinary from .geo import Geopoint, Geoshape from .object import ObjectType from .vector import FloatVector, knn_match @@ -6,6 +7,7 @@ __all__ = [ Geopoint, Geoshape, + LargeBinary, ObjectArray, ObjectType, FloatVector, diff --git a/src/sqlalchemy_cratedb/type/array.py b/src/sqlalchemy_cratedb/type/array.py index 7798692c..71efca4b 100644 --- a/src/sqlalchemy_cratedb/type/array.py +++ b/src/sqlalchemy_cratedb/type/array.py @@ -96,6 +96,8 @@ def __init__(self, left, right, operator=operators.eq): self.operator = operator +# TODO: Should this be inherited from PostgreSQL's +# `ARRAY`, in order to improve type checking? class _ObjectArray(sqltypes.UserDefinedType): cache_ok = True @@ -139,5 +141,8 @@ def any(self, other, operator=operators.eq): def get_col_spec(self, **kws): return "ARRAY(OBJECT)" + def as_generic(self, **kwargs): + return sqltypes.ARRAY + ObjectArray = MutableList.as_mutable(_ObjectArray) diff --git a/src/sqlalchemy_cratedb/type/binary.py b/src/sqlalchemy_cratedb/type/binary.py new file mode 100644 index 00000000..04b04073 --- /dev/null +++ b/src/sqlalchemy_cratedb/type/binary.py @@ -0,0 +1,44 @@ +import base64 + +import sqlalchemy as sa + + +class LargeBinary(sa.String): + """A type for large binary byte data. + + The :class:`.LargeBinary` type corresponds to a large and/or unlengthed + binary type for the target platform, such as BLOB on MySQL and BYTEA for + PostgreSQL. It also handles the necessary conversions for the DBAPI. + + """ + + __visit_name__ = "large_binary" + + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + # TODO: DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + # TODO: return DBAPIBinary(value) + return base64.b64encode(value).decode() + else: + return None + + return process + + # Python 3 has native bytes() type + # both sqlite3 and pg8000 seem to return it, + # psycopg2 as of 2.5 returns 'memoryview' + def result_processor(self, dialect, coltype): + if dialect.returns_native_bytes: + return None + + def process(value): + if value is not None: + return base64.b64decode(value) + return value + + return process diff --git a/tests/integration.py b/tests/integration.py index 75c36a27..35e5e647 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -30,7 +30,7 @@ from crate.client import connect -from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_2_0, SA_VERSION from tests.settings import crate_host log = logging.getLogger() @@ -179,16 +179,22 @@ def create_test_suite(): "docs/crud.rst", "docs/working-with-types.rst", "docs/advanced-querying.rst", - "docs/inspection-reflection.rst", ] - # Don't run DataFrame integration tests on SQLAlchemy 1.3 and Python 3.7. - skip_dataframe = SA_VERSION < SA_2_0 or sys.version_info < (3, 8) or sys.version_info >= (3, 13) + # Don't run DataFrame integration tests on SQLAlchemy 1.4 and earlier, or Python 3.7. + skip_dataframe = SA_VERSION < SA_2_0 or sys.version_info < (3, 8) if not skip_dataframe: sqlalchemy_integration_tests += [ "docs/dataframe.rst", ] + # Don't run reflection integration tests on SQLAlchemy 1.3 and earlier and Python 3.10 and 3.11. + skip_reflection = SA_VERSION < SA_1_4 and (3, 10) <= sys.version_info < (3, 12) + if not skip_reflection: + sqlalchemy_integration_tests += [ + "docs/inspection-reflection.rst", + ] + s = doctest.DocFileSuite( *sqlalchemy_integration_tests, module_relative=False, diff --git a/tests/test_schema.py b/tests/test_schema.py index fa9a4764..14b4ce90 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,8 +1,12 @@ +from unittest import skipIf + import sqlalchemy as sa +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION from tests.conftest import TESTDRIVE_DATA_SCHEMA +@skipIf(SA_VERSION < SA_1_4, "Does not work correctly on SQLAlchemy 1.3") def test_correct_schema(cratedb_service): """ Tests that the correct schema is being picked up. diff --git a/tests/test_support_pandas.py b/tests/test_support_pandas.py index 47fe9c7a..5ba15705 100644 --- a/tests/test_support_pandas.py +++ b/tests/test_support_pandas.py @@ -1,7 +1,10 @@ import re import sys +import pandas as pd import pytest +import sqlalchemy as sa +from pandas._testing import assert_equal from pueblo.testing.pandas import makeTimeDataFrame from sqlalchemy.exc import ProgrammingError @@ -15,6 +18,18 @@ df = makeTimeDataFrame(nper=INSERT_RECORDS, freq="S") df["time"] = df.index +float_double_data = { + "col_1": [19556.88, 629414.27, 51570.0, 2933.52, 20338.98], + "col_2": [ + 15379.920000000002, + 1107140.42, + 8081.999999999999, + 1570.0300000000002, + 29468.539999999997, + ], +} +float_double_df = pd.DataFrame.from_dict(float_double_data) + @pytest.mark.skipif( sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" @@ -113,3 +128,34 @@ def test_table_kwargs_unknown(cratedb_service): "passed to [ALTER | CREATE] TABLE statement]" ) ) + + +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" +) +@pytest.mark.skipif( + SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier" +) +def test_float_double(cratedb_service): + """ + Validate I/O with floating point numbers, specifically DOUBLE types. + + Motto: Do not lose precision when DOUBLE is required. + """ + tablename = "pandas_double" + engine = cratedb_service.database.engine + float_double_df.to_sql( + tablename, + engine, + if_exists="replace", + index=False, + ) + with engine.connect() as conn: + conn.execute(sa.text(f"REFRESH TABLE {tablename}")) + df_load = pd.read_sql_table(tablename, engine) + + before = float_double_df.sort_values(by="col_1", ignore_index=True) + after = df_load.sort_values(by="col_1", ignore_index=True) + + pd.options.display.float_format = "{:.12f}".format + assert_equal(before, after, check_exact=True)