Skip to content

Commit bebbf07

Browse files
committed
Add support for psycopg and asyncpg drivers
This introduces the `crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect identifiers. The asynchronous variant of `psycopg` is also supported.
1 parent 314d430 commit bebbf07

File tree

6 files changed

+273
-10
lines changed

6 files changed

+273
-10
lines changed

CHANGES.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Changelog
22

33
## Unreleased
4+
- Added support for `psycopg` and `asyncpg` drivers, by introducing the
5+
`crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect
6+
identifiers. The asynchronous variant of `psycopg` is also supported.
47

58
## 2024/08/29 0.39.0
69
Added `quote_relation_name` support utility function

pyproject.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ dependencies = [
8888
"verlib2==0.2",
8989
]
9090
optional-dependencies.all = [
91-
"sqlalchemy-cratedb[vector]",
91+
"sqlalchemy-cratedb[postgresql,vector]",
9292
]
9393
optional-dependencies.develop = [
9494
"mypy<1.12",
@@ -101,6 +101,9 @@ optional-dependencies.doc = [
101101
"crate-docs-theme>=0.26.5",
102102
"sphinx>=3.5,<9",
103103
]
104+
optional-dependencies.postgresql = [
105+
"sqlalchemy-postgresql-relaxed",
106+
]
104107
optional-dependencies.release = [
105108
"build<2",
106109
"twine<6",
@@ -111,6 +114,7 @@ optional-dependencies.test = [
111114
"pandas<2.3",
112115
"pueblo>=0.0.7",
113116
"pytest<9",
117+
"pytest-asyncio<0.24",
114118
"pytest-cov<6",
115119
"pytest-mock<4",
116120
]
@@ -121,7 +125,11 @@ urls.changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES.
121125
urls.documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/"
122126
urls.homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/"
123127
urls.repository = "https://github.com/crate/sqlalchemy-cratedb"
124-
entry-points."sqlalchemy.dialects".crate = "sqlalchemy_cratedb:dialect"
128+
entry-points."sqlalchemy.dialects"."crate" = "sqlalchemy_cratedb:dialect"
129+
entry-points."sqlalchemy.dialects"."crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg"
130+
entry-points."sqlalchemy.dialects"."crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg"
131+
entry-points."sqlalchemy.dialects"."crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async"
132+
entry-points."sqlalchemy.dialects"."crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3"
125133

126134
[tool.black]
127135
line-length = 100

src/sqlalchemy_cratedb/dialect.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import logging
2323
from datetime import date, datetime
24+
from types import ModuleType
2425

2526
from sqlalchemy import types as sqltypes
2627
from sqlalchemy.engine import default, reflection
@@ -212,6 +213,12 @@ def initialize(self, connection):
212213
# get default schema name
213214
self.default_schema_name = self._get_default_schema_name(connection)
214215

216+
def set_isolation_level(self, dbapi_connection, level):
217+
"""
218+
For CrateDB, this is implemented as a noop.
219+
"""
220+
pass
221+
215222
def do_rollback(self, connection):
216223
# if any exception is raised by the dbapi, sqlalchemy by default
217224
# attempts to do a rollback crate doesn't support rollbacks.
@@ -230,7 +237,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
230237
use_ssl = asbool(kwargs.pop("ssl", False))
231238
if use_ssl:
232239
servers = ["https://" + server for server in servers]
233-
return self.dbapi.connect(servers=servers, **kwargs)
240+
241+
is_module = isinstance(self.dbapi, ModuleType)
242+
if is_module:
243+
driver_name = self.dbapi.__name__
244+
else:
245+
driver_name = self.dbapi.__class__.__name__
246+
if driver_name == "crate.client":
247+
if "database" in kwargs:
248+
del kwargs["database"]
249+
return self.dbapi.connect(servers=servers, **kwargs)
250+
elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]:
251+
return self.dbapi.connect(host=host, port=port, **kwargs)
252+
else:
253+
raise ValueError(f"Unknown driver variant: {driver_name}")
254+
234255
return self.dbapi.connect(**kwargs)
235256

236257
def _get_default_schema_name(self, connection):
@@ -276,10 +297,12 @@ def get_table_names(self, connection, schema=None, **kw):
276297
if schema is None:
277298
schema = self._get_effective_schema_name(connection)
278299
cursor = connection.exec_driver_sql(
279-
"SELECT table_name FROM information_schema.tables "
280-
"WHERE {0} = ? "
281-
"AND table_type = 'BASE TABLE' "
282-
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
300+
self._format_query(
301+
"SELECT table_name FROM information_schema.tables "
302+
"WHERE {0} = ? "
303+
"AND table_type = 'BASE TABLE' "
304+
"ORDER BY table_name ASC, {0} ASC"
305+
).format(self.schema_column),
283306
(schema or self.default_schema_name,),
284307
)
285308
return [row[0] for row in cursor.fetchall()]
@@ -302,7 +325,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
302325
"AND column_name !~ ?".format(self.schema_column)
303326
)
304327
cursor = connection.exec_driver_sql(
305-
query,
328+
self._format_query(query),
306329
(
307330
table_name,
308331
schema or self.default_schema_name,
@@ -342,7 +365,9 @@ def result_fun(result):
342365
rows = result.fetchone()
343366
return set(rows[0] if rows else [])
344367

345-
pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name))
368+
pk_result = engine.exec_driver_sql(
369+
self._format_query(query), (table_name, schema or self.default_schema_name)
370+
)
346371
pks = result_fun(pk_result)
347372
return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"}
348373

