Skip to content

feat(litestar): implement initial litestar plugin #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.9.10"
rev: "v0.10.0"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
PY_OBJ = "py:obj"

nitpicky = True
nitpick_ignore = []
nitpick_ignore: list[str] = []
nitpick_ignore_regex = [
(PY_RE, r"sqlspec.*\.T"),
]
Expand Down
28 changes: 28 additions & 0 deletions docs/examples/litestar_multi_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from aiosqlite import Connection
from duckdb import DuckDBPyConnection
from litestar import Litestar, get

from sqlspec.adapters.aiosqlite import AiosqliteConfig
from sqlspec.adapters.duckdb import DuckDBConfig
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec


@get("/test", sync_to_thread=True)
def simple_select(etl_session: DuckDBPyConnection) -> dict[str, str]:
result = etl_session.execute("SELECT 'Hello, world!' AS greeting").fetchall()
return {"greeting": result[0][0]}


@get("/")
async def simple_sqlite(db_connection: Connection) -> dict[str, str]:
result = await db_connection.execute_fetchall("SELECT 'Hello, world!' AS greeting")
return {"greeting": result[0][0]} # type: ignore # noqa: PGH003


sqlspec = SQLSpec(
config=[
DatabaseConfig(config=AiosqliteConfig(), commit_mode="autocommit"),
DatabaseConfig(config=DuckDBConfig(), connection_key="etl_session"),
],
)
app = Litestar(route_handlers=[simple_sqlite, simple_select], plugins=[sqlspec])
20 changes: 20 additions & 0 deletions docs/examples/litestar_single_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from aiosqlite import Connection
from litestar import Litestar, get

from sqlspec.adapters.aiosqlite import AiosqliteConfig
from sqlspec.extensions.litestar import SQLSpec


@get("/")
async def simple_sqlite(db_session: Connection) -> dict[str, str]:
"""Simple select statement.

Returns:
dict[str, str]: The greeting.
"""
result = await db_session.execute_fetchall("SELECT 'Hello, world!' AS greeting")
return {"greeting": result[0][0]} # type: ignore # noqa: PGH003


sqlspec = SQLSpec(config=AiosqliteConfig())
app = Litestar(route_handlers=[simple_sqlite], plugins=[sqlspec])
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ ignore = [
"ARG002", # Unused method argument
"ARG001", # Unused function argument
"CPY001", # pycodestyle - Missing Copywrite notice at the top of the file
"RUF029", # Ruff - function is declared as async but has no awaitable calls
"COM812", # flake8-comma - Missing trailing comma
]
select = ["ALL"]

Expand Down Expand Up @@ -310,6 +312,8 @@ known-first-party = ["sqlspec", "tests"]
"TRY",
"PT012",
"INP001",
"DOC",
"PLC",
]
"tools/**/*.*" = ["D", "ARG", "EM", "TRY", "G", "FBT", "S603", "F811", "PLW0127", "PLR0911"]
"tools/prepare_release.py" = ["S603", "S607"]
Expand Down
6 changes: 5 additions & 1 deletion sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def connection_params(self) -> "dict[str, Any]":

@contextmanager
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]":
"""Create and provide a database connection."""
"""Create and provide a database connection.
Yields:
Connection: A database connection instance.
"""
from adbc_driver_manager.dbapi import connect

with connect(**self.connection_params) as connection:
Expand Down
4 changes: 1 addition & 3 deletions sqlspec/adapters/aiosqlite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def connection_config_dict(self) -> "dict[str, Any]":
Returns:
A string keyed dict of config kwargs for the aiosqlite.connect() function.
"""
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(self, exclude_empty=True, convert_nested=False, exclude={"pool_instance"})

async def create_connection(self) -> "Connection":
"""Create and return a new database connection.
Expand All @@ -76,8 +76,6 @@ async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGener
Yields:
An Aiosqlite connection instance.

Raises:
ImproperConfigurationError: If the connection could not be established.
"""
connection = await self.create_connection()
try:
Expand Down
13 changes: 12 additions & 1 deletion sqlspec/adapters/asyncmy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def pool_config_dict(self) -> "dict[str, Any]":
ImproperConfigurationError: If the pool configuration is not provided.
"""
if self.pool_config:
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(
self.pool_config,
exclude_empty=True,
convert_nested=False,
exclude={"pool_instance"},
)
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
raise ImproperConfigurationError(msg)

