Skip to content

Commit 5bd1674

Browse files
committed
feat: additional progress on new adapters
1 parent ae38ef9 commit 5bd1674

File tree

19 files changed

+1065
-176
lines changed

19 files changed

+1065
-176
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.7.2"
20+
rev: "v0.7.3"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ requires-python = ">=3.9, <4.0"
99
version = "0.2.0"
1010

1111
[project.optional-dependencies]
12+
adbc = ["adbc-driver-manager", "pyarrow"]
1213
aioodbc = ["aioodbc"]
1314
aiosqlite = ["aiosqlite"]
1415
asyncmy = ["asyncmy"]
@@ -28,7 +29,7 @@ fastapi = ["fastapi"]
2829
flask = ["flask"]
2930

3031
[dependency-groups]
31-
dev = [{ include-group = "lint" }, { include-group = "doc" }, { include-group = "test" }]
32+
dev = ["adbc-driver-sqlite", "adbc-driver-postgresql", "adbc-driver-flightsql", { include-group = "lint" }, { include-group = "doc" }, { include-group = "test" }]
3233
doc = [
3334
"auto-pytabs[sphinx]>=0.5.0",
3435
"git-cliff>=2.6.1",

sqlspec/adapters/adbc/__init__.py

Whitespace-only changes.

sqlspec/adapters/adbc/config.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, TypeVar
6+
7+
from sqlspec.types.configs import GenericDatabaseConfig
8+
from sqlspec.types.empty import Empty
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Generator
12+
from typing import Any
13+
14+
from adbc_driver_manager.dbapi import Connection, Cursor
15+
16+
from sqlspec.types.empty import EmptyType
17+
18+
__all__ = ("AdbcDatabaseConfig",)
19+
20+
ConnectionT = TypeVar("ConnectionT", bound="Connection")
21+
CursorT = TypeVar("CursorT", bound="Cursor")
22+
23+
24+
@dataclass
25+
class AdbcDatabaseConfig(GenericDatabaseConfig):
26+
"""Configuration for ADBC connections.
27+
28+
This class provides configuration options for ADBC database connections using the
29+
ADBC Driver Manager.([1](https://arrow.apache.org/adbc/current/python/api/adbc_driver_manager.html))
30+
"""
31+
32+
uri: str | EmptyType = Empty
33+
"""Database URI"""
34+
driver_name: str | EmptyType = Empty
35+
"""Name of the ADBC driver to use"""
36+
db_kwargs: dict[str, Any] | None = None
37+
"""Additional database-specific connection parameters"""
38+
39+
@property
40+
def connection_params(self) -> dict[str, Any]:
41+
"""Return the connection parameters as a dict."""
42+
return {
43+
k: v
44+
for k, v in {"uri": self.uri, "driver": self.driver_name, **(self.db_kwargs or {})}.items()
45+
if v is not Empty
46+
}
47+
48+
@contextmanager
49+
def provide_connection(self, *args: Any, **kwargs: Any) -> Generator[Connection, None, None]:
50+
"""Create and provide a database connection."""
51+
from adbc_driver_manager.dbapi import connect
52+
53+
with connect(**self.connection_params) as connection:
54+
yield connection

sqlspec/adapters/adbc/driver.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Any
5+
6+
from sqlspec.types.protocols import StatementType, SyncDriverAdapterProtocol
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Callable, Generator, Iterable
10+
11+
from adbc_driver_manager.dbapi import Connection, Cursor
12+
13+
__all__ = ("AdbcAdapter",)
14+
15+
16+
class AdbcAdapter(SyncDriverAdapterProtocol):
17+
"""A synchronous ADBC SQLSpec Adapter."""
18+
19+
is_async: bool = False
20+
21+
def process_sql(self, op_type: StatementType, sql: str) -> str:
22+
"""Process SQL query."""
23+
return sql
24+
25+
def _cursor(self, connection: Connection) -> Cursor:
26+
"""Get a cursor from a connection."""
27+
return connection.cursor()
28+
29+
def _process_row(self, row: Any, column_names: list[str], record_class: Callable | None = None) -> Any:
30+
"""Process a row into the desired format."""
31+
if record_class is None:
32+
return row
33+
return record_class(**{str(k): v for k, v in zip(column_names, row, strict=False)})
34+
35+
def select(
36+
self,
37+
connection: Connection,
38+
sql: str,
39+
parameters: list | dict,
40+
record_class: Callable | None,
41+
) -> Iterable[Any]:
42+
"""Handle a relation-returning SELECT."""
43+
cur = self._cursor(connection)
44+
try:
45+
cur.execute(sql, parameters)
46+
if record_class is None:
47+
yield from cur
48+
else:
49+
column_names = [desc[0] for desc in cur.description]
50+
for row in cur:
51+
yield self._process_row(row, column_names, record_class)
52+
finally:
53+
cur.close()
54+
55+
def select_one(
56+
self,
57+
connection: Connection,
58+
sql: str,
59+
parameters: list | dict,
60+
record_class: Callable | None,
61+
) -> Any | None:
62+
"""Handle a single-row-returning SELECT."""
63+
cur = self._cursor(connection)
64+
try:
65+
cur.execute(sql, parameters)
66+
row = cur.fetchone()
67+
if row is None:
68+
return None
69+
if record_class is None:
70+
return row
71+
column_names = [desc[0] for desc in cur.description or []]
72+
return self._process_row(row, column_names, record_class)
73+
finally:
74+
cur.close()
75+
76+
def select_scalar(
77+
self,
78+
connection: Connection,
79+
sql: str,
80+
parameters: list | dict,
81+
) -> Any | None:
82+
"""Handle a scalar-returning SELECT."""
83+
cur = self._cursor(connection)
84+
try:
85+
cur.execute(sql, parameters)
86+
row = cur.fetchone()
87+
return row[0] if row else None
88+
finally:
89+
cur.close()
90+
91+
@contextmanager
92+
def with_cursor(
93+
self,
94+
connection: Connection,
95+
sql: str,
96+
parameters: list | dict,
97+
) -> Generator[Cursor, None, None]:
98+
"""Execute a query and yield the cursor."""
99+
cur = self._cursor(connection)
100+
try:
101+
cur.execute(sql, parameters)
102+
yield cur
103+
finally:
104+
cur.close()
105+
106+
def insert_update_delete(
107+
self,
108+
connection: Connection,
109+
sql: str,
110+
parameters: list | dict,
111+
) -> int:
112+
"""Handle an INSERT, UPDATE, or DELETE."""
113+
cur = self._cursor(connection)
114+
try:
115+
cur.execute(sql, parameters)
116+
return cur.rowcount
117+
finally:
118+
cur.close()
119+
120+
def insert_update_delete_returning(
121+
self,
122+
connection: Connection,
123+
sql: str,
124+
parameters: list | dict,
125+
record_class: Callable | None = None,
126+
) -> Any:
127+
"""Handle an INSERT, UPDATE, or DELETE with RETURNING clause."""
128+
return self.select_one(connection, sql, parameters, record_class)
129+
130+
def insert_update_delete_many(
131+
self,
132+
connection: Connection,
133+
sql: str,
134+
parameters: list | dict,
135+
) -> int:
136+
"""Handle multiple INSERT, UPDATE, or DELETE operations."""
137+
cur = self._cursor(connection)
138+
try:
139+
cur.executemany(sql, parameters)
140+
return cur.rowcount
141+
finally:
142+
cur.close()

sqlspec/adapters/asyncmy/__init__.py

Whitespace-only changes.

sqlspec/adapters/asyncpg/driver.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, cast
55

66
from sqlspec.sql.patterns import VAR_REF
77
from sqlspec.types.protocols import AsyncDriverAdapterProtocol, StatementType
@@ -15,22 +15,23 @@
1515
__all__ = ("AsyncpgAdapter",)
1616

1717

18-
class MaybeAcquire:
18+
class ManagedConnection:
1919
"""Context manager for handling connection acquisition from pools or direct connections."""
2020

2121
def __init__(self, client: Connection | Pool) -> None:
2222
self.client = client
2323
self._managed_conn: Connection | None = None
2424

2525
async def __aenter__(self) -> Connection:
26-
if hasattr(self.client, "acquire"):
27-
self._managed_conn = await self.client.acquire()
28-
return self._managed_conn
29-
return self.client
26+
if "acquire" in dir(self.client):
27+
self._managed_conn = await self.client.acquire() # pyright: ignore[reportAttributeAccessIssue]
28+
return cast("Connection", self._managed_conn)
29+
self._managed_conn = None
30+
return cast("Connection", self.client)
3031

3132
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
3233
if self._managed_conn is not None:
33-
await self.client.release(self._managed_conn)
34+
await self.client.release(self._managed_conn) # pyright: ignore[reportAttributeAccessIssue]
3435

3536

3637
class AsyncpgAdapter(AsyncDriverAdapterProtocol):
@@ -99,7 +100,7 @@ async def select(
99100
record_class: Callable | None,
100101
) -> Iterable[Any]:
101102
"""Handle a relation-returning SELECT."""
102-
async with MaybeAcquire(connection) as conn:
103+
async with ManagedConnection(connection) as conn:
103104
results = await conn.fetch(sql, *parameters)
104105
if record_class is not None:
105106
return [record_class(**dict(rec)) for rec in results]
@@ -113,7 +114,7 @@ async def select_one(
113114
record_class: Callable | None,
114115
) -> Any | None:
115116
"""Handle a single-row-returning SELECT."""
116-
async with MaybeAcquire(connection) as conn:
117+
async with ManagedConnection(connection) as conn:
117118
result = await conn.fetchrow(sql, *parameters)
118119
if result is not None and record_class is not None:
119120
return record_class(**dict(result))
@@ -126,7 +127,7 @@ async def select_scalar(
126127
parameters: list | dict,
127128
) -> Any | None:
128129
"""Handle a scalar-returning SELECT."""
129-
async with MaybeAcquire(connection) as conn:
130+
async with ManagedConnection(connection) as conn:
130131
return await conn.fetchval(sql, *parameters)
131132

132133
async def with_cursor(
@@ -136,7 +137,7 @@ async def with_cursor(
136137
parameters: list | dict,
137138
) -> AsyncGenerator[Cursor | CursorFactory, None]:
138139
"""Execute a query and yield the cursor."""
139-
async with MaybeAcquire(connection) as conn:
140+
async with ManagedConnection(connection) as conn:
140141
stmt = await conn.prepare(sql)
141142
async with conn.transaction():
142143
yield stmt.cursor(*parameters)
@@ -148,7 +149,7 @@ async def insert_update_delete(
148149
parameters: list | dict,
149150
) -> int:
150151
"""Handle an INSERT, UPDATE, or DELETE."""
151-
async with MaybeAcquire(connection) as conn:
152+
async with ManagedConnection(connection) as conn:
152153
result = await conn.execute(sql, *parameters)
153154
if isinstance(result, str):
154155
return int(result.split()[-1])
@@ -172,7 +173,7 @@ async def execute_script(
172173
record_class: Callable | None = None,
173174
) -> Any:
174175
"""Execute a SQL script."""
175-
async with MaybeAcquire(connection) as conn:
176+
async with ManagedConnection(connection) as conn:
176177
if parameters:
177178
result = await conn.fetch(sql, *parameters)
178179
else:

sqlspec/adapters/duckdb/__init__.py

Whitespace-only changes.

sqlspec/adapters/oracledb/driver/_async.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,22 @@ async def select(
3636
record_class: Callable | None,
3737
) -> Iterable[Any]:
3838
"""Handle a relation-returning SELECT."""
39-
cur = await self._cursor(connection)
39+
cur = self._cursor(connection)
4040
try:
4141
await cur.execute(sql, parameters)
42+
column_names = [desc[0] for desc in cur.description]
43+
cur.rowfactory = lambda *args: dict(zip(column_names, args))
44+
data = await cur.fetchall()
4245
if record_class is None:
4346
async for row in cur:
4447
yield row
4548
else:
4649
column_names = [desc[0] for desc in cur.description]
50+
cur.rowfactory = lambda *args: dict(zip(column_names, args))
4751
async for row in cur:
4852
yield record_class(**{str(k): v for k, v in zip(column_names, row, strict=False)})
4953
finally:
50-
await cur.close()
54+
cur.close()
5155

5256
async def select_one(
5357
self,
@@ -90,7 +94,7 @@ async def with_cursor(
9094
parameters: list | dict,
9195
) -> AsyncGenerator[AsyncCursor, None]:
9296
"""Execute a query and yield the cursor."""
93-
cur = await self._cursor(connection)
97+
cur = self._cursor(connection)
9498
try:
9599
await cur.execute(sql, parameters)
96100
yield cur
@@ -104,7 +108,7 @@ async def insert_update_delete(
104108
parameters: list | dict,
105109
) -> int:
106110
"""Handle an INSERT, UPDATE, or DELETE."""
107-
cur = await self._cursor(connection)
111+
cur = self._cursor(connection)
108112
try:
109113
await cur.execute(sql, parameters)
110114
return cur.rowcount

sqlspec/adapters/psycopg/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)