Skip to content

Commit 722e571

Browse files
authored
feat(litestar): implement initial litestar plugin (#20)
* feat(litestar): implement initial `litestar` plugin * feat: updated configuration for databases without pooling * fix: adds a `close_pool` based on config type * fix: add secret manager for `duckdb`
1 parent c7c723f commit 722e571

File tree

35 files changed

+1417
-213
lines changed

35 files changed

+1417
-213
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.9.10"
20+
rev: "v0.10.0"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
PY_OBJ = "py:obj"
4646

4747
nitpicky = True
48-
nitpick_ignore = []
48+
nitpick_ignore: list[str] = []
4949
nitpick_ignore_regex = [
5050
(PY_RE, r"sqlspec.*\.T"),
5151
]

docs/examples/litestar_multi_db.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from aiosqlite import Connection
2+
from duckdb import DuckDBPyConnection
3+
from litestar import Litestar, get
4+
5+
from sqlspec.adapters.aiosqlite import AiosqliteConfig
6+
from sqlspec.adapters.duckdb import DuckDBConfig
7+
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec
8+
9+
10+
@get("/test", sync_to_thread=True)
11+
def simple_select(etl_session: DuckDBPyConnection) -> dict[str, str]:
12+
result = etl_session.execute("SELECT 'Hello, world!' AS greeting").fetchall()
13+
return {"greeting": result[0][0]}
14+
15+
16+
@get("/")
17+
async def simple_sqlite(db_connection: Connection) -> dict[str, str]:
18+
result = await db_connection.execute_fetchall("SELECT 'Hello, world!' AS greeting")
19+
return {"greeting": result[0][0]} # type: ignore # noqa: PGH003
20+
21+
22+
sqlspec = SQLSpec(
23+
config=[
24+
DatabaseConfig(config=AiosqliteConfig(), commit_mode="autocommit"),
25+
DatabaseConfig(config=DuckDBConfig(), connection_key="etl_session"),
26+
],
27+
)
28+
app = Litestar(route_handlers=[simple_sqlite, simple_select], plugins=[sqlspec])

docs/examples/litestar_single_db.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from aiosqlite import Connection
2+
from litestar import Litestar, get
3+
4+
from sqlspec.adapters.aiosqlite import AiosqliteConfig
5+
from sqlspec.extensions.litestar import SQLSpec
6+
7+
8+
@get("/")
9+
async def simple_sqlite(db_session: Connection) -> dict[str, str]:
10+
"""Simple select statement.
11+
12+
Returns:
13+
dict[str, str]: The greeting.
14+
"""
15+
result = await db_session.execute_fetchall("SELECT 'Hello, world!' AS greeting")
16+
return {"greeting": result[0][0]} # type: ignore # noqa: PGH003
17+
18+
19+
sqlspec = SQLSpec(config=AiosqliteConfig())
20+
app = Litestar(route_handlers=[simple_sqlite], plugins=[sqlspec])

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ ignore = [
267267
"ARG002", # Unused method argument
268268
"ARG001", # Unused function argument
269269
"CPY001", # pycodestyle - Missing Copywrite notice at the top of the file
270+
"RUF029", # Ruff - function is declared as async but has no awaitable calls
271+
"COM812", # flake8-comma - Missing trailing comma
270272
]
271273
select = ["ALL"]
272274

@@ -310,6 +312,8 @@ known-first-party = ["sqlspec", "tests"]
310312
"TRY",
311313
"PT012",
312314
"INP001",
315+
"DOC",
316+
"PLC",
313317
]
314318
"tools/**/*.*" = ["D", "ARG", "EM", "TRY", "G", "FBT", "S603", "F811", "PLW0127", "PLR0911"]
315319
"tools/prepare_release.py" = ["S603", "S607"]

sqlspec/adapters/adbc/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ def connection_params(self) -> "dict[str, Any]":
4040

4141
@contextmanager
4242
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]":
43-
"""Create and provide a database connection."""
43+
"""Create and provide a database connection.
44+
45+
Yields:
46+
Connection: A database connection instance.
47+
"""
4448
from adbc_driver_manager.dbapi import connect
4549

4650
with connect(**self.connection_params) as connection:

sqlspec/adapters/aiosqlite/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def connection_config_dict(self) -> "dict[str, Any]":
5050
Returns:
5151
A string keyed dict of config kwargs for the aiosqlite.connect() function.
5252
"""
53-
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False)
53+
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False, exclude={"pool_instance"})
5454

5555
async def create_connection(self) -> "Connection":
5656
"""Create and return a new database connection.
@@ -76,8 +76,6 @@ async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGener
7676
Yields:
7777
An Aiosqlite connection instance.
7878
79-
Raises:
80-
ImproperConfigurationError: If the connection could not be established.
8179
"""
8280
connection = await self.create_connection()
8381
try:

sqlspec/adapters/asyncmy/config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,12 @@ def pool_config_dict(self) -> "dict[str, Any]":
130130
ImproperConfigurationError: If the pool configuration is not provided.
131131
"""
132132
if self.pool_config:
133-
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
133+
return dataclass_to_dict(
134+
self.pool_config,
135+
exclude_empty=True,
136+
convert_nested=False,
137+
exclude={"pool_instance"},
138+
)
134139
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
135140
raise ImproperConfigurationError(msg)
136141

