Skip to content

Commit f761e14

Browse files
Task/trans ctx typing (#128)
* Fix return typing of AsyncEngine._trans_ctx.__aenter__ * StartableContext is made generic as to remove the assumption that it'll always be returning itself which is not the case for for it's behaviour in _trans_ctx * tests: add testcase for #109 Co-authored-by: Faster Speeding <luke@lmbyrne.dev>
1 parent 424b378 commit f761e14

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

sqlalchemy-stubs/ext/asyncio/base.pyi

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import abc
22
from typing import Any
33
from typing import Generator
4+
from typing import Generic
45
from typing import TypeVar
56

6-
_TStartableContext = TypeVar("_TStartableContext", bound=StartableContext)
7+
_T = TypeVar("_T")
78

8-
class StartableContext(abc.ABC, metaclass=abc.ABCMeta):
9+
class StartableContext(abc.ABC, Generic[_T], metaclass=abc.ABCMeta):
910
@abc.abstractmethod
10-
async def start(self: _TStartableContext) -> _TStartableContext: ...
11-
def __await__(
12-
self: _TStartableContext,
13-
) -> Generator[Any, None, _TStartableContext]: ...
14-
async def __aenter__(self: _TStartableContext) -> _TStartableContext: ...
11+
async def start(self) -> _T: ...
12+
def __await__(self) -> Generator[Any, None, _T]: ...
13+
async def __aenter__(self) -> _T: ...
1514
@abc.abstractmethod
1615
async def __aexit__(
1716
self, type_: Any, value: Any, traceback: Any

sqlalchemy-stubs/ext/asyncio/engine.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def create_async_engine(*arg: Any, **kw: Any) -> AsyncEngine: ...
2222

2323
class AsyncConnectable: ...
2424

25-
class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
25+
class AsyncConnection(
26+
ProxyComparable, StartableContext["AsyncConnection"], AsyncConnectable
27+
):
2628
# copied from future.Connection via create_proxy_methods
2729
@property
2830
def closed(self) -> bool: ...
@@ -102,12 +104,12 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
102104
def update_execution_options(self, **opt: Any) -> None: ...
103105
def get_execution_options(self) -> Mapping[Any, Any]: ...
104106
# end copied
105-
class _trans_ctx(StartableContext):
107+
class _trans_ctx(StartableContext[AsyncConnection]):
106108
conn: AsyncConnection = ...
107109
def __init__(self, conn: AsyncConnection) -> None: ...
108110
transaction: Any = ...
109-
async def start(self) -> AsyncConnection: ... # type: ignore[override]
110-
def __await__(self) -> Generator[Any, None, AsyncConnection]: ... # type: ignore[override]
111+
async def start(self) -> AsyncConnection: ...
112+
def __await__(self) -> Generator[Any, None, AsyncConnection]: ...
111113
async def __aexit__(
112114
self, type_: Any, value: Any, traceback: Any
113115
) -> None: ...
@@ -119,7 +121,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
119121
def execution_options(self, **opt: Any) -> AsyncEngine: ...
120122
async def dispose(self) -> None: ...
121123

122-
class AsyncTransaction(ProxyComparable, StartableContext):
124+
class AsyncTransaction(ProxyComparable, StartableContext["AsyncTransaction"]):
123125
connection: AsyncConnection = ...
124126
sync_transaction: Optional[Transaction] = ...
125127
nested: bool = ...

sqlalchemy-stubs/ext/asyncio/session.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class _AsyncSessionContextManager:
151151
self, type_: Any, value: Any, traceback: Any
152152
) -> None: ...
153153

154-
class AsyncSessionTransaction(StartableContext):
154+
class AsyncSessionTransaction(StartableContext["AsyncSessionTransaction"]):
155155
session: AsyncSession = ...
156156
nested: bool = ...
157157
sync_transaction: Optional[Any] = ...
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from sqlalchemy import literal
2+
from sqlalchemy import select
3+
from sqlalchemy.ext import asyncio
4+
5+
6+
async def test() -> None:
7+
database = asyncio.create_async_engine("", future=True)
8+
9+
trans_ctx = database.begin()
10+
async with trans_ctx as connection:
11+
await connection.execute(select(literal(1)))

0 commit comments

Comments
 (0)