Skip to content

Commit 7371bb0

Browse files
heckadCaselIT
andauthored
Improve typing for the 'get' method in session (#132)
* Improve typing for the 'get' method in session * update other get signatures. Add tests Co-authored-by: Federico Caselli <cfederico87@gmail.com>
1 parent 326c37e commit 7371bb0

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

sqlalchemy-stubs/ext/asyncio/session.pyi

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing import Mapping
88
from typing import Optional
99
from typing import Protocol
1010
from typing import Sequence
11+
from typing import Type
1112
from typing import TypeVar
1213
from typing import Union
1314

@@ -105,15 +106,15 @@ class _AsyncSessionProtocol(
105106
) -> None: ...
106107
async def get(
107108
self,
108-
entity: Any,
109+
entity: Type[_T],
109110
ident: Any,
110111
options: Optional[Sequence[Any]] = ...,
111112
populate_existing: bool = ...,
112113
with_for_update: Optional[
113114
Union[Literal[True], Mapping[str, Any]]
114115
] = ...,
115116
identity_token: Optional[Any] = ...,
116-
) -> Any: ...
117+
) -> Optional[_T]: ...
117118
async def stream(
118119
self,
119120
statement: Any,
@@ -155,13 +156,13 @@ class _AsyncSessionTypingCommon(
155156
async def flush(self, objects: Optional[Any] = ...) -> None: ...
156157
async def get(
157158
self,
158-
entity: Any,
159+
entity: Type[_T],
159160
ident: Any,
160161
options: Optional[Any] = ...,
161162
populate_existing: bool = ...,
162163
with_for_update: Optional[Any] = ...,
163164
identity_token: Optional[Any] = ...,
164-
) -> Any: ...
165+
) -> Optional[_T]: ...
165166
async def merge(self, instance: _T, load: bool = ...) -> _T: ...
166167
async def refresh(
167168
self,

sqlalchemy-stubs/orm/session.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,15 @@ class _SessionProtocol(
144144
def delete(self, instance: Any) -> None: ...
145145
def get(
146146
self,
147-
entity: Any,
147+
entity: Type[_T],
148148
ident: Any,
149149
options: Optional[Sequence[Any]] = ...,
150150
populate_existing: bool = ...,
151151
with_for_update: Optional[
152152
Union[Literal[True], Mapping[str, Any]]
153153
] = ...,
154154
identity_token: Optional[Any] = ...,
155-
) -> Any: ...
155+
) -> Optional[_T]: ...
156156
def merge(self, instance: _T, load: bool = ...) -> _T: ...
157157
def flush(self, objects: Optional[Collection[Any]] = ...) -> None: ...
158158
@classmethod
@@ -364,15 +364,15 @@ class _SessionTypingCommon(_SessionNoIoTypingCommon):
364364
def flush(self, objects: Optional[Collection[Any]] = ...) -> None: ...
365365
def get(
366366
self,
367-
entity: Any,
367+
entity: Type[_T],
368368
ident: Any,
369369
options: Optional[Sequence[Any]] = ...,
370370
populate_existing: bool = ...,
371371
with_for_update: Optional[
372372
Union[Literal[True], Mapping[str, Any]]
373373
] = ...,
374374
identity_token: Optional[Any] = ...,
375-
) -> Any: ...
375+
) -> Optional[_T]: ...
376376
def bulk_save_objects(
377377
self,
378378
objects: Sequence[Any],

test/files/session_get.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Optional
2+
3+
from sqlalchemy import Column
4+
from sqlalchemy import Integer
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
from sqlalchemy.orm import registry
7+
from sqlalchemy.orm import Session
8+
9+
mr: registry = registry()
10+
11+
12+
@mr.mapped
13+
class Foo:
14+
id = Column(Integer, primary_key=True)
15+
__tablename__ = "foo"
16+
17+
18+
s = Session()
19+
x: Optional[Foo] = s.get(Foo, 1)
20+
21+
22+
async def go() -> None:
23+
as_ = AsyncSession()
24+
y: Optional[Foo] = await as_.get(Foo, 1)

0 commit comments

Comments
 (0)