Skip to content

fix: StatementFilter and parameter validation fix #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions docs/examples/litestar_asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@
# ]
# ///

from typing import Annotated, Optional

from litestar import Litestar, get
from litestar.params import Dependency

from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgDriver, AsyncpgPoolConfig
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec, providers
from sqlspec.filters import FilterTypes


@get("/")
async def simple_asyncpg(db_session: AsyncpgDriver) -> dict[str, str]:
return await db_session.select_one("SELECT 'Hello, world!' AS greeting")
@get(
"/",
dependencies=providers.create_filter_dependencies({"search": "greeting", "search_ignore_case": True}),
)
async def simple_asyncpg(
db_session: AsyncpgDriver, filters: Annotated[list[FilterTypes], Dependency(skip_validation=True)]
) -> Optional[dict[str, str]]:
return await db_session.select_one_or_none(
"SELECT greeting FROM (select 'Hello, world!' as greeting) as t", *filters
)


sqlspec = SQLSpec(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nanoid = ["fastnanoid>=0.4.1"]
oracledb = ["oracledb"]
orjson = ["orjson"]
performance = ["sqlglot[rs]", "msgspec"]
polars = ["polars", "pyarrow"]
psqlpy = ["psqlpy"]
psycopg = ["psycopg[binary,pool]"]
pydantic = ["pydantic", "pydantic-extra-types"]
Expand Down
97 changes: 37 additions & 60 deletions sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import contextlib
import logging
import re
from collections.abc import Generator, Sequence
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, overload

from adbc_driver_manager.dbapi import Connection, Cursor
from sqlglot import exp as sqlglot_exp

from sqlspec.base import SyncDriverAdapterProtocol
from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
from sqlspec.exceptions import SQLParsingError
from sqlspec.filters import StatementFilter
from sqlspec.mixins import ResultConverter, SQLTranslatorMixin, SyncArrowBulkOperationsMixin
from sqlspec.statement import SQLStatement
Expand Down Expand Up @@ -91,7 +91,6 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
**kwargs: Any,
) -> "tuple[str, Optional[tuple[Any, ...]]]": # Always returns tuple or None for params
Expand All @@ -108,14 +107,24 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915
**kwargs: Additional keyword arguments.

Raises:
ParameterStyleMismatchError: If positional parameters are mixed with keyword arguments.
SQLParsingError: If the SQL statement cannot be parsed.