Expand Down Expand Up @@ -179,3 +184,9 @@ async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGener
pool = await self.provide_pool(*args, **kwargs) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
async with pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
yield connection # pyright: ignore[reportUnknownMemberType]

async def close_pool(self) -> None:
"""Close the connection pool."""
if self.pool_instance is not None:
await self.pool_instance.close()
self.pool_instance = None
19 changes: 17 additions & 2 deletions sqlspec/adapters/asyncpg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ def pool_config_dict(self) -> "dict[str, Any]":
Returns:
A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <asyncpg.pool.create_pool>`
function.

Raises:
ImproperConfigurationError: If no pool_config is provided but a pool_instance is set.
"""
if self.pool_config:
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(
self.pool_config, exclude_empty=True, exclude={"pool_instance"}, convert_nested=False
)
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
raise ImproperConfigurationError(msg)

Expand All @@ -107,6 +112,10 @@ async def create_pool(self) -> "Pool": # pyright: ignore[reportMissingTypeArgum

Returns:
Getter that returns the pool instance used by the plugin.

Raises:
ImproperConfigurationError: If neither pool_config nor pool_instance are provided,
or if the pool could not be configured.
"""
if self.pool_instance is not None:
return self.pool_instance
Expand Down Expand Up @@ -136,9 +145,15 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[Pool]": # p
async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[PoolConnectionProxy, None]": # pyright: ignore[reportMissingTypeArgument,reportUnknownParameterType]
"""Create a connection instance.

Returns:
Yields:
A connection instance.
"""
db_pool = await self.provide_pool(*args, **kwargs) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
async with db_pool.acquire() as connection: # pyright: ignore[reportUnknownVariableType]
yield connection

async def close_pool(self) -> None:
"""Close the pool."""
if self.pool_instance is not None:
await self.pool_instance.close()
self.pool_instance = None
78 changes: 75 additions & 3 deletions sqlspec/adapters/duckdb/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Union, cast

from duckdb import DuckDBPyConnection
Expand Down Expand Up @@ -39,6 +39,27 @@ class ExtensionConfig(TypedDict):
"""Optional version of the extension to install"""


@dataclass
class SecretConfig:
"""Configuration for a secret to store in a connection.

This class provides configuration options for storing a secret in a connection for later retrieval.

For details see: https://duckdb.org/docs/api/python/overview#connection-options
"""

secret_type: str = field()
"""The type of secret to store"""
name: str = field()
"""The name of the secret to store"""
persist: bool = field(default=False)
"""Whether to persist the secret"""
value: dict[str, Any] = field(default_factory=dict)
"""The secret value to store"""
replace_if_exists: bool = field(default=True)
"""Whether to replace the secret if it already exists"""


@dataclass
class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):
"""Configuration for DuckDB database connections.
Expand All @@ -63,6 +84,8 @@ class DuckDBConfig(NoPoolSyncConfig[DuckDBPyConnection]):

extensions: "Union[Sequence[ExtensionConfig], ExtensionConfig, EmptyType]" = Empty
"""A sequence of extension configurations to install and configure upon connection creation."""
secrets: "Union[Sequence[SecretConfig], SecretConfig , EmptyType]" = Empty
"""A dictionary of secrets to store in the connection for later retrieval."""

def __post_init__(self) -> None:
"""Post-initialization validation and processing.
Expand All @@ -73,9 +96,10 @@ def __post_init__(self) -> None:
"""
if self.config is Empty:
self.config = {}

if self.extensions is Empty:
self.extensions = []
if self.secrets is Empty:
self.secrets = []
if isinstance(self.extensions, dict):
self.extensions = [self.extensions]
# this is purely for mypy
Expand Down Expand Up @@ -120,6 +144,47 @@ def _configure_extensions(self, connection: "DuckDBPyConnection") -> None:
for extension in cast("list[ExtensionConfig]", self.extensions):
self._configure_extension(connection, extension)

@staticmethod
def _secret_exists(connection: "DuckDBPyConnection", name: "str") -> bool:
"""Check if a secret exists in the connection.

Args:
connection: The DuckDB connection to check for the secret.
name: The name of the secret to check for.

Returns:
bool: True if the secret exists, False otherwise.
"""
results = connection.execute("select 1 from duckdb_secrets() where name=?", name).fetchone()
return results is not None

@classmethod
def _configure_secrets(
cls,
connection: "DuckDBPyConnection",
secrets: "list[SecretConfig]",
) -> None:
"""Configure persistent secrets for the connection.

Args:
connection: The DuckDB connection to configure secrets for.
secrets: The list of secrets to store in the connection.

Raises:
ImproperConfigurationError: If a secret could not be stored in the connection.
"""
try:
for secret in secrets:
secret_exists = cls._secret_exists(connection, secret.name)
if not secret_exists or secret.replace_if_exists:
connection.execute(f"""create or replace {"persistent" if secret.persist else ""} secret {secret.name} (
type {secret.secret_type},
{" ,".join([f"{k} '{v}'" for k, v in secret.value.items()])}
) """)
except Exception as e:
msg = f"Failed to store secret. Error: {e!s}"
raise ImproperConfigurationError(msg) from e

@staticmethod
def _configure_extension(connection: "DuckDBPyConnection", extension: ExtensionConfig) -> None:
"""Configure a single extension for the connection.
Expand Down Expand Up @@ -156,7 +221,12 @@ def connection_config_dict(self) -> "dict[str, Any]":
Returns:
A string keyed dict of config kwargs for the duckdb.connect() function.
"""
config = dataclass_to_dict(self, exclude_empty=True, exclude={"extensions"}, convert_nested=False)
config = dataclass_to_dict(
self,
exclude_empty=True,
exclude={"extensions", "pool_instance", "secrets"},
convert_nested=False,
)
if not config.get("database"):
config["database"] = ":memory:"
return config
Expand All @@ -176,6 +246,8 @@ def create_connection(self) -> "DuckDBPyConnection":
connection = duckdb.connect(**self.connection_config_dict) # pyright: ignore[reportUnknownMemberType]
self._configure_extensions(connection)
self._configure_connection(connection)
self._configure_secrets(connection, cast("list[SecretConfig]", self.secrets))

except Exception as e:
msg = f"Could not configure the DuckDB connection. Error: {e!s}"
raise ImproperConfigurationError(msg) from e
Expand Down
15 changes: 12 additions & 3 deletions sqlspec/adapters/oracledb/config/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,27 @@ class OracleAsyncDatabaseConfig(AsyncDatabaseConfig[AsyncConnection, AsyncConnec
def pool_config_dict(self) -> "dict[str, Any]":
"""Return the pool configuration as a dict.

Raises:
ImproperConfigurationError: If no pool_config is provided but a pool_instance

Returns:
A string keyed dict of config kwargs for the Asyncpg :func:`create_pool <oracledb.pool.create_pool>`
function.
"""
if self.pool_config is not None:
return dataclass_to_dict(self.pool_config, exclude_empty=True, convert_nested=False)
return dataclass_to_dict(
self.pool_config, exclude_empty=True, convert_nested=False, exclude={"pool_instance"}
)
msg = "'pool_config' methods can not be used when a 'pool_instance' is provided."
raise ImproperConfigurationError(msg)

async def create_pool(self) -> "AsyncConnectionPool":
"""Return a pool. If none exists yet, create one.

Raises:
ImproperConfigurationError: If neither pool_config nor pool_instance are provided,
or if the pool could not be configured.

Returns:
Getter that returns the pool instance used by the plugin.
"""
Expand Down Expand Up @@ -95,8 +104,8 @@ def provide_pool(self, *args: "Any", **kwargs: "Any") -> "Awaitable[AsyncConnect
async def provide_connection(self, *args: "Any", **kwargs: "Any") -> "AsyncGenerator[AsyncConnection, None]":
"""Create a connection instance.

Returns:
A connection instance.
Yields:
AsyncConnection: A connection instance.
"""
db_pool = await self.provide_pool(*args, **kwargs)
async with db_pool.acquire() as connection: # pyright: ignore[reportUnknownMemberType]
Expand Down
2 changes: 1 addition & 1 deletion sqlspec/adapters/oracledb/config/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@dataclass
class OracleGenericPoolConfig(Generic[ConnectionT, PoolT], GenericPoolConfig):
class OracleGenericPoolConfig(GenericPoolConfig, Generic[ConnectionT, PoolT]):
"""Configuration for Oracle database connection pools.

This class provides configuration options for both synchronous and asynchronous Oracle
Expand Down
Loading