Skip to content

Commit 08adb59

Browse files
committed
Add support for psycopg and asyncpg drivers
This introduces the `crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect identifiers. The asynchronous variant of `psycopg` is also supported.
1 parent f149dd4 commit 08adb59

File tree

6 files changed

+273
-10
lines changed

6 files changed

+273
-10
lines changed

CHANGES.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Unreleased
44
- Dependencies: Updated to `crate-2.0.0`, which uses `orjson` for JSON marshalling
5+
- Added support for `psycopg` and `asyncpg` drivers, by introducing the
6+
`crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect
7+
identifiers. The asynchronous variant of `psycopg` is also supported.
58

69
## 2024/11/04 0.40.1
710
- CI: Verified support on Python 3.13

pyproject.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ dependencies = [
8989
"verlib2==0.2",
9090
]
9191
optional-dependencies.all = [
92-
"sqlalchemy-cratedb[vector]",
92+
"sqlalchemy-cratedb[postgresql,vector]",
9393
]
9494
optional-dependencies.develop = [
9595
"mypy<1.15",
@@ -102,6 +102,9 @@ optional-dependencies.doc = [
102102
"crate-docs-theme>=0.26.5",
103103
"sphinx>=3.5,<9",
104104
]
105+
optional-dependencies.postgresql = [
106+
"sqlalchemy-postgresql-relaxed",
107+
]
105108
optional-dependencies.release = [
106109
"build<2",
107110
"twine<7",
@@ -112,6 +115,7 @@ optional-dependencies.test = [
112115
"pandas<2.3",
113116
"pueblo>=0.0.7",
114117
"pytest<9",
118+
"pytest-asyncio<0.24",
115119
"pytest-cov<7",
116120
"pytest-mock<4",
117121
]
@@ -122,7 +126,11 @@ urls.changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES.
122126
urls.documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/"
123127
urls.homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/"
124128
urls.repository = "https://github.com/crate/sqlalchemy-cratedb"
125-
entry-points."sqlalchemy.dialects".crate = "sqlalchemy_cratedb:dialect"
129+
entry-points."sqlalchemy.dialects"."crate" = "sqlalchemy_cratedb:dialect"
130+
entry-points."sqlalchemy.dialects"."crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg"
131+
entry-points."sqlalchemy.dialects"."crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg"
132+
entry-points."sqlalchemy.dialects"."crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async"
133+
entry-points."sqlalchemy.dialects"."crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3"
126134

127135
[tool.black]
128136
line-length = 100

src/sqlalchemy_cratedb/dialect.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import logging
2323
from datetime import date, datetime
24+
from types import ModuleType
2425

2526
from sqlalchemy import types as sqltypes
2627
from sqlalchemy.engine import default, reflection
@@ -212,6 +213,12 @@ def initialize(self, connection):
212213
# get default schema name
213214
self.default_schema_name = self._get_default_schema_name(connection)
214215

216+
def set_isolation_level(self, dbapi_connection, level):
217+
"""
218+
For CrateDB, this is implemented as a noop.
219+
"""
220+
pass
221+
215222
def do_rollback(self, connection):
216223
# if any exception is raised by the dbapi, sqlalchemy by default
217224
# attempts to do a rollback crate doesn't support rollbacks.
@@ -230,7 +237,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
230237
use_ssl = asbool(kwargs.pop("ssl", False))
231238
if use_ssl:
232239
servers = ["https://" + server for server in servers]
233-
return self.dbapi.connect(servers=servers, **kwargs)
240+
241+
is_module = isinstance(self.dbapi, ModuleType)
242+
if is_module:
243+
driver_name = self.dbapi.__name__
244+
else:
245+
driver_name = self.dbapi.__class__.__name__
246+
if driver_name == "crate.client":
247+
if "database" in kwargs:
248+
del kwargs["database"]
249+
return self.dbapi.connect(servers=servers, **kwargs)
250+
elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]:
251+
return self.dbapi.connect(host=host, port=port, **kwargs)
252+
else:
253+
raise ValueError(f"Unknown driver variant: {driver_name}")
254+
234255
return self.dbapi.connect(**kwargs)
235256

236257
def do_execute(self, cursor, statement, parameters, context=None):
@@ -300,10 +321,12 @@ def get_table_names(self, connection, schema=None, **kw):
300321
if schema is None:
301322
schema = self._get_effective_schema_name(connection)
302323
cursor = connection.exec_driver_sql(
303-
"SELECT table_name FROM information_schema.tables "
304-
"WHERE {0} = ? "
305-
"AND table_type = 'BASE TABLE' "
306-
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
324+
self._format_query(
325+
"SELECT table_name FROM information_schema.tables "
326+
"WHERE {0} = ? "
327+
"AND table_type = 'BASE TABLE' "
328+
"ORDER BY table_name ASC, {0} ASC"
329+
).format(self.schema_column),
307330
(schema or self.default_schema_name,),
308331
)
309332
return [row[0] for row in cursor.fetchall()]
@@ -326,7 +349,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
326349
"AND column_name !~ ?".format(self.schema_column)
327350
)
328351
cursor = connection.exec_driver_sql(
329-
query,
352+
self._format_query(query),
330353
(
331354
table_name,
332355
schema or self.default_schema_name,
@@ -366,7 +389,9 @@ def result_fun(result):
366389
rows = result.fetchone()
367390
return set(rows[0] if rows else [])
368391

369-
pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name))
392+
pk_result = engine.exec_driver_sql(
393+
self._format_query(query), (table_name, schema or self.default_schema_name)
394+
)
370395
pks = result_fun(pk_result)
371396
return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"}
372397

