Skip to content

Commit 7c5bdc9

Browse files
committed
Add implementation of AsyncSession.run_sync()
Currently, `sqlmodel.ext.asyncio.session.AsyncSession` doesn't implement `run_sync()`, which means that any call to `run_sync()` on a sqlmodel `AsyncSession` will be dispatched to the parent `sqlalchemy.ext.asyncio.AsyncSession`. The first argument to sqlalchemy's `AsyncSession.run_sync()` is a callable whose first argument is a `sqlalchemy.orm.Session` object. If we're using this in a repo that uses sqlmodel, we'll actually be passing a callable whose first argument is a `sqlmodel.orm.session.Session`. In practice this works fine - because `sqlmodel.orm.session.Session` is derived from `sqlalchemy.orm.Session`, the implementation of `sqlalchemy.ext.asyncio.AsyncSession.run_sync()` can use the sqlmodel `Session` object in place of the sqlalchemy `Session` object. However, static analysers will complain that the argument to `run_sync()` is of the wrong type. For example, here's a warning from pyright: ``` Pyright: Error: Argument of type "(session: Session, id: UUID) -> int" cannot be assigned to parameter "fn" of type "(Session, **_P@run_sync) -> _T@run_sync" in function "run_sync"   Type "(session: Session, id: UUID) -> int" is not assignable to type "(Session, id: UUID) -> int"     Parameter 1: type "Session" is incompatible with type "Session"       "sqlalchemy.orm.session.Session" is not assignable to "sqlmodel.orm.session.Session" [reportArgumentType] ``` This commit implements a `run_sync()` method on `sqlmodel.ext.asyncio.session.AsyncSession`, which casts the callable to the correct type before dispatching it to the base class. This satisfies the static type checks.
1 parent 6c0410e commit 7c5bdc9

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

sqlmodel/ext/asyncio/session.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import (
22
Any,
3+
Callable,
34
Dict,
45
Mapping,
56
Optional,
@@ -17,15 +18,17 @@
1718
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
1819
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
1920
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
21+
from sqlalchemy.orm import Session as _Session
2022
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
2123
from sqlalchemy.sql.base import Executable as _Executable
2224
from sqlalchemy.util.concurrency import greenlet_spawn
23-
from typing_extensions import deprecated
25+
from typing_extensions import deprecated, Concatenate, ParamSpec
2426

2527
from ...orm.session import Session
2628
from ...sql.base import Executable
2729
from ...sql.expression import Select, SelectOfScalar
2830

31+
_P = ParamSpec("_P")
2932
_TSelectParam = TypeVar("_TSelectParam", bound=Any)
3033

3134

@@ -148,3 +151,17 @@ async def execute( # type: ignore
148151
_parent_execute_state=_parent_execute_state,
149152
_add_event=_add_event,
150153
)
154+
155+
async def run_sync(
156+
self,
157+
fn: Callable[Concatenate[Session, _P], _TSelectParam],
158+
*arg: _P.args,
159+
**kw: _P.kwargs,
160+
) -> _TSelectParam:
161+
base_fn = cast(Callable[Concatenate[_Session, _P], _TSelectParam], fn)
162+
163+
return await super().run_sync(
164+
base_fn,
165+
*arg,
166+
**kw,
167+
)

0 commit comments

Comments
 (0)