@@ -381,6 +406,17 @@ def has_ilike_operator(self):
381406
server_version_info = self.server_version_info
382407
return server_version_info is not None and server_version_info >= (4, 1, 0)
383408

409+
def _format_query(self, query):
410+
"""
411+
When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
412+
the paramstyle is not `qmark`, but `pyformat`.
413+
414+
TODO: Review: Is it legit and sane? Are there alternatives?
415+
"""
416+
if self.paramstyle == "pyformat":
417+
query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s")
418+
return query
419+
384420

385421
class DateTrunc(functions.GenericFunction):
386422
name = "date_trunc"
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: utf-8; -*-
2+
#
3+
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
4+
# license agreements. See the NOTICE file distributed with this work for
5+
# additional information regarding copyright ownership. Crate licenses
6+
# this file to you under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License. You may
8+
# obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# However, if you have executed another commercial license agreement
19+
# with Crate these terms will supersede the license and you may use the
20+
# software solely pursuant to the terms of the relevant commercial agreement.
21+
from sqlalchemy.engine.reflection import Inspector
22+
from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed
23+
from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed
24+
from sqlalchemy_postgresql_relaxed.psycopg import (
25+
PGDialect_psycopg_relaxed,
26+
PGDialectAsync_psycopg_relaxed,
27+
)
28+
29+
from sqlalchemy_cratedb import dialect
30+
31+
32+
class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect):
33+
"""
34+
Provide a dialect on top of the relaxed PostgreSQL dialect.
35+
"""
36+
37+
inspector = Inspector
38+
39+
# Need to manually override some methods because of polymorphic inheritance woes.
40+
# TODO: Investigate if this can be solved using metaprogramming or other techniques.
41+
has_schema = dialect.has_schema
42+
has_table = dialect.has_table
43+
get_schema_names = dialect.get_schema_names
44+
get_table_names = dialect.get_table_names
45+
get_view_names = dialect.get_view_names
46+
get_columns = dialect.get_columns
47+
get_pk_constraint = dialect.get_pk_constraint
48+
get_foreign_keys = dialect.get_foreign_keys
49+
get_indexes = dialect.get_indexes
50+
51+
get_multi_columns = dialect.get_multi_columns
52+
get_multi_pk_constraint = dialect.get_multi_pk_constraint
53+
get_multi_foreign_keys = dialect.get_multi_foreign_keys
54+
55+
# TODO: Those may want to go to dialect instead?
56+
def get_multi_indexes(self, *args, **kwargs):
57+
return []
58+
59+
def get_multi_unique_constraints(self, *args, **kwargs):
60+
return []
61+
62+
def get_multi_check_constraints(self, *args, **kwargs):
63+
return []
64+
65+
def get_multi_table_comment(self, *args, **kwargs):
66+
return []
67+
68+
69+
class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter):
70+
driver = "psycopg"
71+
72+
@classmethod
73+
def get_async_dialect_cls(cls, url):
74+
return CrateDialectAsync_psycopg
75+
76+
@classmethod
77+
def import_dbapi(cls):
78+
import psycopg
79+
80+
return psycopg
81+
82+
83+
class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter):
84+
driver = "psycopg_async"
85+
is_async = True
86+
87+
88+
class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter):
89+
driver = "asyncpg"
90+
91+
# TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this!
92+
93+
# TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle'
94+
"""
95+
@classmethod
96+
def import_dbapi(cls):
97+
import asyncpg
98+
99+
return asyncpg
100+
"""
101+
102+
103+
dialect_urllib3 = dialect
104+
dialect_psycopg = CrateDialect_psycopg
105+
dialect_psycopg_async = CrateDialectAsync_psycopg
106+
dialect_asyncpg = CrateDialect_asyncpg

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ def cratedb_service():
1616
Provide a CrateDB service instance to the test suite.
1717
"""
1818
db = CrateDBTestAdapter()
19-
db.start()
19+
db.start(ports={4200: None, 5432: None})
2020
yield db
2121
db.stop()

tests/engine_test.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import pytest
2+
import sqlalchemy as sa
3+
from sqlalchemy.dialects import registry as dialect_registry
4+
5+
from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION
6+
7+
if SA_VERSION < SA_2_0:
8+
raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True)
9+
10+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
11+
12+
# Registering the additional dialects manually seems to be needed when running
13+
# under tests. Apparently, manual registration is not needed under regular
14+
# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint
15+
# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie.
16+
dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3")
17+
dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg")
18+
dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg")
19+
20+
21+
QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;")
22+
23+
24+
def test_engine_sync_vanilla(cratedb_service):
25+
"""
26+
crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver.
27+
"""
28+
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
29+
engine = sa.create_engine(f"crate://crate@localhost:{port4200}/", echo=True)
30+
assert isinstance(engine, sa.engine.Engine)
31+
with engine.connect() as connection:
32+
result = connection.execute(QUERY)
33+
assert result.mappings().fetchone() == {
34+
"mountain": "Acherkogel",
35+
"coordinates": [10.95667, 47.18917],
36+
}
37+
38+
39+
def test_engine_sync_urllib3(cratedb_service):
40+
"""
41+
crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver.
42+
""" # noqa: E501
43+
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
44+
engine = sa.create_engine(
45+
f"crate+urllib3://crate@localhost:{port4200}/", isolation_level="AUTOCOMMIT", echo=True
46+
)
47+
assert isinstance(engine, sa.engine.Engine)
48+
with engine.connect() as connection:
49+
result = connection.execute(QUERY)
50+
assert result.mappings().fetchone() == {
51+
"mountain": "Acherkogel",
52+
"coordinates": [10.95667, 47.18917],
53+
}
54+
55+
56+
def test_engine_sync_psycopg(cratedb_service):
57+
"""
58+
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
59+
"""
60+
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
61+
engine = sa.create_engine(
62+
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
63+
)
64+
assert isinstance(engine, sa.engine.Engine)
65+
with engine.connect() as connection:
66+
result = connection.execute(QUERY)
67+
assert result.mappings().fetchone() == {
68+
"mountain": "Acherkogel",
69+
"coordinates": "(10.95667,47.18917)",
70+
}
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_engine_async_psycopg(cratedb_service):
75+
"""
76+
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
77+
This time, in asynchronous mode.
78+
"""
79+
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
80+
engine = create_async_engine(
81+
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
82+
)
83+
assert isinstance(engine, AsyncEngine)
84+
async with engine.begin() as conn:
85+
result = await conn.execute(QUERY)
86+
assert result.mappings().fetchone() == {
87+
"mountain": "Acherkogel",
88+
"coordinates": "(10.95667,47.18917)",
89+
}
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_engine_async_asyncpg(cratedb_service):
94+
"""
95+
crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver.
96+
This exclusively uses asynchronous mode.
97+
"""
98+
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
99+
from asyncpg.pgproto.types import Point
100+
101+
engine = create_async_engine(
102+
f"crate+asyncpg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
103+
)
104+
assert isinstance(engine, AsyncEngine)
105+
async with engine.begin() as conn:
106+
result = await conn.execute(QUERY)
107+
assert result.mappings().fetchone() == {
108+
"mountain": "Acherkogel",
109+
"coordinates": Point(10.95667, 47.18917),
110+
}

0 commit comments

Comments
 (0)