Skip to content

Commit e8bfd77

Browse files
committed
Add support for psycopg and asyncpg drivers
This introduces the `crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect identifiers. The asynchronous variant of `psycopg` is also supported.
1 parent 6db4702 commit e8bfd77

File tree

5 files changed

+237
-7
lines changed

5 files changed

+237
-7
lines changed

CHANGES.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33

44
## Unreleased
5+
- Added support for `psycopg` and `asyncpg` drivers, by introducing the
6+
`crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect
7+
identifiers. The asynchronous variant of `psycopg` is also supported.
58

69
## 2024/06/25 0.38.0
710
- Added/reactivated documentation as `sqlalchemy-cratedb`

pyproject.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ dependencies = [
9393
]
9494
[project.optional-dependencies]
9595
all = [
96-
"sqlalchemy-cratedb[vector]",
96+
"sqlalchemy-cratedb[postgresql,vector]",
9797
]
9898
develop = [
9999
"black<25",
@@ -107,6 +107,9 @@ doc = [
107107
"crate-docs-theme>=0.26.5",
108108
"sphinx<8,>=3.5",
109109
]
110+
postgresql = [
111+
"sqlalchemy-postgresql-relaxed",
112+
]
110113
release = [
111114
"build<2",
112115
"twine<6",
@@ -117,6 +120,7 @@ test = [
117120
"pandas<2.3",
118121
"pueblo>=0.0.7",
119122
"pytest<9",
123+
"pytest-asyncio<0.24",
120124
"pytest-cov<6",
121125
"pytest-mock<4",
122126
]
@@ -129,7 +133,11 @@ documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/"
129133
homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/"
130134
repository = "https://github.com/crate/sqlalchemy-cratedb"
131135
[project.entry-points."sqlalchemy.dialects"]
132-
crate = "sqlalchemy_cratedb:dialect"
136+
"crate" = "sqlalchemy_cratedb:dialect"
137+
"crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3"
138+
"crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg"
139+
"crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async"
140+
"crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg"
133141

134142
[tool.black]
135143
line-length = 100

src/sqlalchemy_cratedb/dialect.py

+37-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import logging
2323
from datetime import datetime, date
24+
from types import ModuleType
2425

2526
from sqlalchemy import types as sqltypes
2627
from sqlalchemy.engine import default, reflection
@@ -202,6 +203,12 @@ def initialize(self, connection):
202203
self.default_schema_name = \
203204
self._get_default_schema_name(connection)
204205

206+
def set_isolation_level(self, dbapi_connection, level):
207+
"""
208+
For CrateDB, this is implemented as a noop.
209+
"""
210+
pass
211+
205212
def do_rollback(self, connection):
206213
# if any exception is raised by the dbapi, sqlalchemy by default
207214
# attempts to do a rollback crate doesn't support rollbacks.
@@ -220,7 +227,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
220227
use_ssl = asbool(kwargs.pop("ssl", False))
221228
if use_ssl:
222229
servers = ["https://" + server for server in servers]
223-
return self.dbapi.connect(servers=servers, **kwargs)
230+
231+
is_module = isinstance(self.dbapi, ModuleType)
232+
if is_module:
233+
driver_name = self.dbapi.__name__
234+
else:
235+
driver_name = self.dbapi.__class__.__name__
236+
if driver_name == "crate.client":
237+
if "database" in kwargs:
238+
del kwargs["database"]
239+
return self.dbapi.connect(servers=servers, **kwargs)
240+
elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]:
241+
return self.dbapi.connect(host=host, port=port, **kwargs)
242+
else:
243+
raise ValueError(f"Unknown driver variant: {driver_name}")
244+
224245
return self.dbapi.connect(**kwargs)
225246

226247
def _get_default_schema_name(self, connection):
@@ -266,11 +287,11 @@ def get_schema_names(self, connection, **kw):
266287
def get_table_names(self, connection, schema=None, **kw):
267288
if schema is None:
268289
schema = self._get_effective_schema_name(connection)
269-
cursor = connection.exec_driver_sql(
290+
cursor = connection.exec_driver_sql(self._format_query(
270291
"SELECT table_name FROM information_schema.tables "
271292
"WHERE {0} = ? "
272293
"AND table_type = 'BASE TABLE' "
273-
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
294+
"ORDER BY table_name ASC, {0} ASC").format(self.schema_column),
274295
(schema or self.default_schema_name, )
275296
)
276297
return [row[0] for row in cursor.fetchall()]
@@ -292,7 +313,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
292313
"AND column_name !~ ?" \
293314
.format(self.schema_column)
294315
cursor = connection.exec_driver_sql(
295-
query,
316+
self._format_query(query),
296317
(table_name,
297318
schema or self.default_schema_name,
298319
r"(.*)\[\'(.*)\'\]") # regex to filter subscript
@@ -331,7 +352,7 @@ def result_fun(result):
331352
return set(rows[0] if rows else [])
332353

333354
pk_result = engine.exec_driver_sql(
334-
query,
355+
self._format_query(query),
335356
(table_name, schema or self.default_schema_name)
336357
)
337358
pks = result_fun(pk_result)
@@ -372,6 +393,17 @@ def has_ilike_operator(self):
372393
server_version_info = self.server_version_info
373394
return server_version_info is not None and server_version_info >= (4, 1, 0)
374395

396+
def _format_query(self, query):
397+
"""
398+
When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
399+
the paramstyle is not `qmark`, but `pyformat`.
400+
401+
TODO: Review: Is it legit and sane? Are there alternatives?
402+
"""
403+
if self.paramstyle == "pyformat":
404+
query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s")
405+
return query
406+
375407

376408
class DateTrunc(functions.GenericFunction):
377409
name = "date_trunc"
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# -*- coding: utf-8; -*-
2+
#
3+
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
4+
# license agreements. See the NOTICE file distributed with this work for
5+
# additional information regarding copyright ownership. Crate licenses
6+
# this file to you under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License. You may
8+
# obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# However, if you have executed another commercial license agreement
19+
# with Crate these terms will supersede the license and you may use the
20+
# software solely pursuant to the terms of the relevant commercial agreement.
21+
from sqlalchemy.engine.reflection import Inspector
22+
from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed
23+
from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed
24+
from sqlalchemy_postgresql_relaxed.psycopg import (
25+
PGDialect_psycopg_relaxed,
26+
PGDialectAsync_psycopg_relaxed,
27+
)
28+
29+
from sqlalchemy_cratedb import dialect
30+
31+
32+
class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect):
33+
"""
34+
Provide a dialect on top of the relaxed PostgreSQL dialect.
35+
"""
36+
37+
inspector = Inspector
38+
39+
# Need to manually override some methods because of polymorphic inheritance woes.
40+
# TODO: Investigate if this can be solved using metaprogramming or other techniques.
41+
has_schema = dialect.has_schema
42+
has_table = dialect.has_table
43+
get_schema_names = dialect.get_schema_names
44+
get_table_names = dialect.get_table_names
45+
get_view_names = dialect.get_view_names
46+
get_columns = dialect.get_columns
47+
get_pk_constraint = dialect.get_pk_constraint
48+
get_foreign_keys = dialect.get_foreign_keys
49+
get_indexes = dialect.get_indexes
50+
51+
get_multi_columns = dialect.get_multi_columns
52+
get_multi_pk_constraint = dialect.get_multi_pk_constraint
53+
get_multi_foreign_keys = dialect.get_multi_foreign_keys
54+
55+
# TODO: Those may want to go to dialect instead?
56+
def get_multi_indexes(self, *args, **kwargs):
57+
return []
58+
59+
def get_multi_unique_constraints(self, *args, **kwargs):
60+
return []
61+
62+
def get_multi_check_constraints(self, *args, **kwargs):
63+
return []
64+
65+
def get_multi_table_comment(self, *args, **kwargs):
66+
return []
67+
68+
69+
class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter):
70+
driver = "psycopg"
71+
72+
@classmethod
73+
def get_async_dialect_cls(cls, url):
74+
return CrateDialectAsync_psycopg
75+
76+
@classmethod
77+
def import_dbapi(cls):
78+
import psycopg
79+
80+
return psycopg
81+
82+
83+
class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter):
84+
driver = "psycopg_async"
85+
is_async = True
86+
87+
88+
class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter):
89+
driver = "asyncpg"
90+
91+
# TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this!
92+
93+
# TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle'
94+
"""
95+
@classmethod
96+
def import_dbapi(cls):
97+
import asyncpg
98+
99+
return asyncpg
100+
"""
101+
102+
103+
dialect_urllib3 = dialect
104+
dialect_psycopg = CrateDialect_psycopg
105+
dialect_psycopg_async = CrateDialectAsync_psycopg
106+
dialect_asyncpg = CrateDialect_asyncpg

