Skip to content

Commit f614041

Browse files
committed
wip
1 parent ebdcd1a commit f614041

File tree

9 files changed

+426
-231
lines changed

9 files changed

+426
-231
lines changed

sqlspec/adapters/asyncpg/config.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def driver_type(self) -> type[AsyncpgDriver]: # type: ignore[override]
212212

213213
@property
214214
def connection_config_dict(self) -> dict[str, Any]:
215-
"""Return the connection configuration as a dict.
215+
"""Return the connection configuration as a dict for asyncpg.connect().
216+
217+
This method filters out pool-specific parameters that are not valid for asyncpg.connect().
216218
217219
Raises:
218220
ImproperConfigurationError: If the configuration is invalid.
@@ -221,19 +223,60 @@ def connection_config_dict(self) -> dict[str, Any]:
221223
merged_config = {**self.connection_config, **self.pool_config}
222224
config = {k: v for k, v in merged_config.items() if v is not Empty}
223225

226+
# Define pool-specific parameters that should not be passed to asyncpg.connect()
227+
pool_only_params = {
228+
"min_size",
229+
"max_size",
230+
"max_queries",
231+
"max_inactive_connection_lifetime",
232+
"setup",
233+
"init",
234+
"loop",
235+
"connection_class",
236+
"record_class",
237+
}
238+
239+
# Filter out pool-specific parameters for connection creation
240+
connection_config = {k: v for k, v in config.items() if k not in pool_only_params}
241+
224242
# Validate essential connection info
225-
has_dsn = config.get("dsn") is not None
226-
has_host = config.get("host") is not None
243+
has_dsn = connection_config.get("dsn") is not None
244+
has_host = connection_config.get("host") is not None
227245

228246
if not (has_dsn or has_host):
229-
msg = f"AsyncPG configuration requires either 'dsn' or 'host' in pool_config. Current config: {config}"
247+
msg = f"AsyncPG configuration requires either 'dsn' or 'host' in pool_config. Current config: {connection_config}"
230248
raise ImproperConfigurationError(msg)
231249

250+
# Set SSL to False by default for non-SSL environments (asyncpg 0.22.0+ defaults to 'prefer')
251+
if "ssl" not in connection_config:
252+
connection_config["ssl"] = False
253+
254+
return connection_config
255+
256+
@property
257+
def pool_config_dict(self) -> dict[str, Any]:
258+
"""Return the full pool configuration as a dict for asyncpg.create_pool().
259+
260+
Returns:
261+
A dictionary containing all pool configuration parameters.
262+
"""
263+
# Merge connection_config into pool_config, with pool_config taking precedence
264+
merged_config = {**self.connection_config, **self.pool_config}
265+
config = {k: v for k, v in merged_config.items() if v is not Empty}
266+
267+
# Set SSL to False by default for non-SSL environments
268+
if "ssl" not in config:
269+
config["ssl"] = False
270+
271+
# Set reasonable defaults for pool parameters to prevent connection issues
272+
if "max_inactive_connection_lifetime" not in config:
273+
config["max_inactive_connection_lifetime"] = 300.0 # 5 minutes
274+
232275
return config
233276

234277
async def _create_pool_impl(self) -> "Pool[Record]":
235278
"""Create the actual async connection pool."""
236-
pool_args = self.connection_config_dict
279+
pool_args = self.pool_config_dict
237280
return await asyncpg_create_pool(**pool_args)
238281

239282
async def _close_pool_impl(self) -> None:

sqlspec/adapters/asyncpg/driver.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,13 @@ def _get_placeholder_style(self) -> ParameterStyle:
6161
async def _execute_impl(
6262
self,
6363
statement: SQL,
64-
parameters: Optional[SQLParameterType] = None,
6564
connection: Optional[AsyncpgConnection] = None,
66-
config: Optional[SQLConfig] = None,
67-
is_many: bool = False,
68-
is_script: bool = False,
6965
**kwargs: Any,
7066
) -> Any:
7167
async with instrument_operation_async(self, "asyncpg_execute", "database"):
7268
conn = self._connection(connection)
73-
if config is not None and config != statement.config:
74-
statement = statement.copy(config=config)
7569