@@ -179,3 +184,9 @@ async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGener
179184
pool = await self.provide_pool(*args, **kwargs) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
180185
async with pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
181186
yield connection # pyright: ignore[reportUnknownMemberType]
187+
188+
async def close_pool(self) -> None:
189+
"""Close the connection pool."""
190+
if self.pool_instance is not None:
191+
await self.pool_instance.close()
192+
self.pool_instance = None

sqlspec/adapters/asyncpg/config.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,14 @@ def pool_config_dict(self) -> "dict[str, Any]":
9696
Returns:
9797
A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <asyncpg.pool.create_pool>`
9898
function.
99+
100+
Raises:
101+
ImproperConfigurationError: If no pool_config is provided but a pool_instance is set.
99102
"""
100103
if self.pool_config:
101-
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
104+
return dataclass_to_dict(
105+
self.pool_config, exclude_empty=True, exclude={"pool_instance"}, convert_nested=False
106+
)
102107
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
103108
raise ImproperConfigurationError(msg)
104109

@@ -107,6 +112,10 @@ async def create_pool(self) -> "Pool": # pyright: ignore[reportMissingTypeArgum
107112
108113
Returns:
109114
Getter that returns the pool instance used by the plugin.
115+
116+
Raises:
117+
ImproperConfigurationError: If neither pool_config nor pool_instance are provided,
118+
or if the pool could not be configured.
110119
"""
111120
if self.pool_instance is not None:
112121
return self.pool_instance
@@ -136,9 +145,15 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[Pool]": # p
136145
async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[PoolConnectionProxy, None]": # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
137146
"""Create a connection instance.
138147
139-
Returns:
148+
Yields:
140149
A connection instance.
141150
"""
142151
db_pool = await self.provide_pool(*args, **kwargs) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
143152
async with db_pool.acquire() as connection: # pyright: ignore[reportUnknownVariableType]
144153
yield connection
154+
155+
async def close_pool(self) -> None:
156+
"""Close the pool."""
157+
if self.pool_instance is not None:
158+
await self.pool_instance.close()
159+
self.pool_instance = None

sqlspec/adapters/duckdb/config.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from contextlib import contextmanager
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from typing import TYPE_CHECKING, Any, Union, cast
44

55
from duckdb import DuckDBPyConnection
@@ -39,6 +39,27 @@ class ExtensionConfig(TypedDict):
3939
"""Optional version of the extension to install"""
4040

4141

42+
@dataclass
43+
class SecretConfig:
44+
"""Configuration for a secret to store in a connection.
45+
46+
This class provides configuration options for storing a secret in a connection for later retrieval.
47+
48+
For details see: https://duckdb.org/docs/api/python/overview#connection-options
49+
"""
50+
51+
secret_type: str = field()
52+
"""The type of secret to store"""
53+
name: str = field()
54+
"""The name of the secret to store"""
55+
persist: bool = field(default=False)
56+
"""Whether to persist the secret"""
57+
value: dict[str, Any] = field(default_factory=dict)
58+
"""The secret value to store"""
59+
replace_if_exists: bool = field(default=True)
60+
"""Whether to replace the secret if it already exists"""
61+
62+
4263
@dataclass
4364
class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
4465
"""Configuration for DuckDB database connections.
@@ -63,6 +84,8 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
6384

6485
extensions: "Union[Sequence[ExtensionConfig], ExtensionConfig, EmptyType]" = Empty
6586
"""A sequence of extension configurations to install and configure upon connection creation."""
87+
secrets: "Union[Sequence[SecretConfig], SecretConfig , EmptyType]" = Empty
88+
"""A dictionary of secrets to store in the connection for later retrieval."""
6689

6790
def __post_init__(self) -> None:
6891
"""Post-initialization validation and processing.
@@ -73,9 +96,10 @@ def __post_init__(self) -> None:
7396
"""
7497
if self.config is Empty:
7598
self.config = {}
76-
7799
if self.extensions is Empty:
78100
self.extensions = []
101+
if self.secrets is Empty:
102+
self.secrets = []
79103
if isinstance(self.extensions, dict):
80104
self.extensions = [self.extensions]
81105
# this is purely for mypy
@@ -120,6 +144,47 @@ def _configure_extensions(self, connection: "DuckDBPyConnection") -> None:
120144
for extension in cast("list[ExtensionConfig]", self.extensions):
121145
self._configure_extension(connection, extension)
122146

147+
@staticmethod
148+
def _secret_exists(connection: "DuckDBPyConnection", name: "str") -> bool:
149+
"""Check if a secret exists in the connection.
150+
151+
Args:
152+
connection: The DuckDB connection to check for the secret.
153+
name: The name of the secret to check for.
154+
155+
Returns:
156+
bool: True if the secret exists, False otherwise.
157+
"""
158+
results = connection.execute("select 1 from duckdb_secrets() where name=?", name).fetchone()
159+
return results is not None
160+
161+
@classmethod
162+
def _configure_secrets(
163+
cls,
164+
connection: "DuckDBPyConnection",
165+
secrets: "list[SecretConfig]",
166+
) -> None:
167+
"""Configure persistent secrets for the connection.
168+
169+
Args:
170+
connection: The DuckDB connection to configure secrets for.
171+
secrets: The list of secrets to store in the connection.
172+
173+
Raises:
174+
ImproperConfigurationError: If a secret could not be stored in the connection.
175+
"""
176+
try:
177+
for secret in secrets:
178+
secret_exists = cls._secret_exists(connection, secret.name)
179+
if not secret_exists or secret.replace_if_exists:
180+
connection.execute(f"""create or replace {"persistent" if secret.persist else ""} secret {secret.name} (
181+
type {secret.secret_type},
182+
{" ,".join([f"{k} '{v}'" for k, v in secret.value.items()])}
183+
) """)
184+
except Exception as e:
185+
msg = f"Failed to store secret. Error: {e!s}"
186+
raise ImproperConfigurationError(msg) from e
187+
123188
@staticmethod
124189
def _configure_extension(connection: "DuckDBPyConnection", extension: ExtensionConfig) -> None:
125190
"""Configure a single extension for the connection.
@@ -156,7 +221,12 @@ def connection_config_dict(self) -> "dict[str, Any]":
156221
Returns:
157222
A string keyed dict of config kwargs for the duckdb.connect() function.
158223
"""
159-
config = dataclass_to_dict(self, exclude_empty=True, exclude={"extensions"}, convert_nested=False)
224+
config = dataclass_to_dict(
225+
self,
226+
exclude_empty=True,
227+
exclude={"extensions", "pool_instance", "secrets"},
228+
convert_nested=False,
229+
)
160230
if not config.get("database"):
161231
config["database"] = ":memory:"
162232
return config
@@ -176,6 +246,8 @@ def create_connection(self) -> "DuckDBPyConnection":
176246
connection = duckdb.connect(**self.connection_config_dict) # pyright: ignore[reportUnknownMemberType]
177247
self._configure_extensions(connection)
178248
self._configure_connection(connection)
249+
self._configure_secrets(connection, cast("list[SecretConfig]", self.secrets))
250+
179251
except Exception as e:
180252
msg = f"Could not configure the DuckDB connection. Error: {e!s}"
181253
raise ImproperConfigurationError(msg) from e

sqlspec/adapters/oracledb/config/_asyncio.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,27 @@ class OracleAsyncDatabaseConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnec
5454
def pool_config_dict(self) -> "dict[str, Any]":
5555
"""Return the pool configuration as a dict.
5656
57+
Raises:
58+
ImproperConfigurationError: If no pool_config is provided but a pool_instance
59+
5760
Returns:
5861
A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <oracledb.pool.create_pool>`
5962
function.
6063
"""
6164
if self.pool_config is not None:
62-
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
65+
return dataclass_to_dict(
66+
self.pool_config, exclude_empty=True, convert_nested=False, exclude={"pool_instance"}
67+
)
6368
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
6469
raise ImproperConfigurationError(msg)
6570

6671
async def create_pool(self) -> "AsyncConnectionPool":
6772
"""Return a pool. If none exists yet, create one.
6873
74+
Raises:
75+
ImproperConfigurationError: If neither pool_config nor pool_instance are provided,
76+
or if the pool could not be configured.
77+
6978
Returns:
7079
Getter that returns the pool instance used by the plugin.
7180
"""
@@ -95,8 +104,8 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[AsyncConnect
95104
async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncConnection, None]":
96105
"""Create a connection instance.
97106
98-
Returns:
99-
A connection instance.
107+
Yields:
108+
AsyncConnection: A connection instance.
100109
"""
101110
db_pool = await self.provide_pool(*args, **kwargs)
102111
async with db_pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType]

sqlspec/adapters/oracledb/config/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@dataclass
30-
class OracleGenericPoolConfig(Generic[ConnectionT, PoolT], GenericPoolConfig):
30+
class OracleGenericPoolConfig(GenericPoolConfig, Generic[ConnectionT, PoolT]):
3131
"""Configuration for Oracle database connection pools.
3232
3333
This class provides configuration options for both synchronous and asynchronous Oracle

0 commit comments

Comments
 (0)