tests/engine_test.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
import sqlalchemy as sa
3+
from sqlalchemy.dialects import registry as dialect_registry
4+
5+
from sqlalchemy_cratedb import SA_VERSION, SA_2_0
6+
7+
if SA_VERSION < SA_2_0:
8+
raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True)
9+
10+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
11+
12+
# Registering the additional dialects manually seems to be needed when running
13+
# under tests. Apparently, manual registration is not needed under regular
14+
# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint
15+
# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie.
16+
dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3")
17+
dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg")
18+
dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg")
19+
20+
21+
QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;")
22+
23+
24+
def test_engine_sync_vanilla():
25+
"""
26+
crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver.
27+
"""
28+
engine = sa.create_engine("crate://crate@localhost:4200/", echo=True)
29+
assert isinstance(engine, sa.engine.Engine)
30+
with engine.connect() as connection:
31+
result = connection.execute(QUERY)
32+
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': [10.95667, 47.18917]}
33+
34+
35+
def test_engine_sync_urllib3():
36+
"""
37+
crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver.
38+
"""
39+
engine = sa.create_engine("crate+urllib3://crate@localhost:4200/", isolation_level="AUTOCOMMIT", echo=True)
40+
assert isinstance(engine, sa.engine.Engine)
41+
with engine.connect() as connection:
42+
result = connection.execute(QUERY)
43+
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': [10.95667, 47.18917]}
44+
45+
46+
def test_engine_sync_psycopg():
47+
"""
48+
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
49+
"""
50+
engine = sa.create_engine("crate+psycopg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True)
51+
assert isinstance(engine, sa.engine.Engine)
52+
with engine.connect() as connection:
53+
result = connection.execute(QUERY)
54+
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': '(10.95667,47.18917)'}
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_engine_async_psycopg():
59+
"""
60+
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
61+
This time, in asynchronous mode.
62+
"""
63+
engine = create_async_engine("crate+psycopg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True)
64+
assert isinstance(engine, AsyncEngine)
65+
async with engine.begin() as conn:
66+
result = await conn.execute(QUERY)
67+
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': '(10.95667,47.18917)'}
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_engine_async_asyncpg():
72+
"""
73+
crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver.
74+
This exclusively uses asynchronous mode.
75+
"""
76+
from asyncpg.pgproto.types import Point
77+
engine = create_async_engine("crate+asyncpg://crate@localhost:5432/", isolation_level="AUTOCOMMIT", echo=True)
78+
assert isinstance(engine, AsyncEngine)
79+
async with engine.begin() as conn:
80+
result = await conn.execute(QUERY)
81+
assert result.mappings().fetchone() == {'mountain': 'Acherkogel', 'coordinates': Point(10.95667, 47.18917)}

0 commit comments

Comments
 (0)