76-
if is_script:
70+
if statement.is_script:
7771
final_sql = statement.to_sql(placeholder_style=ParameterStyle.STATIC)
7872
if self.instrumentation_config.log_queries:
7973
logger.debug("Executing SQL script: %s", final_sql)
@@ -84,37 +78,30 @@ async def _execute_impl(
8478
if self.instrumentation_config.log_queries:
8579
logger.debug("Executing SQL: %s", final_sql)
8680

87-
if is_many:
88-
args_list: list[Sequence[Any]] = []
89-
if parameters and isinstance(parameters, Sequence):
90-
for param_set in parameters:
81+
if statement.is_many:
82+
params_list: list[tuple[Any, ...]] = []
83+
if statement.parameters and isinstance(statement.parameters, Sequence):
84+
for param_set in statement.parameters:
9185
if isinstance(param_set, (list, tuple)):
92-
args_list.append(param_set)
93-
elif isinstance(param_set, dict):
94-
# Convert dict to ordered args based on statement parameters
95-
ordered_params = statement.get_parameters(style=self._get_placeholder_style())
96-
if isinstance(ordered_params, (list, tuple)):
97-
args_list.append(ordered_params)
98-
else:
99-
args_list.append((ordered_params,) if ordered_params is not None else ())
86+
params_list.append(tuple(param_set))
10087
elif param_set is None:
101-
args_list.append(())
88+
params_list.append(())
10289
else:
103-
args_list.append((param_set,))
90+
params_list.append((param_set,))
10491

105-
if self.instrumentation_config.log_parameters and args_list:
106-
logger.debug("Query parameters (batch): %s", args_list)
92+
if self.instrumentation_config.log_parameters and params_list:
93+
logger.debug("Query parameters (batch): %s", params_list)
10794

108-
await conn.executemany(final_sql, args_list)
109-
return len(args_list)
95+
return await conn.executemany(final_sql, params_list)
11096
# Single execution
111-
ordered_params = statement.get_parameters(style=self._get_placeholder_style())
97+
# Use the statement's already-processed parameters directly
98+
processed_params = statement._merged_parameters if hasattr(statement, "_merged_parameters") else None
11299
args: list[Any] = []
113100

114-
if isinstance(ordered_params, (list, tuple)):
115-
args.extend(ordered_params)
116-
elif ordered_params is not None:
117-
args.append(ordered_params)
101+
if isinstance(processed_params, (list, tuple)):
102+
args.extend(processed_params)
103+
elif processed_params is not None:
104+
args.append(processed_params)
118105

119106
if self.instrumentation_config.log_parameters and args:
120107
logger.debug("Query parameters: %s", args)
@@ -246,13 +233,14 @@ async def select_to_arrow(
246233
raise TypeError(msg)
247234

248235
final_sql = stmt_obj.to_sql(placeholder_style=self._get_placeholder_style())
249-
ordered_params = stmt_obj.get_parameters(style=self._get_placeholder_style())
236+
# Use the statement's already-processed parameters instead of calling get_parameters()
237+
processed_params = stmt_obj._merged_parameters if hasattr(stmt_obj, "_merged_parameters") else None
250238

251239
args: list[Any] = []
252-
if isinstance(ordered_params, (list, tuple)):
253-
args.extend(ordered_params)
254-
elif ordered_params is not None:
255-
args.append(ordered_params)
240+
if isinstance(processed_params, (list, tuple)):
241+
args.extend(processed_params)
242+
elif processed_params is not None:
243+
args.append(processed_params)
256244

257245
records = await conn.fetch(final_sql, *args)
258246

sqlspec/adapters/duckdb/config.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,24 +262,28 @@ def connection_config_dict(self) -> dict[str, Any]:
262262
# Filter out empty values and prepare config for duckdb.connect()
263263
config_dict = {k: v for k, v in self.connection_config.items() if v is not Empty}
264264

265-
# Extract DuckDB-specific settings into the config parameter
266-
duckdb_settings = {}
267-
connection_params = {}
268-
269265
# Parameters that go directly to duckdb.connect()
270-
direct_params = {"database", "read_only", "config"}
266+
connection_params = {}
267+
duckdb_config_settings = {}
271268

269+
# Only database and read_only go directly to connect()
270+
# Everything else goes into the config dictionary
272271
for key, value in config_dict.items():
273-
if key in direct_params:
274-
connection_params[key] = value
272+
if key == "database":
273+
connection_params["database"] = value
274+
elif key == "read_only":
275+
connection_params["read_only"] = value
276+
elif key == "config":
277+
# If user provided a config dict, merge it
278+
if isinstance(value, dict):
279+
duckdb_config_settings.update(value)
275280
else:
276281
# All other parameters are DuckDB configuration settings
277-
duckdb_settings[key] = value
282+
duckdb_config_settings[key] = value
278283

279-
existing_config: dict[str, Any] = connection_params.get("config", {})
280-
if duckdb_settings:
281-
existing_config.update(duckdb_settings)
282-
connection_params["config"] = existing_config
284+
# Add the config dictionary if we have settings
285+
if duckdb_config_settings:
286+
connection_params["config"] = duckdb_config_settings
283287

284288
return connection_params
285289

sqlspec/adapters/duckdb/driver.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ def _execute_impl(
9999
final_exec_params = [list(p) if isinstance(p, (list, tuple)) else [p] for p in parameters]
100100
else:
101101
final_exec_params = []
102+
# Use provided parameters if available, otherwise get from statement
103+
elif parameters is not None:
104+
if isinstance(parameters, list):
105+
final_exec_params = parameters
106+
elif hasattr(parameters, "__iter__") and not isinstance(parameters, (str, bytes)):
107+
final_exec_params = list(parameters)
108+
else:
109+
final_exec_params = [parameters]
102110
else:
103111
single_params = statement.get_parameters(style=self._get_placeholder_style())
104112
if single_params is not None:
@@ -122,9 +130,17 @@ def _execute_impl(
122130
return "SCRIPT EXECUTED"
123131
if is_many:
124132
cursor.executemany(final_sql, cast("list[list[Any]]", final_exec_params))
125-
else:
126-
cursor.execute(final_sql, final_exec_params or [])
127-
return cursor
133+
# For executemany, return cursor info for execute result
134+
return {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1}
135+
cursor.execute(final_sql, final_exec_params or [])
136+
137+
# For SELECT queries, fetch the data immediately since cursor will be closed
138+
if self.returns_rows(statement.expression):
139+
fetched_data = cursor.fetchall()
140+
column_names = [col[0] for col in cursor.description or []]
141+
return {"data": fetched_data, "columns": column_names}
142+
# For non-SELECT queries, return cursor info
143+
return {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1}
128144

129145
def _wrap_select_result(
130146
self,
@@ -134,9 +150,23 @@ def _wrap_select_result(
134150
**kwargs: Any,
135151
) -> Union[SelectResult[ModelDTOT], SelectResult[dict[str, Any]]]:
136152
with instrument_operation(self, "duckdb_wrap_select", "database"):
137-
cursor = raw_driver_result
138-
fetched_data = cursor.fetchall()
139-
column_names = [col[0] for col in cursor.description or []]
153+
# Handle the new dictionary format from _execute_impl
154+
if isinstance(raw_driver_result, dict) and "data" in raw_driver_result:
155+
fetched_data = raw_driver_result["data"]
156+
column_names = raw_driver_result["columns"]
157+
elif not isinstance(raw_driver_result, dict):
158+
# Fallback for backward compatibility (shouldn't happen with new implementation)
159+
if hasattr(raw_driver_result, "fetchall") and hasattr(raw_driver_result, "description"):
160+
fetched_data = raw_driver_result.fetchall()
161+
column_names = [col[0] for col in raw_driver_result.description or []]
162+
else:
163+
# Should not happen with current implementation
164+
fetched_data = []
165+
column_names = []
166+
else:
167+
# Should not happen with current implementation
168+
fetched_data = []
169+
column_names = []
140170

141171
# Convert to list of dicts
142172
rows_as_dicts: list[dict[str, Any]] = []
@@ -185,8 +215,12 @@ def _wrap_execute_result(
185215
operation_type=operation_type or "SCRIPT",
186216
)
187217

188-
cursor = raw_driver_result
189-
rows_affected = cursor.rowcount if hasattr(cursor, "rowcount") else -1
218+
# Handle the new dictionary format from _execute_impl
219+
if isinstance(raw_driver_result, dict) and "rowcount" in raw_driver_result:
220+
rows_affected = raw_driver_result["rowcount"]
221+
else:
222+
# Fallback for backward compatibility
223+
rows_affected = getattr(raw_driver_result, "rowcount", -1)
190224

191225
if self.instrumentation_config.log_results_count:
192226
logger.debug("Execute operation affected %d rows", rows_affected)

sqlspec/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Awaitable, Coroutine
66
from typing import (
77
TYPE_CHECKING,
8-
Annotated,
98
Any,
109
Optional,
1110
Union,
@@ -78,17 +77,17 @@ def _cleanup_pools(self) -> None:
7877
logger.info("Pool cleanup completed. Cleaned %d pools", cleaned_count)
7978

8079
@overload
81-
def add_config(self, config: "SyncConfigT") -> "Annotated[type[SyncConfigT], int]": # pyright: ignore[reportInvalidTypeVarUse]
80+
def add_config(self, config: "SyncConfigT") -> "type[SyncConfigT]": # pyright: ignore[reportInvalidTypeVarUse]
8281
...
8382

8483
@overload
85-
def add_config(self, config: "AsyncConfigT") -> "Annotated[type[AsyncConfigT], int]": # pyright: ignore[reportInvalidTypeVarUse]
84+
def add_config(self, config: "AsyncConfigT") -> "type[AsyncConfigT]": # pyright: ignore[reportInvalidTypeVarUse]
8685
...
8786

8887
def add_config(
8988
self,
9089
config: "Union[SyncConfigT, AsyncConfigT]",
91-
) -> "Union[Annotated[type[SyncConfigT], int], Annotated[type[AsyncConfigT], int]]": # pyright: ignore[reportInvalidTypeVarUse]
90+
) -> "type[Union[SyncConfigT, AsyncConfigT]]": # pyright: ignore[reportInvalidTypeVarUse]
9291
"""Add a configuration instance to the registry.
9392
9493
Args:

0 commit comments

Comments
 (0)