@@ -405,6 +430,17 @@ def has_ilike_operator(self):
405430
server_version_info = self.server_version_info
406431
return server_version_info is not None and server_version_info >= (4, 1, 0)
407432

433+
def _format_query(self, query):
434+
"""
435+
When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
436+
the paramstyle is not `qmark`, but `pyformat`.
437+
438+
TODO: Review: Is it legit and sane? Are there alternatives?
439+
"""
440+
if self.paramstyle == "pyformat":
441+
query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s")
442+
return query
443+
408444

409445
class DateTrunc(functions.GenericFunction):
410446
name = "date_trunc"
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: utf-8; -*-
2+
#
3+
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
4+
# license agreements. See the NOTICE file distributed with this work for
5+
# additional information regarding copyright ownership. Crate licenses
6+
# this file to you under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License. You may
8+
# obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# However, if you have executed another commercial license agreement
19+
# with Crate these terms will supersede the license and you may use the
20+
# software solely pursuant to the terms of the relevant commercial agreement.
21+
from sqlalchemy.engine.reflection import Inspector
22+
from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed
23+
from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed
24+
from sqlalchemy_postgresql_relaxed.psycopg import (
25+
PGDialect_psycopg_relaxed,
26+
PGDialectAsync_psycopg_relaxed,
27+
)
28+
29+
from sqlalchemy_cratedb import dialect
30+
31+
32+
class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect):
33+
"""
34+
Provide a dialect on top of the relaxed PostgreSQL dialect.
35+
"""
36+
37+
inspector = Inspector
38+
39+
# Need to manually override some methods because of polymorphic inheritance woes.
40+
# TODO: Investigate if this can be solved using metaprogramming or other techniques.
41+
has_schema = dialect.has_schema
42+
has_table = dialect.has_table
43+
get_schema_names = dialect.get_schema_names
44+
get_table_names = dialect.get_table_names
45+
get_view_names = dialect.get_view_names
46+
get_columns = dialect.get_columns
47+
get_pk_constraint = dialect.get_pk_constraint
48+
get_foreign_keys = dialect.get_foreign_keys
49+
get_indexes = dialect.get_indexes
50+
51+
get_multi_columns = dialect.get_multi_columns
52+
get_multi_pk_constraint = dialect.get_multi_pk_constraint
53+
get_multi_foreign_keys = dialect.get_multi_foreign_keys
54+
55+
# TODO: Those may want to go to dialect instead?
56+
def get_multi_indexes(self, *args, **kwargs):
57+
return []
58+
59+
def get_multi_unique_constraints(self, *args, **kwargs):
60+
return []
61+
62+
def get_multi_check_constraints(self, *args, **kwargs):
63+
return []
64+
65+
def get_multi_table_comment(self, *args, **kwargs):
66+
return []
67+
68+
69+
class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter):
70+
driver = "psycopg"
71+
72+
@classmethod
73+
def get_async_dialect_cls(cls, url):
74+
return CrateDialectAsync_psycopg
75+
76+
@classmethod
77+
def import_dbapi(cls):
78+
import psycopg
79+
80+
return psycopg
81+
82+
83+
class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter):
84+
driver = "psycopg_async"
85+
is_async = True
86+
87+
88+
class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter):
89+
driver = "asyncpg"
90+
91+
# TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this!
92+
93+
# TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle'
94+
"""
95+
@classmethod
96+
def import_dbapi(cls):
97+
import asyncpg
98+
99+
return asyncpg
100+
"""
101+
102+
103+
dialect_urllib3 = dialect
104+
dialect_psycopg = CrateDialect_psycopg
105+
dialect_psycopg_async = CrateDialectAsync_psycopg
106+
dialect_asyncpg = CrateDialect_asyncpg

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ def cratedb_service():
1616
Provide a CrateDB service instance to the test suite.
1717
"""
1818
db = CrateDBTestAdapter()
19-
db.start()
19+
db.start(ports={4200: None, 5432: None})
2020
yield db
2121
db.stop()

tests/engine_test.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import pytest
2+
import sqlalchemy as sa
3+
from sqlalchemy.dialects import registry as dialect_registry
4+
5+
from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION
6+
7+
if SA_VERSION < SA_2_0:
8+
raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True)
9+
10+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
11+
12+
# Registering the additional dialects manually seems to be needed when running
13+
# under tests. Apparently, manual registration is not needed under regular
14+
# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint
15+
# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie.
16+
dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3")
17+
dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg")
18+
dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg")
19+
20+
21+
QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;")
22+
23+
24+
def test_engine_sync_vanilla(cratedb_service):
25+
"""
26+
crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver.
27+
"""
28+
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
29+
engine = sa.create_engine(f"crate://crate@localhost:{port4200}/", echo=True)
30+
assert isinstance(engine, sa.engine.Engine)
31+
with engine.connect() as connection:
32+
result = connection.execute(QUERY)
33+
assert result.mappings().fetchone() == {
34+
"mountain": "Acherkogel",
35+
"coordinates": [10.95667, 47.18917],
36+
}
37+
38+
39+
def test_engine_sync_urllib3(cratedb_service):
40+
"""
41+
crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver.
42+
""" # noqa: E501
43+
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
44+
engine = sa.create_engine(
45+
f"crate+urllib3://crate@localhost:{port4200}/", isolation_level="AUTOCOMMIT", echo=True
46+
)
47+
assert isinstance(engine, sa.engine.Engine)
48+
with engine.connect() as connection:
49+
result = connection.execute(QUERY)
50+
assert result.mappings().fetchone() == {
51+
"mountain": "Acherkogel",
52+
"coordinates": [10.95667, 47.18917],
53+
}
54+
55+
56+
def test_engine_sync_psycopg(cratedb_service):
57+
"""
58+
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
59+
"""
60+
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
61+
engine = sa.create_engine(
62+
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
63+
)
64+
assert isinstance(engine, sa.engine.Engine)
65+
with engine.connect() as connection:
66+
result = connection.execute(QUERY)
67+
assert result.mappings().fetchone() == {
68+
"mountain": "Acherkogel",
69+
"coordinates": "(10.95667,47.18917)",
70+
}
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_engine_async_psycopg(cratedb_service):
75+
"""
76+
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
77+
This time, in asynchronous mode.
78+
"""
79+
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
80+
engine = create_async_engine(
81+
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
82+
)
83+
assert isinstance(engine, AsyncEngine)
84+
async with engine.begin() as conn:
85+
result = await conn.execute(QUERY)
86+
assert result.mappings().fetchone() == {
87+
"mountain": "Acherkogel",
88+
"coordinates": "(10.95667,47.18917)",
89+
}
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_engine_async_asyncpg(cratedb_service):
94+
"""
95+
crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver.
96+
This exclusively uses asynchronous mode.
97+
"""
98+
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
99+
from asyncpg.pgproto.types import Point
100+
101+
engine = create_async_engine(
102+
f"crate+asyncpg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
103+
)
104+
assert isinstance(engine, AsyncEngine)
105+
async with engine.begin() as conn:
106+
result = await conn.execute(QUERY)
107+
assert result.mappings().fetchone() == {
108+
"mountain": "Acherkogel",
109+
"coordinates": Point(10.95667, 47.18917),
110+
}

0 commit comments

Comments
 (0)