From 7c5bdc98940279549af2314480b103a1eb4da9f0 Mon Sep 17 00:00:00 2001 From: John Newbery Date: Thu, 24 Apr 2025 17:18:59 +0100 Subject: [PATCH 1/2] Add implementation of AsyncSession.run_sync() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, `sqlmodel.ext.asyncio.session.AsyncSession` doesn't implement `run_sync()`, which means that any call to `run_sync()` on a sqlmodel `AsyncSession` will be dispatched to the parent `sqlalchemy.ext.asyncio.AsyncSession`. The first argument to sqlalchemy's `AsyncSession.run_sync()` is a callable whose first argument is a `sqlalchemy.orm.Session` object. If we're using this in a repo that uses sqlmodel, we'll actually be passing a callable whose first argument is a `sqlmodel.orm.session.Session`. In practice this works fine - because `sqlmodel.orm.session.Session` is derived from `sqlalchemy.orm.Session`, the implementation of `sqlalchemy.ext.asyncio.AsyncSession.run_sync()` can use the sqlmodel `Session` object in place of the sqlalchemy `Session` object. However, static analysers will complain that the argument to `run_sync()` is of the wrong type. For example, here's a warning from pyright: ``` Pyright: Error: Argument of type "(session: Session, id: UUID) -> int" cannot be assigned to parameter "fn" of type "(Session, **_P@run_sync) -> _T@run_sync" in function "run_sync"   Type "(session: Session, id: UUID) -> int" is not assignable to type "(Session, id: UUID) -> int"     Parameter 1: type "Session" is incompatible with type "Session"       "sqlalchemy.orm.session.Session" is not assignable to "sqlmodel.orm.session.Session" [reportArgumentType] ``` This commit implements a `run_sync()` method on `sqlmodel.ext.asyncio.session.AsyncSession`, which casts the callable to the correct type before dispatching it to the base class. This satisfies the static type checks. --- sqlmodel/ext/asyncio/session.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 467d0bd84e..b69baf73d6 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,5 +1,6 @@ from typing import ( Any, + Callable, Dict, Mapping, Optional, @@ -17,15 +18,17 @@ from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession from sqlalchemy.ext.asyncio.result import _ensure_sync_result from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS +from sqlalchemy.orm import Session as _Session from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql.base import Executable as _Executable from sqlalchemy.util.concurrency import greenlet_spawn -from typing_extensions import deprecated +from typing_extensions import deprecated, Concatenate, ParamSpec from ...orm.session import Session from ...sql.base import Executable from ...sql.expression import Select, SelectOfScalar +_P = ParamSpec("_P") _TSelectParam = TypeVar("_TSelectParam", bound=Any) @@ -148,3 +151,17 @@ async def execute( # type: ignore _parent_execute_state=_parent_execute_state, _add_event=_add_event, ) + + async def run_sync( + self, + fn: Callable[Concatenate[Session, _P], _TSelectParam], + *arg: _P.args, + **kw: _P.kwargs, + ) -> _TSelectParam: + base_fn = cast(Callable[Concatenate[_Session, _P], _TSelectParam], fn) + + return await super().run_sync( + base_fn, + *arg, + **kw, + ) From 63341a105dff5b3559fee8d191152eb74401879a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 May 2025 10:04:50 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/ext/asyncio/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index b69baf73d6..ab5dd7fe19 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -22,7 +22,7 @@ from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql.base import Executable as _Executable from sqlalchemy.util.concurrency import greenlet_spawn -from typing_extensions import deprecated, Concatenate, ParamSpec +from typing_extensions import Concatenate, ParamSpec, deprecated from ...orm.session import Session from ...sql.base import Executable