From 9a78e2054e2e6bdd460cb4a274816bcbb075b438 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 12 May 2025 16:51:31 +0000 Subject: [PATCH 1/3] feat: merge statement filters --- docs/examples/litestar_asyncpg.py | 19 ++- sqlspec/adapters/adbc/driver.py | 33 +--- sqlspec/adapters/aiosqlite/driver.py | 43 ++--- sqlspec/adapters/asyncmy/driver.py | 74 ++++---- sqlspec/adapters/asyncpg/driver.py | 76 ++++----- sqlspec/adapters/bigquery/driver.py | 247 +++++++++------------------ sqlspec/base.py | 41 ----- sqlspec/filters.py | 3 +- 8 files changed, 180 insertions(+), 356 deletions(-) diff --git a/docs/examples/litestar_asyncpg.py b/docs/examples/litestar_asyncpg.py index d6a1103..324c82f 100644 --- a/docs/examples/litestar_asyncpg.py +++ b/docs/examples/litestar_asyncpg.py @@ -13,15 +13,26 @@ # ] # /// +from typing import Annotated, Optional + from litestar import Litestar, get +from litestar.params import Dependency from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgDriver, AsyncpgPoolConfig -from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec +from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec, providers +from sqlspec.filters import FilterTypes -@get("/") -async def simple_asyncpg(db_session: AsyncpgDriver) -> dict[str, str]: - return await db_session.select_one("SELECT 'Hello, world!' AS greeting") +@get( + "/", + dependencies=providers.create_filter_dependencies({"search": "greeting", "search_ignore_case": True}), +) +async def simple_asyncpg( + db_session: AsyncpgDriver, filters: Annotated[list[FilterTypes], Dependency(skip_validation=True)] +) -> Optional[dict[str, str]]: + return await db_session.select_one_or_none( + "SELECT greeting FROM (select 'Hello, world!' as greeting) as t", *filters + ) sqlspec = SQLSpec( diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 584e964..711ab60 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -91,7 +91,6 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915 self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[tuple[Any, ...]]]": # Always returns tuple or None for params @@ -284,7 +283,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: None = None, @@ -295,7 +293,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", @@ -305,7 +302,6 @@ def select( self, sql: str, parameters: Optional["StatementParameterType"] = None, - /, *filters: "StatementFilter", connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -341,7 +337,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: None = None, @@ -352,7 +347,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", @@ -362,7 +356,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -396,7 +389,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: None = None, @@ -407,7 +399,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", @@ -417,7 +408,6 @@ def select_one_or_none( self, sql: str, parameters: Optional["StatementParameterType"] = None, - /, *filters: "StatementFilter", connection: Optional["AdbcConnection"] = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -452,7 +442,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: None = None, @@ -463,7 +452,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: "type[T]", @@ -473,7 +461,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -508,7 +495,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: None = None, @@ -519,7 +505,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: "type[T]", @@ -529,7 +514,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -564,7 +548,6 @@ def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, **kwargs: Any, @@ -572,14 +555,14 @@ def insert_update_delete( """Execute an insert, update, or delete statement. Args: - sql: The SQL statement to execute. + sql: The SQL statement string. parameters: The parameters for the statement (dict, tuple, list, or None). *filters: Statement filters to apply. connection: Optional connection override. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - The number of rows affected by the statement. + Row count affected by the operation. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) @@ -593,7 +576,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: None = None, @@ -604,7 +586,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", @@ -614,16 +595,15 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: StatementFilter, connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Insert, update, or delete data from the database and return result. + """Insert, update, or delete data with RETURNING clause. Args: - sql: The SQL statement to execute. + sql: The SQL statement string. parameters: The parameters for the statement (dict, tuple, list, or None). *filters: Statement filters to apply. connection: Optional connection override. @@ -631,7 +611,7 @@ def insert_update_delete_returning( **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - The first row of results. + The returned row data, or None if no row returned. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) @@ -648,7 +628,6 @@ def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[AdbcConnection]" = None, **kwargs: Any, ) -> str: @@ -692,7 +671,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - An Apache Arrow Table containing the query results. + An Arrow Table containing the query results. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 41693ba..aabdcb2 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -6,14 +6,14 @@ from sqlglot import exp from sqlspec.base import AsyncDriverAdapterProtocol +from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin from sqlspec.statement import SQLStatement from sqlspec.typing import is_dict if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Sequence + from collections.abc import AsyncGenerator, Mapping, Sequence # Added Mapping, Sequence - from sqlspec.filters import StatementFilter from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AiosqliteConnection", "AiosqliteDriver") @@ -51,7 +51,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -63,17 +62,27 @@ def _process_sql_params( Args: sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. *filters: Statement filters to apply. **kwargs: Additional keyword arguments. Returns: Tuple of processed SQL and parameters. """ - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) + passed_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) - # Apply any filters - for filter_obj in filters: + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + # _actual_data_params remains None + else: + # If parameters is not a StatementFilter, it's actual data parameters. + passed_parameters = parameters # type: ignore[assignment] + + statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect) + + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) processed_sql, processed_params, parsed_expr = statement.process() @@ -121,7 +130,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, @@ -132,7 +140,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -142,7 +149,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -174,7 +180,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, @@ -185,7 +190,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -195,7 +199,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -226,7 +229,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, @@ -237,7 +239,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -247,7 +248,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -279,7 +279,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, @@ -290,7 +289,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[T]", @@ -300,7 +298,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -330,7 +327,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, @@ -341,7 +337,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[T]", @@ -351,7 +346,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -381,7 +375,6 @@ async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, **kwargs: Any, @@ -404,7 +397,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: None = None, @@ -415,7 +407,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -425,7 +416,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AiosqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -458,7 +448,6 @@ async def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[AiosqliteConnection]" = None, **kwargs: Any, ) -> str: diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index f6ea8b5..3ee2694 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -1,7 +1,7 @@ # type: ignore import logging import re -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, Optional, Union, overload @@ -9,14 +9,16 @@ from sqlspec.base import AsyncDriverAdapterProtocol from sqlspec.exceptions import ParameterStyleMismatchError +from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin from sqlspec.statement import SQLStatement from sqlspec.typing import is_dict if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from asyncmy.cursors import Cursor - from sqlspec.filters import StatementFilter from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AsyncmyDriver",) @@ -55,7 +57,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -63,7 +64,7 @@ def _process_sql_params( Args: sql: The SQL statement to process. - parameters: The parameters to bind to the statement. + parameters: The parameters to bind to the statement. Can be data or a StatementFilter. *filters: Statement filters to apply. **kwargs: Additional keyword arguments. @@ -73,41 +74,52 @@ def _process_sql_params( Returns: A tuple of (sql, parameters) ready for execution. """ + # Convert filters tuple to a list to allow modification + current_filters: list[StatementFilter] = list(filters) + actual_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + + if parameters is not None: + if isinstance(parameters, StatementFilter): + current_filters.insert(0, parameters) + # actual_parameters remains None + else: + actual_parameters = parameters # type: ignore[assignment] + # Handle MySQL-specific placeholders (%s) which SQLGlot doesn't parse well # If %s placeholders are present, handle them directly mysql_placeholders_count = len(MYSQL_PLACEHOLDER_PATTERN.findall(sql)) if mysql_placeholders_count > 0: # For MySQL format placeholders, minimal processing is needed - if parameters is None: + if actual_parameters is None: if mysql_placeholders_count > 0: msg = f"asyncmy: SQL statement contains {mysql_placeholders_count} format placeholders ('%s'), but no parameters were provided. SQL: {sql}" raise ParameterStyleMismatchError(msg) return sql, None # Convert dict to tuple if needed - if is_dict(parameters): + if is_dict(actual_parameters): # MySQL's %s placeholders require positional params msg = "asyncmy: Dictionary parameters provided with '%s' placeholders. MySQL format placeholders require tuple/list parameters." raise ParameterStyleMismatchError(msg) # Convert to tuple (handles both scalar and sequence cases) - if not isinstance(parameters, (list, tuple)): + if not isinstance(actual_parameters, (list, tuple)): # Scalar parameter case - return sql, (parameters,) + return sql, (actual_parameters,) # Sequence parameter case - ensure appropriate length - if len(parameters) != mysql_placeholders_count: - msg = f"asyncmy: Parameter count mismatch. SQL expects {mysql_placeholders_count} '%s' placeholders, but {len(parameters)} parameters were provided. SQL: {sql}" + if len(actual_parameters) != mysql_placeholders_count: # type: ignore[arg-type] + msg = f"asyncmy: Parameter count mismatch. SQL expects {mysql_placeholders_count} '%s' placeholders, but {len(actual_parameters)} parameters were provided. SQL: {sql}" # type: ignore[arg-type] raise ParameterStyleMismatchError(msg) - return sql, tuple(parameters) + return sql, tuple(actual_parameters) # type: ignore[arg-type] # Create a SQLStatement with MySQL dialect - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) + statement = SQLStatement(sql, actual_parameters, kwargs=kwargs, dialect=self.dialect) # Apply any filters - for filter_obj in filters: + for filter_obj in current_filters: # Use the modified list of filters statement = statement.apply_filter(filter_obj) # Process the statement for execution @@ -136,7 +148,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, @@ -147,7 +158,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", @@ -157,7 +167,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -171,22 +180,18 @@ async def select( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: - await cursor.execute(sql, parameters) + await cursor.execute(final_sql, final_params) results = await cursor.fetchall() if not results: return [] column_names = [c[0] for c in cursor.description or []] - - # Convert to dicts first - dict_results = [dict(zip(column_names, row)) for row in results] - return self.to_schema(dict_results, schema_type=schema_type) + return self.to_schema(dict_[dict(zip(column_names, row)) for row in results]esults, schema_type=schema_type) @overload async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, @@ -197,7 +202,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", @@ -207,7 +211,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -225,17 +228,13 @@ async def select_one( result = await cursor.fetchone() result = self.check_not_found(result) column_names = [c[0] for c in cursor.description or []] - - # Convert to dict and use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, @@ -246,7 +245,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", @@ -256,7 +254,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -275,17 +272,13 @@ async def select_one_or_none( if result is None: return None column_names = [c[0] for c in cursor.description or []] - - # Convert to dict and use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, @@ -296,7 +289,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[T]", @@ -306,7 +298,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -333,7 +324,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, @@ -344,7 +334,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[T]", @@ -354,7 +343,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -381,7 +369,6 @@ async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, **kwargs: Any, @@ -402,7 +389,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: None = None, @@ -413,7 +399,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "type[ModelDTOT]", @@ -423,7 +408,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncmyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index ff35810..9b77a56 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -8,6 +8,7 @@ from typing_extensions import TypeAlias from sqlspec.base import AsyncDriverAdapterProtocol +from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin from sqlspec.statement import SQLStatement @@ -18,7 +19,6 @@ from asyncpg.connection import Connection from asyncpg.pool import PoolConnectionProxy - from sqlspec.filters import StatementFilter from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("AsyncpgConnection", "AsyncpgDriver") @@ -69,7 +69,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -80,22 +79,33 @@ def _process_sql_params( Args: sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. *filters: Statement filters to apply. **kwargs: Additional keyword arguments. Returns: Tuple of processed SQL and parameters. """ - # Handle scalar parameter by converting to a single-item tuple - if parameters is not None and not isinstance(parameters, (list, tuple, dict)): - parameters = (parameters,) + data_params_for_statement: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + # data_params_for_statement remains None + else: + # If parameters is not a StatementFilter, it's actual data parameters. + data_params_for_statement = parameters + + # Handle scalar parameter by converting to a single-item tuple if it's data + if data_params_for_statement is not None and not isinstance(data_params_for_statement, (list, tuple, dict)): + data_params_for_statement = (data_params_for_statement,) # Create a SQLStatement with PostgreSQL dialect - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) + statement = SQLStatement(sql, data_params_for_statement, kwargs=kwargs, dialect=self.dialect) - # Apply any filters - for filter_obj in filters: + # Apply any filters from the combined list + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) # Process the statement @@ -164,7 +174,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: None = None, @@ -175,7 +184,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "type[ModelDTOT]", @@ -185,7 +193,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -194,9 +201,9 @@ async def select( """Fetch data from the database. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. schema_type: Optional schema class for the result. **kwargs: Additional keyword arguments. @@ -218,7 +225,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: None = None, @@ -229,7 +235,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "type[ModelDTOT]", @@ -239,7 +244,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -248,9 +252,9 @@ async def select_one( """Fetch one row from the database. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. schema_type: Optional schema class for the result. **kwargs: Additional keyword arguments. @@ -270,7 +274,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: None = None, @@ -281,7 +284,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "type[ModelDTOT]", @@ -291,7 +293,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -300,9 +301,9 @@ async def select_one_or_none( """Fetch one row from the database. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. schema_type: Optional schema class for the result. **kwargs: Additional keyword arguments. @@ -323,7 +324,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: None = None, @@ -334,7 +334,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "type[T]", @@ -344,7 +343,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -353,9 +351,9 @@ async def select_value( """Fetch a single value from the database. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. schema_type: Optional schema class for the result. **kwargs: Additional keyword arguments. @@ -377,7 +375,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: None = None, @@ -388,7 +385,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "type[T]", @@ -398,7 +394,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -407,9 +402,9 @@ async def select_value_or_none( """Fetch a single value from the database. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. schema_type: Optional schema class for the result. **kwargs: Additional keyword arguments. @@ -431,7 +426,6 @@ async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: Optional["AsyncpgConnection"] = None, **kwargs: Any, @@ -439,9 +433,9 @@ async def insert_update_delete( """Insert, update, or delete data from the database. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. **kwargs: Additional keyword arguments. @@ -463,7 +457,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: None = None, @@ -474,7 +467,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "type[ModelDTOT]", @@ -484,7 +476,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[AsyncpgConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -493,9 +484,9 @@ async def insert_update_delete_returning( """Insert, update, or delete data from the database and return the affected row. Args: - *filters: Statement filters to apply. sql: SQL statement. - parameters: Query parameters. + parameters: Query parameters. Can be data or a StatementFilter. + *filters: Statement filters to apply. connection: Optional connection to use. schema_type: Optional schema class for the result. **kwargs: Additional keyword arguments. @@ -516,7 +507,6 @@ async def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[AsyncpgConnection]" = None, **kwargs: Any, ) -> str: diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index ed9d8e6..9be92e1 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -1,7 +1,7 @@ import contextlib import datetime import logging -from collections.abc import Iterator, Sequence +from collections.abc import Iterator, Mapping, Sequence from decimal import Decimal from typing import ( TYPE_CHECKING, @@ -69,8 +69,6 @@ def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]": if isinstance(value, float): return "FLOAT64", None if isinstance(value, Decimal): - # Precision/scale might matter, but BQ client handles conversion. - # Defaulting to BIGNUMERIC, NUMERIC might be desired in some cases though (User change) return "BIGNUMERIC", None if isinstance(value, str): return "STRING", None @@ -78,23 +76,17 @@ def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]": return "BYTES", None if isinstance(value, datetime.date): return "DATE", None - # DATETIME is for timezone-naive values if isinstance(value, datetime.datetime) and value.tzinfo is None: return "DATETIME", None - # TIMESTAMP is for timezone-aware values if isinstance(value, datetime.datetime) and value.tzinfo is not None: return "TIMESTAMP", None if isinstance(value, datetime.time): return "TIME", None - # Handle Arrays - Determine element type if isinstance(value, (list, tuple)): if not value: - # Cannot determine type of empty array, BQ requires type. - # Raise or default? Defaulting is risky. Let's raise. msg = "Cannot determine BigQuery ARRAY type for empty sequence." raise SQLSpecError(msg) - # Infer type from first element first_element = value[0] element_type, _ = BigQueryDriver._get_bq_param_type(first_element) if element_type is None: @@ -102,55 +94,59 @@ def _get_bq_param_type(value: Any) -> "tuple[Optional[str], Optional[str]]": raise SQLSpecError(msg) return "ARRAY", element_type - # Handle Structs (basic dict mapping) - Requires careful handling - # if isinstance(value, dict): - # # This requires recursive type mapping for sub-fields. - # # For simplicity, users might need to construct StructQueryParameter manually. - # # return "STRUCT", None # Placeholder if implementing # noqa: ERA001 - # raise SQLSpecError("Automatic STRUCT mapping not implemented. Please use bigquery.StructQueryParameter.") # noqa: ERA001 - - return None, None # Unsupported type + return None, None def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": """Process SQL and parameters using SQLStatement with dialect support. + This method also handles the separation of StatementFilter instances that might be + passed in the 'parameters' argument. + Args: sql: The SQL statement to process. - parameters: The parameters to bind to the statement. - *filters: Statement filters to apply. - **kwargs: Additional keyword arguments. + parameters: The parameters to bind to the statement. This can be a + Mapping (dict), Sequence (list/tuple), a single StatementFilter, or None. + *filters: Additional statement filters to apply. + **kwargs: Additional keyword arguments (treated as named parameters for the SQL statement). Raises: ParameterStyleMismatchError: If pre-formatted BigQuery parameters are mixed with keyword arguments. Returns: - A tuple of (sql, parameters) ready for execution. + A tuple of (processed_sql, processed_parameters) ready for execution. """ - # Special case: check for pre-formatted BQ parameters + passed_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + else: + passed_parameters = parameters + if ( - isinstance(parameters, (list, tuple)) - and parameters - and all(isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in parameters) + isinstance(passed_parameters, (list, tuple)) + and passed_parameters + and all( + isinstance(p, (bigquery.ScalarQueryParameter, bigquery.ArrayQueryParameter)) for p in passed_parameters + ) ): if kwargs: msg = "Cannot mix pre-formatted BigQuery parameters with keyword arguments." raise ParameterStyleMismatchError(msg) - return sql, parameters + return sql, passed_parameters - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) + statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect) - # Apply any filters - for filter_obj in filters: + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) - # Process the statement for execution processed_sql, processed_params, _ = statement.process() return processed_sql, processed_params @@ -159,8 +155,7 @@ def _run_query_job( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, job_config: "Optional[QueryJobConfig]" = None, is_script: bool = False, @@ -168,19 +163,15 @@ def _run_query_job( ) -> "QueryJob": conn = self._connection(connection) - # Determine the final job config, creating a new one if necessary - # to avoid modifying a shared default config. if job_config: - final_job_config = job_config # Use the provided config directly + final_job_config = job_config elif self._default_query_job_config: - final_job_config = QueryJobConfig() + final_job_config = QueryJobConfig.from_api_repr(self._default_query_job_config.to_api_repr()) # type: ignore[no-untyped-call] else: - final_job_config = QueryJobConfig() # Create a fresh config + final_job_config = QueryJobConfig() - # Process SQL and parameters final_sql, processed_params = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Handle pre-formatted parameters if ( isinstance(processed_params, (list, tuple)) and processed_params @@ -189,31 +180,24 @@ def _run_query_job( ) ): final_job_config.query_parameters = list(processed_params) - # Convert regular parameters to BigQuery parameters elif isinstance(processed_params, dict): - # Convert dict params to BQ ScalarQueryParameter final_job_config.query_parameters = [ bigquery.ScalarQueryParameter(name, self._get_bq_param_type(value)[0], value) for name, value in processed_params.items() ] elif isinstance(processed_params, (list, tuple)): - # Convert list params to BQ ScalarQueryParameter final_job_config.query_parameters = [ bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value) for value in processed_params ] - # Determine which kwargs to pass to the actual query method - # We only want to pass kwargs that were *not* treated as SQL parameters final_query_kwargs = {} - if parameters is not None and kwargs: # Params came via arg, kwargs are separate + if parameters is not None and kwargs: final_query_kwargs = kwargs - # Else: If params came via kwargs, they are already handled, so don't pass them again - # Execute query return conn.query( final_sql, - job_config=final_job_config, + job_config=final_job_config, # pyright: ignore **final_query_kwargs, ) @@ -238,15 +222,12 @@ def _rows_to_results( schema_type: "Optional[type[ModelDTOT]]" = None, ) -> Sequence[Union[ModelDTOT, dict[str, Any]]]: processed_results = [] - # Create a quick lookup map for schema fields from the passed schema schema_map = {field.name: field for field in schema} for row in rows: - # row here is now a Row object from the iterator row_dict = {} - for key, value in row.items(): # Use row.items() on the Row object + for key, value in row.items(): field = schema_map.get(key) - # Workaround remains the same if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value: try: parsed_value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc) @@ -263,8 +244,7 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -274,8 +254,7 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -284,27 +263,12 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, job_config: "Optional[QueryJobConfig]" = None, **kwargs: Any, ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": - """Fetch data from the database. - - Args: - sql: The SQL query string. - parameters: The parameters for the query (dict, tuple, list, or None). - *filters: Statement filters to apply. - connection: Optional connection override. - schema_type: Optional schema class for the result. - job_config: Optional job configuration. - **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. - - Returns: - List of row data as either model instances or dictionaries. - """ query_job = self._run_query_job( sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs ) @@ -315,8 +279,7 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -326,8 +289,7 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -336,8 +298,7 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, job_config: "Optional[QueryJobConfig]" = None, @@ -348,14 +309,8 @@ def select_one( ) rows_iterator = query_job.result() try: - # Pass the iterator containing only the first row to _rows_to_results - # This ensures the timestamp workaround is applied consistently. - # We need to pass the original iterator for schema access, but only consume one row. first_row = next(rows_iterator) - # Create a simple iterator yielding only the first row for processing single_row_iter = iter([first_row]) - # We need RowIterator type for schema, create mock/proxy if needed, or pass schema - # Let's try passing schema directly to _rows_to_results (requires modifying it) results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type) return results[0] except StopIteration: @@ -367,8 +322,7 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -378,8 +332,7 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -388,8 +341,7 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, job_config: "Optional[QueryJobConfig]" = None, @@ -401,9 +353,7 @@ def select_one_or_none( rows_iterator = query_job.result() try: first_row = next(rows_iterator) - # Create a simple iterator yielding only the first row for processing single_row_iter = iter([first_row]) - # Pass schema directly results = self._rows_to_results(single_row_iter, rows_iterator.schema, schema_type) return results[0] except StopIteration: @@ -414,8 +364,7 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[T]]" = None, job_config: "Optional[QueryJobConfig]" = None, @@ -426,8 +375,7 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "type[T]", **kwargs: Any, @@ -436,8 +384,7 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[T]]" = None, job_config: "Optional[QueryJobConfig]" = None, @@ -450,8 +397,7 @@ def select_value( try: first_row = next(iter(rows)) value = first_row[0] - # Apply timestamp workaround if necessary - field = rows.schema[0] # Get schema for the first column + field = rows.schema[0] if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value: with contextlib.suppress(ValueError): value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc) @@ -466,8 +412,7 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -477,8 +422,7 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "type[T]", **kwargs: Any, @@ -487,8 +431,7 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[T]]" = None, job_config: "Optional[QueryJobConfig]" = None, @@ -506,8 +449,7 @@ def select_value_or_none( try: first_row = next(iter(rows)) value = first_row[0] - # Apply timestamp workaround if necessary - field = rows.schema[0] # Get schema for the first column + field = rows.schema[0] if field and field.field_type == "TIMESTAMP" and isinstance(value, str) and "." in value: with contextlib.suppress(ValueError): value = datetime.datetime.fromtimestamp(float(value), tz=datetime.timezone.utc) @@ -520,32 +462,23 @@ def insert_update_delete( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: Optional["BigQueryConnection"] = None, job_config: Optional[QueryJobConfig] = None, **kwargs: Any, ) -> int: - """Executes INSERT, UPDATE, DELETE and returns affected row count. - - Returns: - int: The number of rows affected by the DML statement. - """ query_job = self._run_query_job( sql, parameters, *filters, connection=connection, job_config=job_config, **kwargs ) - # DML statements might not return rows, check job properties - # num_dml_affected_rows might be None initially, wait might be needed - query_job.result() # Ensure completion - return query_job.num_dml_affected_rows or 0 # Return 0 if None + query_job.result() + return query_job.num_dml_affected_rows or 0 @overload def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -555,8 +488,7 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -565,31 +497,23 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, job_config: "Optional[QueryJobConfig]" = None, **kwargs: Any, ) -> Union[ModelDTOT, dict[str, Any]]: - """BigQuery DML RETURNING equivalent is complex, often requires temp tables or scripting.""" msg = "BigQuery does not support `RETURNING` clauses directly in the same way as some other SQL databases. Consider multi-statement queries or alternative approaches." raise NotImplementedError(msg) def execute_script( self, - sql: str, # Expecting a script here - parameters: "Optional[StatementParameterType]" = None, # Parameters might be complex in scripts - /, + sql: str, + parameters: "Optional[StatementParameterType]" = None, connection: "Optional[BigQueryConnection]" = None, job_config: "Optional[QueryJobConfig]" = None, **kwargs: Any, ) -> str: - """Executes a BigQuery script and returns the job ID. - - Returns: - str: The job ID of the executed script. - """ query_job = self._run_query_job( sql, parameters, @@ -600,14 +524,11 @@ def execute_script( ) return str(query_job.job_id) - # --- Mixin Implementations --- - def select_arrow( # pyright: ignore self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[BigQueryConnection]" = None, job_config: "Optional[QueryJobConfig]" = None, **kwargs: Any, @@ -615,10 +536,8 @@ def select_arrow( # pyright: ignore conn = self._connection(connection) final_job_config = job_config or self._default_query_job_config or QueryJobConfig() - # Process SQL and parameters using SQLStatement processed_sql, processed_params = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Convert parameters to BigQuery format if isinstance(processed_params, dict): query_parameters = [] for key, value in processed_params.items(): @@ -633,16 +552,14 @@ def select_arrow( # pyright: ignore raise SQLSpecError(msg) final_job_config.query_parameters = query_parameters elif isinstance(processed_params, (list, tuple)): - # Convert sequence parameters final_job_config.query_parameters = [ bigquery.ScalarQueryParameter(None, self._get_bq_param_type(value)[0], value) for value in processed_params ] - # Execute the query and get Arrow table try: query_job = conn.query(processed_sql, job_config=final_job_config) - arrow_table = query_job.to_arrow() # Waits for job completion + arrow_table = query_job.to_arrow() except Exception as e: msg = f"BigQuery Arrow query execution failed: {e!s}" raise SQLSpecError(msg) from e @@ -650,31 +567,34 @@ def select_arrow( # pyright: ignore def select_to_parquet( self, - sql: str, # Expects table ID: project.dataset.table + sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", destination_uri: "Optional[str]" = None, connection: "Optional[BigQueryConnection]" = None, job_config: "Optional[bigquery.ExtractJobConfig]" = None, **kwargs: Any, ) -> None: - """Exports a BigQuery table to Parquet files in Google Cloud Storage. - - Raises: - NotImplementedError: If the SQL is not a fully qualified table ID or if parameters are provided. - NotFoundError: If the source table is not found. - SQLSpecError: If the Parquet export fails. - """ if destination_uri is None: msg = "destination_uri is required" raise SQLSpecError(msg) conn = self._connection(connection) - if "." not in sql or parameters is not None: - msg = "select_to_parquet currently expects a fully qualified table ID (project.dataset.table) as the `sql` argument and no `parameters`." + + if parameters is not None: + msg = ( + "select_to_parquet expects a fully qualified table ID (e.g., 'project.dataset.table') " + "as the `sql` argument and does not support `parameters`." + ) raise NotImplementedError(msg) - source_table_ref = bigquery.TableReference.from_string(sql, default_project=conn.project) + try: + source_table_ref = bigquery.TableReference.from_string(sql, default_project=conn.project) + except ValueError as e: + msg = ( + "select_to_parquet expects a fully qualified table ID (e.g., 'project.dataset.table') " + f"as the `sql` argument. Parsing failed for input '{sql}': {e!s}" + ) + raise NotImplementedError(msg) from e final_extract_config = job_config or bigquery.ExtractJobConfig() # type: ignore[no-untyped-call] final_extract_config.destination_format = bigquery.DestinationFormat.PARQUET @@ -684,9 +604,8 @@ def select_to_parquet( source_table_ref, destination_uri, job_config=final_extract_config, - # Location is correctly inferred by the client library ) - extract_job.result() # Wait for completion + extract_job.result() except NotFound: msg = f"Source table not found for Parquet export: {source_table_ref}" @@ -699,12 +618,4 @@ def select_to_parquet( raise SQLSpecError(msg) def _connection(self, connection: "Optional[BigQueryConnection]" = None) -> "BigQueryConnection": - """Get the connection to use for the operation. - - Args: - connection: Optional connection to use. - - Returns: - The connection to use. - """ return connection or self.connection diff --git a/sqlspec/base.py b/sqlspec/base.py index 656c3f7..d6483d7 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -538,7 +538,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -577,7 +576,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -590,7 +588,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -602,7 +599,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: Optional[type[ModelDTOT]] = None, @@ -615,7 +611,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -628,7 +623,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -640,7 +634,6 @@ def select_one( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, *filters: "StatementFilter", connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, @@ -653,7 +646,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -666,7 +658,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -678,7 +669,6 @@ def select_one_or_none( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, *filters: "StatementFilter", connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, @@ -691,7 +681,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -704,7 +693,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[T]", @@ -716,7 +704,6 @@ def select_value( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, *filters: "StatementFilter", connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, @@ -729,7 +716,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -742,7 +728,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[T]", @@ -754,7 +739,6 @@ def select_value_or_none( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, *filters: "StatementFilter", connection: Optional[ConnectionT] = None, schema_type: Optional[type[T]] = None, @@ -766,7 +750,6 @@ def insert_update_delete( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, *filters: "StatementFilter", connection: Optional[ConnectionT] = None, **kwargs: Any, @@ -778,7 +761,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -791,7 +773,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -803,7 +784,6 @@ def insert_update_delete_returning( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, *filters: "StatementFilter", connection: Optional[ConnectionT] = None, schema_type: Optional[type[ModelDTOT]] = None, @@ -815,7 +795,6 @@ def execute_script( self, sql: str, parameters: Optional[StatementParameterType] = None, - /, connection: Optional[ConnectionT] = None, **kwargs: Any, ) -> str: ... @@ -833,7 +812,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -846,7 +824,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -858,7 +835,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -871,7 +847,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -884,7 +859,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -896,7 +870,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -909,7 +882,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -922,7 +894,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -934,7 +905,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -947,7 +917,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -960,7 +929,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[T]", @@ -972,7 +940,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, @@ -985,7 +952,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -998,7 +964,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[T]", @@ -1010,7 +975,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[T]]" = None, @@ -1022,7 +986,6 @@ async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, **kwargs: Any, @@ -1034,7 +997,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: None = None, @@ -1047,7 +1009,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "type[ModelDTOT]", @@ -1059,7 +1020,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -1071,7 +1031,6 @@ async def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[ConnectionT]" = None, **kwargs: Any, ) -> str: ... diff --git a/sqlspec/filters.py b/sqlspec/filters.py index 3655f40..fe1bb86 100644 --- a/sqlspec/filters.py +++ b/sqlspec/filters.py @@ -4,7 +4,7 @@ from collections import abc from dataclasses import dataclass from datetime import datetime -from typing import Any, Generic, Literal, Optional, Protocol, Union, cast +from typing import Any, Generic, Literal, Optional, Protocol, Union, cast, runtime_checkable from sqlglot import exp from typing_extensions import TypeAlias, TypeVar @@ -30,6 +30,7 @@ T = TypeVar("T") +@runtime_checkable class StatementFilter(Protocol): """Protocol for filters that can be appended to a statement.""" From b63b612de682810cce164b15ec7a74cbb729e927 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 12 May 2025 17:47:54 +0000 Subject: [PATCH 2/3] feat: more updates --- pyproject.toml | 1 + sqlspec/adapters/aiosqlite/driver.py | 2 +- sqlspec/adapters/asyncmy/driver.py | 4 +- sqlspec/adapters/asyncpg/driver.py | 2 +- sqlspec/adapters/duckdb/driver.py | 69 +++------- sqlspec/adapters/oracledb/driver.py | 139 +++++++++----------- sqlspec/adapters/psqlpy/driver.py | 95 ++++++-------- sqlspec/adapters/psycopg/driver.py | 183 +++++++++++---------------- sqlspec/adapters/sqlite/driver.py | 50 +++----- 9 files changed, 210 insertions(+), 335 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f45bdbf..c1fc940 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ nanoid = ["fastnanoid>=0.4.1"] oracledb = ["oracledb"] orjson = ["orjson"] performance = ["sqlglot[rs]", "msgspec"] +polars = ["polars", "pyarrow"] psqlpy = ["psqlpy"] psycopg = ["psycopg[binary,pool]"] pydantic = ["pydantic", "pydantic-extra-types"] diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index aabdcb2..fd11d05 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -78,7 +78,7 @@ def _process_sql_params( # _actual_data_params remains None else: # If parameters is not a StatementFilter, it's actual data parameters. - passed_parameters = parameters # type: ignore[assignment] + passed_parameters = parameters statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect) diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 3ee2694..8e3e607 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -180,12 +180,12 @@ async def select( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: - await cursor.execute(final_sql, final_params) + await cursor.execute(sql, parameters) results = await cursor.fetchall() if not results: return [] column_names = [c[0] for c in cursor.description or []] - return self.to_schema(dict_[dict(zip(column_names, row)) for row in results]esults, schema_type=schema_type) + return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type) @overload async def select_one( diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 9b77a56..1aba8e6 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -13,7 +13,7 @@ from sqlspec.statement import SQLStatement if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Mapping, Sequence from asyncpg import Record from asyncpg.connection import Connection diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 9d210f3..3df0a7a 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -5,14 +5,14 @@ from duckdb import DuckDBPyConnection from sqlspec.base import SyncDriverAdapterProtocol +from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin, SyncArrowBulkOperationsMixin from sqlspec.statement import SQLStatement from sqlspec.typing import ArrowTable, StatementParameterType if TYPE_CHECKING: - from collections.abc import Generator, Sequence + from collections.abc import Generator, Mapping, Sequence - from sqlspec.filters import StatementFilter from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T __all__ = ("DuckDBConnection", "DuckDBDriver") @@ -58,7 +58,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -77,10 +76,18 @@ def _process_sql_params( Returns: Tuple of processed SQL and parameters. """ - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) - - # Apply any filters - for filter_obj in filters: + data_params_for_statement: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + else: + data_params_for_statement = parameters + if data_params_for_statement is not None and not isinstance(data_params_for_statement, (list, tuple, dict)): + data_params_for_statement = (data_params_for_statement,) + statement = SQLStatement(sql, data_params_for_statement, kwargs=kwargs, dialect=self.dialect) + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) processed_sql, processed_params, _ = statement.process() @@ -98,7 +105,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, @@ -109,7 +115,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", @@ -119,7 +124,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -138,17 +142,13 @@ def select( if not results: return [] column_names = [column[0] for column in cursor.description or []] - - # Convert to dicts first - dict_results = [dict(zip(column_names, row)) for row in results] - return self.to_schema(dict_results, schema_type=schema_type) + return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type) @overload def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, @@ -159,7 +159,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", @@ -169,7 +168,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -187,17 +185,13 @@ def select_one( result = cursor.fetchone() result = self.check_not_found(result) column_names = [column[0] for column in cursor.description or []] - - # Convert to dict and use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, @@ -208,7 +202,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", @@ -218,7 +211,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -237,17 +229,13 @@ def select_one_or_none( if result is None: return None column_names = [column[0] for column in cursor.description or []] - - # Convert to dict and use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, @@ -258,7 +246,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "type[T]", @@ -268,7 +255,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -295,7 +281,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, @@ -306,7 +291,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "type[T]", @@ -316,7 +300,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -337,7 +320,6 @@ def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, **kwargs: Any, @@ -354,7 +336,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: None = None, @@ -365,7 +346,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "type[ModelDTOT]", @@ -375,7 +355,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -389,16 +368,12 @@ def insert_update_delete_returning( result = cursor.fetchall() result = self.check_not_found(result) column_names = [col[0] for col in cursor.description or []] - - # Convert to dict and use ResultConverter - dict_result = dict(zip(column_names, result[0])) - return self.to_schema(dict_result, schema_type=schema_type) + return self.to_schema(dict(zip(column_names, result[0])), schema_type=schema_type) def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[DuckDBConnection]" = None, **kwargs: Any, ) -> str: @@ -415,8 +390,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[DuckDBConnection]" = None, **kwargs: Any, ) -> "ArrowTable": @@ -425,6 +399,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] Args: sql: The SQL query string. parameters: Parameters for the query. + *filters: Optional filters to apply to the SQL statement. connection: Optional connection override. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. @@ -432,10 +407,6 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] An Apache Arrow Table containing the query results. """ connection = self._connection(connection) - - # Extract filters from kwargs if present - filters = kwargs.pop("filters", []) - sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) with self._with_cursor(connection) as cursor: params = [] if parameters is None else parameters diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index ca21752..94f5030 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -5,6 +5,7 @@ from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol +from sqlspec.filters import StatementFilter from sqlspec.mixins import ( AsyncArrowBulkOperationsMixin, ResultConverter, @@ -15,9 +16,8 @@ from sqlspec.typing import ArrowTable, StatementParameterType, T if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator, Sequence + from collections.abc import AsyncGenerator, Generator, Mapping, Sequence - from sqlspec.filters import StatementFilter from sqlspec.typing import ModelDTOT __all__ = ("OracleAsyncConnection", "OracleAsyncDriver", "OracleSyncConnection", "OracleSyncDriver") @@ -37,7 +37,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], dict[str, Any]]]]": @@ -52,15 +51,22 @@ def _process_sql_params( Returns: A tuple of (sql, parameters) ready for execution. """ - # Special case: Oracle treats empty dicts as None - if isinstance(parameters, dict) and not parameters and not kwargs: + data_params_for_statement: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + else: + data_params_for_statement = parameters + if data_params_for_statement is not None and not isinstance(data_params_for_statement, (list, tuple, dict)): + data_params_for_statement = (data_params_for_statement,) + + if isinstance(data_params_for_statement, dict) and not data_params_for_statement and not kwargs: return sql, None - # Create a SQLStatement with appropriate dialect - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) - - # Apply any filters - for filter_obj in filters: + statement = SQLStatement(sql, data_params_for_statement, kwargs=kwargs, dialect=self.dialect) + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) processed_sql, processed_params, _ = statement.process() @@ -102,7 +108,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, @@ -113,7 +118,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -123,7 +127,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -149,7 +152,7 @@ def select( results = cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if not results: return [] - # Get column names + # Get column names from description column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type) @@ -159,7 +162,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, @@ -170,7 +172,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -180,7 +181,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -217,7 +217,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, @@ -228,7 +227,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -238,7 +236,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -276,7 +273,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, @@ -287,7 +283,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[T]", @@ -297,7 +292,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -333,7 +327,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, @@ -344,7 +337,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[T]", @@ -354,7 +346,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -390,7 +381,6 @@ def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, **kwargs: Any, @@ -419,7 +409,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: None = None, @@ -430,7 +419,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -440,7 +428,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -473,7 +460,6 @@ def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[OracleSyncConnection]" = None, **kwargs: Any, ) -> str: @@ -482,7 +468,6 @@ def execute_script( Args: sql: The SQL script to execute. parameters: The parameters for the script (dict, tuple, list, or None). - *filters: Statement filters to apply. connection: Optional connection override. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. @@ -557,7 +542,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, @@ -568,7 +552,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -578,7 +561,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -586,6 +568,14 @@ async def select( ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": """Fetch data from the database. + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + *filters: Statement filters to apply. + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: List of row data as either model instances or dictionaries. """ @@ -597,20 +587,16 @@ async def select( results = await cursor.fetchall() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] if not results: return [] - # Get column names + # Get column names from description column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if schema_type: - return [cast("ModelDTOT", schema_type(**dict(zip(column_names, row)))) for row in results] # pyright: ignore - - return [dict(zip(column_names, row)) for row in results] # pyright: ignore + return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type) @overload async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, @@ -621,7 +607,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -631,7 +616,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -639,6 +623,14 @@ async def select_one( ) -> "Union[ModelDTOT, dict[str, Any]]": """Fetch one row from the database. + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + *filters: Statement filters to apply. + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: The first row of the query results. """ @@ -649,20 +641,15 @@ async def select_one( await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] result = self.check_not_found(result) # pyright: ignore[reportUnknownArgumentType] - # Get column names column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] - # Always return dictionaries - return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, @@ -673,7 +660,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -683,16 +669,23 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": - """Fetch one row from the database. + """Fetch one row from the database or return None if no rows found. + + Args: + sql: The SQL query string. + parameters: The parameters for the query (dict, tuple, list, or None). + *filters: Statement filters to apply. + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - The first row of the query results. + The first row of the query results, or None if no results found. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) @@ -700,24 +693,16 @@ async def select_one_or_none( async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - if result is None: return None - - # Get column names column_names = [col[0] for col in cursor.description or []] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - - if schema_type is not None: - return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] - # Always return dictionaries - return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, @@ -728,7 +713,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[T]", @@ -738,7 +722,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -763,7 +746,7 @@ async def select_value( async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - result = self.check_not_found(result) + result = self.check_not_found(result) # pyright: ignore[reportUnknownArgumentType] if schema_type is None: return result[0] # pyright: ignore[reportUnknownArgumentType] @@ -774,7 +757,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, @@ -785,7 +767,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[T]", @@ -795,7 +776,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -831,12 +811,11 @@ async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, **kwargs: Any, ) -> int: - """Execute an insert, update, or delete statement. + """Insert, update, or delete data from the database. Args: sql: The SQL statement to execute. @@ -846,21 +825,20 @@ async def insert_update_delete( **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - The number of rows affected by the statement. + Row count affected by the operation. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] - return cursor.rowcount # pyright: ignore[reportUnknownMemberType] + return cursor.rowcount # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] @overload async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: None = None, @@ -871,7 +849,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -881,7 +858,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -889,8 +865,16 @@ async def insert_update_delete_returning( ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": """Insert, update, or delete data from the database and return result. + Args: + sql: The SQL statement with RETURNING clause. + parameters: The parameters for the statement (dict, tuple, list, or None). + *filters: Statement filters to apply. + connection: Optional connection override. + schema_type: Optional schema class for the result. + **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. + Returns: - The first row of results. + The returned row data, as either a model instance or dictionary. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) @@ -898,7 +882,6 @@ async def insert_update_delete_returning( async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] result = await cursor.fetchone() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - if result is None: return None @@ -907,14 +890,12 @@ async def insert_update_delete_returning( if schema_type is not None: return cast("ModelDTOT", schema_type(**dict(zip(column_names, result)))) # pyright: ignore[reportUnknownArgumentType] - # Always return dictionaries return dict(zip(column_names, result)) # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] async def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[OracleAsyncConnection]" = None, **kwargs: Any, ) -> str: diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index c01e300..ba5babc 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -17,7 +17,7 @@ from sqlspec.typing import is_dict if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Mapping, Sequence from psqlpy import QueryResult @@ -64,8 +64,7 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], dict[str, Any]]]]": """Process SQL and parameters for psqlpy. @@ -82,15 +81,19 @@ def _process_sql_params( Raises: SQLParsingError: If the SQL parsing fails. """ - # Handle scalar parameter by converting to a single-item tuple - if parameters is not None and not isinstance(parameters, (list, tuple, dict)): - parameters = (parameters,) - - # Create and process the statement - statement = SQLStatement(sql=sql, parameters=parameters, kwargs=kwargs, dialect=self.dialect) - - # Apply any filters - for filter_obj in filters: + data_params_for_statement: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + else: + data_params_for_statement = parameters + if data_params_for_statement is not None and not isinstance(data_params_for_statement, (list, tuple, dict)): + data_params_for_statement = (data_params_for_statement,) + statement = SQLStatement(sql, data_params_for_statement, kwargs=kwargs, dialect=self.dialect) + + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) # Process the statement @@ -162,8 +165,7 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -173,8 +175,7 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -183,8 +184,7 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, @@ -217,8 +217,7 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -228,8 +227,7 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -238,8 +236,7 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, @@ -275,8 +272,7 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -286,8 +282,7 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -296,8 +291,7 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, @@ -332,8 +326,7 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -343,8 +336,7 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[T]", **kwargs: Any, @@ -353,8 +345,7 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, @@ -388,8 +379,7 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -399,8 +389,7 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[T]", **kwargs: Any, @@ -409,8 +398,7 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, @@ -446,8 +434,7 @@ async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, **kwargs: Any, ) -> int: @@ -477,8 +464,7 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -488,8 +474,7 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -498,16 +483,15 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[PsqlpyConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, ) -> "Union[ModelDTOT, dict[str, Any]]": - """Insert, update, or delete data from the database and return result. + """Insert, update, or delete data with RETURNING clause. Args: - sql: The SQL statement to execute. + sql: The SQL statement with RETURNING clause. parameters: The parameters for the statement (dict, tuple, list, or None). *filters: Statement filters to apply. connection: Optional connection override. @@ -515,15 +499,15 @@ async def insert_update_delete_returning( **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. Returns: - The first row of results. + The returned row data, as either a model instance or dictionary. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) parameters = parameters or [] - result = await connection.execute(sql, parameters=parameters) - dict_results = result.result() + result = await connection.fetch(sql, parameters=parameters) + dict_results = result.result() if not dict_results: self.check_not_found(None) @@ -533,7 +517,6 @@ async def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[PsqlpyConnection]" = None, **kwargs: Any, ) -> str: diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 1fd00e8..7fe86e2 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -8,14 +8,14 @@ from sqlspec.base import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol from sqlspec.exceptions import ParameterStyleMismatchError +from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin from sqlspec.statement import SQLStatement from sqlspec.typing import is_dict if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator, Sequence + from collections.abc import AsyncGenerator, Generator, Mapping, Sequence - from sqlspec.filters import StatementFilter from sqlspec.typing import ModelDTOT, StatementParameterType, T logger = logging.getLogger("sqlspec") @@ -38,7 +38,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -56,10 +55,20 @@ def _process_sql_params( Returns: A tuple of (sql, parameters) ready for execution. """ - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) + data_params_for_statement: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + else: + data_params_for_statement = parameters + if data_params_for_statement is not None and not isinstance(data_params_for_statement, (list, tuple, dict)): + data_params_for_statement = (data_params_for_statement,) + statement = SQLStatement(sql, data_params_for_statement, kwargs=kwargs, dialect=self.dialect) # Apply all statement filters - for filter_obj in filters: + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) processed_sql, processed_params, _ = statement.process() @@ -118,7 +127,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, @@ -129,7 +137,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -139,7 +146,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", schema_type: "Optional[type[ModelDTOT]]" = None, connection: "Optional[PsycopgSyncConnection]" = None, @@ -165,7 +171,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, @@ -176,7 +181,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -186,7 +190,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -201,17 +204,15 @@ def select_one( sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - row = cursor.fetchone() - row = self.check_not_found(row) - - return self.to_schema(cast("dict[str, Any]", row), schema_type=schema_type) + result = cursor.fetchone() + result = self.check_not_found(result) + return self.to_schema(cast("dict[str, Any]", result), schema_type=schema_type) @overload def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, @@ -222,7 +223,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -232,7 +232,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -241,23 +240,22 @@ def select_one_or_none( """Fetch one row from the database. Returns: - The first row of the query results. + The first row of the query results, or None if no results. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - row = cursor.fetchone() - if row is None: + result = cursor.fetchone() + if result is None: return None - return self.to_schema(cast("dict[str, Any]", row), schema_type=schema_type) + return self.to_schema(cast("dict[str, Any]", result), schema_type=schema_type) @overload def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, @@ -268,7 +266,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[T]", @@ -278,7 +275,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -287,26 +283,25 @@ def select_value( """Fetch a single value from the database. Returns: - The first value from the first row of results, or None if no results. + The first value from the first row of results. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - row = cursor.fetchone() - row = self.check_not_found(row) - val = next(iter(row.values())) if row else None - val = self.check_not_found(val) - if schema_type is not None: - return schema_type(val) # type: ignore[call-arg] - return val + result = cursor.fetchone() + result = self.check_not_found(result) + + value = next(iter(result.values())) # Get the first value from the row + if schema_type is None: + return value + return schema_type(value) # type: ignore[call-arg] @overload def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, @@ -317,7 +312,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[T]", @@ -327,7 +321,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -342,29 +335,27 @@ def select_value_or_none( sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) - row = cursor.fetchone() - if row is None: - return None - val = next(iter(row.values())) if row else None - if val is None: + result = cursor.fetchone() + if result is None: return None - if schema_type is not None: - return schema_type(val) # type: ignore[call-arg] - return val + + value = next(iter(result.values())) # Get the first value from the row + if schema_type is None: + return value + return schema_type(value) # type: ignore[call-arg] def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, **kwargs: Any, ) -> int: - """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. + """Insert, update, or delete data from the database. Returns: - The number of rows affected by the operation. + Row count affected by the operation. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) @@ -377,7 +368,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: None = None, @@ -388,7 +378,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -398,33 +387,28 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgSyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Insert, update, or delete data from the database and return result. + ) -> "Union[ModelDTOT, dict[str, Any]]": + """Insert, update, or delete data with RETURNING clause. Returns: - The first row of results. + The returned row data, as either a model instance or dictionary. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) with self._with_cursor(connection) as cursor: cursor.execute(sql, parameters) result = cursor.fetchone() - - if result is None: - return None - + result = self.check_not_found(result) return self.to_schema(cast("dict[str, Any]", result), schema_type=schema_type) def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[PsycopgSyncConnection]" = None, **kwargs: Any, ) -> str: @@ -468,7 +452,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, @@ -479,7 +462,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -489,7 +471,6 @@ async def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", schema_type: "Optional[type[ModelDTOT]]" = None, connection: "Optional[PsycopgAsyncConnection]" = None, @@ -507,6 +488,7 @@ async def select( results = await cursor.fetchall() if not results: return [] + return self.to_schema(cast("Sequence[dict[str, Any]]", results), schema_type=schema_type) @overload @@ -514,7 +496,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, @@ -525,7 +506,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -535,7 +515,6 @@ async def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -550,17 +529,15 @@ async def select_one( sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) - row = await cursor.fetchone() - row = self.check_not_found(row) - - return self.to_schema(cast("dict[str, Any]", row), schema_type=schema_type) + result = await cursor.fetchone() + result = self.check_not_found(result) + return self.to_schema(cast("dict[str, Any]", result), schema_type=schema_type) @overload async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, @@ -571,7 +548,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -581,7 +557,6 @@ async def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", schema_type: "Optional[type[ModelDTOT]]" = None, connection: "Optional[PsycopgAsyncConnection]" = None, @@ -590,25 +565,22 @@ async def select_one_or_none( """Fetch one row from the database. Returns: - The first row of the query results. + The first row of the query results, or None if no results. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) - row = await cursor.fetchone() - if row is None: + result = await cursor.fetchone() + if result is None: return None - - # Use self.to_schema from ResultConverter mixin - return self.to_schema(cast("dict[str, Any]", row), schema_type=schema_type) + return self.to_schema(cast("dict[str, Any]", result), schema_type=schema_type) @overload async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, @@ -619,7 +591,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[T]", @@ -629,7 +600,6 @@ async def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -638,26 +608,25 @@ async def select_value( """Fetch a single value from the database. Returns: - The first value from the first row of results, or None if no results. + The first value from the first row of results. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) - row = await cursor.fetchone() - row = self.check_not_found(row) - val = next(iter(row.values())) if row else None - val = self.check_not_found(val) - if schema_type is not None: - return schema_type(val) # type: ignore[call-arg] - return val + result = await cursor.fetchone() + result = self.check_not_found(result) + + value = next(iter(result.values())) # Get the first value from the row + if schema_type is None: + return value + return schema_type(value) # type: ignore[call-arg] @overload async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, @@ -668,7 +637,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[T]", @@ -678,7 +646,6 @@ async def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -693,29 +660,27 @@ async def select_value_or_none( sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) - row = await cursor.fetchone() - if row is None: - return None - val = next(iter(row.values())) if row else None - if val is None: + result = await cursor.fetchone() + if result is None: return None - if schema_type is not None: - return schema_type(val) # type: ignore[call-arg] - return val + + value = next(iter(result.values())) # Get the first value from the row + if schema_type is None: + return value + return schema_type(value) # type: ignore[call-arg] async def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, **kwargs: Any, ) -> int: - """Execute an INSERT, UPDATE, or DELETE query and return the number of affected rows. + """Insert, update, or delete data from the database. Returns: - The number of rows affected by the operation. + Row count affected by the operation. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) @@ -728,7 +693,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: None = None, @@ -739,7 +703,6 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "type[ModelDTOT]", @@ -749,32 +712,28 @@ async def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[PsycopgAsyncConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, - ) -> "Optional[Union[dict[str, Any], ModelDTOT]]": - """Insert, update, or delete data from the database and return result. + ) -> "Union[ModelDTOT, dict[str, Any]]": + """Insert, update, or delete data with RETURNING clause. Returns: - The first row of results. + The returned row data, as either a model instance or dictionary. """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) async with self._with_cursor(connection) as cursor: await cursor.execute(sql, parameters) result = await cursor.fetchone() - if result is None: - return None - + result = self.check_not_found(result) return self.to_schema(cast("dict[str, Any]", result), schema_type=schema_type) async def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[PsycopgAsyncConnection]" = None, **kwargs: Any, ) -> str: diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index d4b8975..aa59438 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -5,14 +5,14 @@ from typing import TYPE_CHECKING, Any, Optional, Union, overload from sqlspec.base import SyncDriverAdapterProtocol +from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin from sqlspec.statement import SQLStatement from sqlspec.typing import is_dict if TYPE_CHECKING: - from collections.abc import Generator, Sequence + from collections.abc import Generator, Mapping, Sequence - from sqlspec.filters import StatementFilter from sqlspec.typing import ModelDTOT, StatementParameterType, T __all__ = ("SqliteConnection", "SqliteDriver") @@ -51,7 +51,6 @@ def _process_sql_params( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", **kwargs: Any, ) -> "tuple[str, Optional[Union[tuple[Any, ...], list[Any], dict[str, Any]]]]": @@ -71,9 +70,19 @@ def _process_sql_params( A tuple of (processed SQL, processed parameters). """ # Create a SQLStatement with SQLite dialect - statement = SQLStatement(sql, parameters, kwargs=kwargs, dialect=self.dialect) - - for filter_obj in filters: + data_params_for_statement: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + else: + data_params_for_statement = parameters + if data_params_for_statement is not None and not isinstance(data_params_for_statement, (list, tuple, dict)): + data_params_for_statement = (data_params_for_statement,) + statement = SQLStatement(sql, data_params_for_statement, kwargs=kwargs, dialect=self.dialect) + + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) processed_sql, processed_params, _ = statement.process() @@ -95,7 +104,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: None = None, @@ -106,7 +114,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -116,7 +123,6 @@ def select( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -146,7 +152,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: None = None, @@ -157,7 +162,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -167,7 +171,6 @@ def select_one( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -186,10 +189,7 @@ def select_one( cursor.execute(sql, parameters or []) result = cursor.fetchone() result = self.check_not_found(result) - - # Get column names column_names = [column[0] for column in cursor.description] - return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload @@ -197,7 +197,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: None = None, @@ -208,7 +207,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -218,7 +216,6 @@ def select_one_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -238,9 +235,7 @@ def select_one_or_none( if result is None: return None - # Get column names column_names = [column[0] for column in cursor.description] - return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload @@ -248,7 +243,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: None = None, @@ -259,7 +253,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "type[T]", @@ -269,7 +262,6 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -287,8 +279,6 @@ def select_value( cursor.execute(sql, parameters or []) result = cursor.fetchone() result = self.check_not_found(result) - - # Return first value from the row result_value = result[0] if schema_type is None: return result_value @@ -299,7 +289,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: None = None, @@ -310,7 +299,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "type[T]", @@ -320,7 +308,6 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[T]]" = None, @@ -339,8 +326,6 @@ def select_value_or_none( result = cursor.fetchone() if result is None: return None - - # Return first value from the row result_value = result[0] if schema_type is None: return result_value @@ -350,7 +335,6 @@ def insert_update_delete( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, **kwargs: Any, @@ -372,7 +356,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: None = None, @@ -383,7 +366,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "type[ModelDTOT]", @@ -393,7 +375,6 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, *filters: "StatementFilter", connection: "Optional[SqliteConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, @@ -418,7 +399,6 @@ def execute_script( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, connection: "Optional[SqliteConnection]" = None, **kwargs: Any, ) -> str: From 379b15a35fb85870c70131f414b993c35e3a10b3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 12 May 2025 18:20:39 +0000 Subject: [PATCH 3/3] feat: more fixes --- sqlspec/adapters/adbc/driver.py | 64 +++++++------ sqlspec/adapters/aiosqlite/driver.py | 134 ++++++++++++--------------- sqlspec/adapters/bigquery/driver.py | 2 +- sqlspec/adapters/oracledb/driver.py | 13 ++- sqlspec/mixins.py | 19 ++-- uv.lock | 8 +- 6 files changed, 111 insertions(+), 129 deletions(-) diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 711ab60..7cd4987 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -1,7 +1,7 @@ import contextlib import logging import re -from collections.abc import Generator, Sequence +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, overload @@ -9,7 +9,7 @@ from sqlglot import exp as sqlglot_exp from sqlspec.base import SyncDriverAdapterProtocol -from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError +from sqlspec.exceptions import SQLParsingError from sqlspec.filters import StatementFilter from sqlspec.mixins import ResultConverter, SQLTranslatorMixin, SyncArrowBulkOperationsMixin from sqlspec.statement import SQLStatement @@ -107,14 +107,24 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915 **kwargs: Additional keyword arguments. Raises: - ParameterStyleMismatchError: If positional parameters are mixed with keyword arguments. SQLParsingError: If the SQL statement cannot be parsed. Returns: A tuple of (sql, parameters) ready for execution. """ + passed_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + combined_filters_list: list[StatementFilter] = list(filters) + + if parameters is not None: + if isinstance(parameters, StatementFilter): + combined_filters_list.insert(0, parameters) + # passed_parameters remains None + else: + # If parameters is not a StatementFilter, it's actual data parameters. + passed_parameters = parameters + # Special handling for SQLite with non-dict parameters and named placeholders - if self.dialect == "sqlite" and parameters is not None and not is_dict(parameters): + if self.dialect == "sqlite" and passed_parameters is not None and not is_dict(passed_parameters): # First mask out comments and strings to avoid detecting parameters in those comments = list(SQL_COMMENT_PATTERN.finditer(sql)) strings = list(SQL_STRING_PATTERN.finditer(sql)) @@ -135,26 +145,15 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915 param_positions.sort(reverse=True) for start, end in param_positions: sql = sql[:start] + "?" + sql[end:] - if not isinstance(parameters, (list, tuple)): - return sql, (parameters,) - return sql, tuple(parameters) + if not isinstance(passed_parameters, (list, tuple)): + passed_parameters = (passed_parameters,) + passed_parameters = tuple(passed_parameters) # Standard processing for all other cases - merged_params = parameters - if kwargs: - if is_dict(parameters): - merged_params = {**parameters, **kwargs} - elif parameters is not None: - msg = "Cannot mix positional parameters with keyword arguments for adbc driver." - raise ParameterStyleMismatchError(msg) - else: - merged_params = kwargs + statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect) - # 2. Create SQLStatement with dialect and process - statement = SQLStatement(sql, merged_params, dialect=self.dialect) - - # Apply any filters - for filter_obj in filters: + # Apply any filters from combined_filters_list + for filter_obj in combined_filters_list: statement = statement.apply_filter(filter_obj) processed_sql, processed_params, parsed_expr = statement.process() @@ -442,7 +441,7 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -452,7 +451,7 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "type[T]", **kwargs: Any, @@ -461,7 +460,7 @@ def select_value( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, @@ -495,7 +494,7 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -505,7 +504,7 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "type[T]", **kwargs: Any, @@ -514,7 +513,7 @@ def select_value_or_none( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[T]]" = None, **kwargs: Any, @@ -576,7 +575,7 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: None = None, **kwargs: Any, @@ -586,7 +585,7 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "type[ModelDTOT]", **kwargs: Any, @@ -595,7 +594,7 @@ def insert_update_delete_returning( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, schema_type: "Optional[type[ModelDTOT]]" = None, **kwargs: Any, @@ -652,12 +651,11 @@ def execute_script( # --- Arrow Bulk Operations --- - def select_arrow( # pyright: ignore[reportUnknownParameterType] + def select_arrow( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *filters: StatementFilter, + *filters: "StatementFilter", connection: "Optional[AdbcConnection]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index fd11d05..e15c8f6 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -162,18 +162,13 @@ async def select( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - results = await cursor.fetchall() - if not results: - return [] - - # Get column names - column_names = [column[0] for column in cursor.description] - - # Convert to dicts first - dict_results = [dict(zip(column_names, row)) for row in results] - return self.to_schema(dict_results, schema_type=schema_type) + async with self._with_cursor(connection) as cursor: + await cursor.execute(sql, parameters or ()) + results = await cursor.fetchall() + if not results: + return [] + column_names = [column[0] for column in cursor.description] + return self.to_schema([dict(zip(column_names, row)) for row in results], schema_type=schema_type) @overload async def select_one( @@ -212,17 +207,14 @@ async def select_one( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - result = await cursor.fetchone() - result = self.check_not_found(result) - - # Get column names - column_names = [column[0] for column in cursor.description] + async with self._with_cursor(connection) as cursor: + await cursor.execute(sql, parameters or ()) + result = await cursor.fetchone() + result = self.check_not_found(result) - # Convert to dict and then use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + # Get column names + column_names = [column[0] for column in cursor.description] + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload async def select_one_or_none( @@ -261,18 +253,13 @@ async def select_one_or_none( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - result = await cursor.fetchone() - if result is None: - return None - - # Get column names - column_names = [column[0] for column in cursor.description] - - # Convert to dict and then use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + async with self._with_cursor(connection) as cursor: + await cursor.execute(sql, parameters or ()) + result = await cursor.fetchone() + if result is None: + return None + column_names = [column[0] for column in cursor.description] + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) @overload async def select_value( @@ -311,16 +298,16 @@ async def select_value( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - result = await cursor.fetchone() - result = self.check_not_found(result) + async with self._with_cursor(connection) as cursor: + await cursor.execute(sql, parameters or ()) + result = await cursor.fetchone() + result = self.check_not_found(result) - # Return first value from the row - result_value = result[0] - if schema_type is None: - return result_value - return schema_type(result_value) # type: ignore[call-arg] + # Return first value from the row + result_value = result[0] + if schema_type is None: + return result_value + return schema_type(result_value) # type: ignore[call-arg] @overload async def select_value_or_none( @@ -359,17 +346,16 @@ async def select_value_or_none( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - result = await cursor.fetchone() - if result is None: - return None - - # Return first value from the row - result_value = result[0] - if schema_type is None: - return result_value - return schema_type(result_value) # type: ignore[call-arg] + async with self._with_cursor(connection) as cursor: + # Execute the query + await cursor.execute(sql, parameters or ()) + result = await cursor.fetchone() + if result is None: + return None + result_value = result[0] + if schema_type is None: + return result_value + return schema_type(result_value) # type: ignore[call-arg] async def insert_update_delete( self, @@ -386,11 +372,10 @@ async def insert_update_delete( """ connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - await connection.commit() - return cursor.rowcount + async with self._with_cursor(connection) as cursor: + # Execute the query + await cursor.execute(sql, parameters or ()) + return cursor.rowcount @overload async def insert_update_delete_returning( @@ -429,20 +414,13 @@ async def insert_update_delete_returning( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) - # Execute the query - cursor = await connection.execute(sql, parameters or ()) - result = await cursor.fetchone() - await connection.commit() - await cursor.close() - - result = self.check_not_found(result) - - # Get column names - column_names = [column[0] for column in cursor.description] - - # Convert to dict and then use ResultConverter - dict_result = dict(zip(column_names, result)) - return self.to_schema(dict_result, schema_type=schema_type) + async with self._with_cursor(connection) as cursor: + # Execute the query + await cursor.execute(sql, parameters or ()) + result = await cursor.fetchone() + result = self.check_not_found(result) + column_names = [column[0] for column in cursor.description] + return self.to_schema(dict(zip(column_names, result)), schema_type=schema_type) async def execute_script( self, @@ -459,10 +437,12 @@ async def execute_script( connection = self._connection(connection) sql, parameters = self._process_sql_params(sql, parameters, **kwargs) - # Execute the script - await connection.executescript(sql) - await connection.commit() - return "Script executed successfully." + async with self._with_cursor(connection) as cursor: + if parameters: + await cursor.execute(sql, parameters) + else: + await cursor.executescript(sql) + return "DONE" def _connection(self, connection: "Optional[AiosqliteConnection]" = None) -> "AiosqliteConnection": """Get the connection to use for the operation. diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 9be92e1..1ca4814 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -166,7 +166,7 @@ def _run_query_job( if job_config: final_job_config = job_config elif self._default_query_job_config: - final_job_config = QueryJobConfig.from_api_repr(self._default_query_job_config.to_api_repr()) # type: ignore[no-untyped-call] + final_job_config = QueryJobConfig.from_api_repr(self._default_query_job_config.to_api_repr()) # type: ignore[assignment] else: final_job_config = QueryJobConfig() diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 94f5030..b89da79 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -485,8 +485,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[OracleSyncConnection]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] @@ -497,7 +496,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) results = connection.fetch_df_all(sql, parameters) return cast("ArrowTable", ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names())) # pyright: ignore @@ -917,12 +916,11 @@ async def execute_script( await cursor.execute(sql, parameters) # pyright: ignore[reportUnknownMemberType] return str(cursor.rowcount) # pyright: ignore[reportUnknownMemberType] - async def select_arrow( # pyright: ignore[reportUnknownParameterType] + async def select_arrow( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[OracleAsyncConnection]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType] @@ -931,6 +929,7 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] Args: sql: The SQL query string. parameters: Parameters for the query. + filters: Statement filters to apply. connection: Optional connection override. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. @@ -939,7 +938,7 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] """ connection = self._connection(connection) - sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs) results = await connection.fetch_df_all(sql, parameters) return ArrowTable.from_arrays(arrays=results.column_arrays(), names=results.column_names()) # pyright: ignore diff --git a/sqlspec/mixins.py b/sqlspec/mixins.py index c59fa80..97c59d6 100644 --- a/sqlspec/mixins.py +++ b/sqlspec/mixins.py @@ -34,6 +34,7 @@ ) if TYPE_CHECKING: + from sqlspec.filters import StatementFilter from sqlspec.typing import ArrowTable __all__ = ( @@ -51,12 +52,11 @@ class SyncArrowBulkOperationsMixin(Generic[ConnectionT]): __supports_arrow__: "ClassVar[bool]" = True @abstractmethod - def select_arrow( # pyright: ignore[reportUnknownParameterType] + def select_arrow( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] @@ -65,6 +65,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType] Args: sql: The SQL query string. parameters: Parameters for the query. + filters: Optional filters to apply to the query. connection: Optional connection override. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. @@ -80,12 +81,11 @@ class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]): __supports_arrow__: "ClassVar[bool]" = True @abstractmethod - async def select_arrow( # pyright: ignore[reportUnknownParameterType] + async def select_arrow( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, **kwargs: Any, ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType] @@ -94,6 +94,7 @@ async def select_arrow( # pyright: ignore[reportUnknownParameterType] Args: sql: The SQL query string. parameters: Parameters for the query. + filters: Optional filters to apply to the query. connection: Optional connection override. **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict. @@ -111,8 +112,7 @@ def select_to_parquet( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, **kwargs: Any, ) -> None: @@ -128,8 +128,7 @@ async def select_to_parquet( self, sql: str, parameters: "Optional[StatementParameterType]" = None, - /, - *, + *filters: "StatementFilter", connection: "Optional[ConnectionT]" = None, **kwargs: Any, ) -> None: diff --git a/uv.lock b/uv.lock index 88c7df7..df262e8 100644 --- a/uv.lock +++ b/uv.lock @@ -3745,6 +3745,10 @@ performance = [ { name = "msgspec" }, { name = "sqlglot", extra = ["rs"] }, ] +polars = [ + { name = "polars" }, + { name = "pyarrow" }, +] psqlpy = [ { name = "psqlpy" }, ] @@ -3887,9 +3891,11 @@ requires-dist = [ { name = "msgspec", marker = "extra == 'performance'" }, { name = "oracledb", marker = "extra == 'oracledb'" }, { name = "orjson", marker = "extra == 'orjson'" }, + { name = "polars", marker = "extra == 'polars'" }, { name = "psqlpy", marker = "extra == 'psqlpy'" }, { name = "psycopg", extras = ["binary", "pool"], marker = "extra == 'psycopg'" }, { name = "pyarrow", marker = "extra == 'adbc'" }, + { name = "pyarrow", marker = "extra == 'polars'" }, { name = "pydantic", marker = "extra == 'pydantic'" }, { name = "pydantic-extra-types", marker = "extra == 'pydantic'" }, { name = "pymssql", marker = "extra == 'pymssql'" }, @@ -3899,7 +3905,7 @@ requires-dist = [ { name = "typing-extensions" }, { name = "uuid-utils", marker = "extra == 'uuid'", specifier = ">=0.6.1" }, ] -provides-extras = ["adbc", "aioodbc", "aiosqlite", "asyncmy", "asyncpg", "bigquery", "duckdb", "fastapi", "flask", "litestar", "msgspec", "nanoid", "oracledb", "orjson", "performance", "psqlpy", "psycopg", "pydantic", "pymssql", "pymysql", "spanner", "uuid"] +provides-extras = ["adbc", "aioodbc", "aiosqlite", "asyncmy", "asyncpg", "bigquery", "duckdb", "fastapi", "flask", "litestar", "msgspec", "nanoid", "oracledb", "orjson", "performance", "polars", "psqlpy", "psycopg", "pydantic", "pymssql", "pymysql", "spanner", "uuid"] [package.metadata.requires-dev] build = [{ name = "bump-my-version" }]