Returns:
A tuple of (sql, parameters) ready for execution.
"""
passed_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None
combined_filters_list: list[StatementFilter] = list(filters)

if parameters is not None:
if isinstance(parameters, StatementFilter):
combined_filters_list.insert(0, parameters)
# passed_parameters remains None
else:
# If parameters is not a StatementFilter, it's actual data parameters.
passed_parameters = parameters

# Special handling for SQLite with non-dict parameters and named placeholders
if self.dialect == "sqlite" and parameters is not None and not is_dict(parameters):
if self.dialect == "sqlite" and passed_parameters is not None and not is_dict(passed_parameters):
# First mask out comments and strings to avoid detecting parameters in those
comments = list(SQL_COMMENT_PATTERN.finditer(sql))
strings = list(SQL_STRING_PATTERN.finditer(sql))
Expand All @@ -136,26 +145,15 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915
param_positions.sort(reverse=True)
for start, end in param_positions:
sql = sql[:start] + "?" + sql[end:]
if not isinstance(parameters, (list, tuple)):
return sql, (parameters,)
return sql, tuple(parameters)
if not isinstance(passed_parameters, (list, tuple)):
passed_parameters = (passed_parameters,)
passed_parameters = tuple(passed_parameters)

# Standard processing for all other cases
merged_params = parameters
if kwargs:
if is_dict(parameters):
merged_params = {**parameters, **kwargs}
elif parameters is not None:
msg = "Cannot mix positional parameters with keyword arguments for adbc driver."
raise ParameterStyleMismatchError(msg)
else:
merged_params = kwargs
statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect)

# 2. Create SQLStatement with dialect and process
statement = SQLStatement(sql, merged_params, dialect=self.dialect)

# Apply any filters
for filter_obj in filters:
# Apply any filters from combined_filters_list
for filter_obj in combined_filters_list:
statement = statement.apply_filter(filter_obj)

processed_sql, processed_params, parsed_expr = statement.process()
Expand Down Expand Up @@ -284,7 +282,6 @@ def select(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
Expand All @@ -295,7 +292,6 @@ def select(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
Expand All @@ -305,7 +301,6 @@ def select(
self,
sql: str,
parameters: Optional["StatementParameterType"] = None,
/,
*filters: "StatementFilter",
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
Expand Down Expand Up @@ -341,7 +336,6 @@ def select_one(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
Expand All @@ -352,7 +346,6 @@ def select_one(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
Expand All @@ -362,7 +355,6 @@ def select_one(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
Expand Down Expand Up @@ -396,7 +388,6 @@ def select_one_or_none(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
Expand All @@ -407,7 +398,6 @@ def select_one_or_none(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
Expand All @@ -417,7 +407,6 @@ def select_one_or_none(
self,
sql: str,
parameters: Optional["StatementParameterType"] = None,
/,
*filters: "StatementFilter",
connection: Optional["AdbcConnection"] = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
Expand Down Expand Up @@ -452,8 +441,7 @@ def select_value(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
Expand All @@ -463,8 +451,7 @@ def select_value(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[T]",
**kwargs: Any,
Expand All @@ -473,8 +460,7 @@ def select_value(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "Optional[type[T]]" = None,
**kwargs: Any,
Expand Down Expand Up @@ -508,8 +494,7 @@ def select_value_or_none(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
Expand All @@ -519,8 +504,7 @@ def select_value_or_none(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[T]",
**kwargs: Any,
Expand All @@ -529,8 +513,7 @@ def select_value_or_none(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "Optional[type[T]]" = None,
**kwargs: Any,
Expand Down Expand Up @@ -564,22 +547,21 @@ def insert_update_delete(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
**kwargs: Any,
) -> int:
"""Execute an insert, update, or delete statement.

Args:
sql: The SQL statement to execute.
sql: The SQL statement string.
parameters: The parameters for the statement (dict, tuple, list, or None).
*filters: Statement filters to apply.
connection: Optional connection override.
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.

Returns:
The number of rows affected by the statement.
Row count affected by the operation.
"""
connection = self._connection(connection)
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
Expand All @@ -593,8 +575,7 @@ def insert_update_delete_returning(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: None = None,
**kwargs: Any,
Expand All @@ -604,8 +585,7 @@ def insert_update_delete_returning(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "type[ModelDTOT]",
**kwargs: Any,
Expand All @@ -614,24 +594,23 @@ def insert_update_delete_returning(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
schema_type: "Optional[type[ModelDTOT]]" = None,
**kwargs: Any,
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
"""Insert, update, or delete data from the database and return result.
"""Insert, update, or delete data with RETURNING clause.

Args:
sql: The SQL statement to execute.
sql: The SQL statement string.
parameters: The parameters for the statement (dict, tuple, list, or None).
*filters: Statement filters to apply.
connection: Optional connection override.
schema_type: Optional schema class for the result.
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.

Returns:
The first row of results.
The returned row data, or None if no row returned.
"""
connection = self._connection(connection)
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
Expand All @@ -648,7 +627,6 @@ def execute_script(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
connection: "Optional[AdbcConnection]" = None,
**kwargs: Any,
) -> str:
Expand All @@ -673,12 +651,11 @@ def execute_script(

# --- Arrow Bulk Operations ---

def select_arrow( # pyright: ignore[reportUnknownParameterType]
def select_arrow(
self,
sql: str,
parameters: "Optional[StatementParameterType]" = None,
/,
*filters: StatementFilter,
*filters: "StatementFilter",
connection: "Optional[AdbcConnection]" = None,
**kwargs: Any,
) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType]
Expand All @@ -692,7 +669,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType]
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.

Returns:
An Apache Arrow Table containing the query results.
An Arrow Table containing the query results.
"""
connection = self._connection(connection)
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
Expand Down
Loading