Skip to content

Commit 88567bf

Browse files
authored
Improve orm.query (#70)
* Improve `orm.query` * Updates based on PR feedback
1 parent 2265585 commit 88567bf

File tree

2 files changed

+138
-94
lines changed

2 files changed

+138
-94
lines changed

sqlalchemy-stubs/orm/query.pyi

Lines changed: 133 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,41 @@
11
from typing import Any
2+
from typing import Generic
3+
from typing import Iterable
4+
from typing import Iterator
25
from typing import List
36
from typing import Optional
7+
from typing import Tuple
8+
from typing import TypeVar
9+
from typing import Union
10+
11+
from typing_extensions import Protocol
412

513
from . import interfaces
6-
from .context import QueryContext as QueryContext
7-
from .util import aliased as aliased
14+
from .session import _SessionProtocol
15+
from .state import InstanceState
816
from ..sql.annotation import SupportsCloneAnnotations
917
from ..sql.base import Executable
18+
from ..sql.elements import BooleanClauseList
19+
from ..sql.elements import GroupedElement
20+
from ..sql.elements import Label
1021
from ..sql.selectable import _SelectFromElements
11-
from ..sql.selectable import GroupedElement
22+
from ..sql.selectable import Alias
23+
from ..sql.selectable import CTE
24+
from ..sql.selectable import Exists
1225
from ..sql.selectable import HasHints
1326
from ..sql.selectable import HasPrefixes
1427
from ..sql.selectable import HasSuffixes
28+
from ..sql.selectable import ScalarSelect
29+
from ..sql.selectable import Select
1530
from ..sql.selectable import SelectBase
31+
from ..sql.selectable import SelectStatementGrouping
32+
from ..util.langhelpers import _symbol
33+
34+
_T = TypeVar("_T")
35+
_TQuery = TypeVar("_TQuery", bound=Query[Any])
36+
37+
class _WithTransformationFn(Protocol[_TQuery]):
38+
def __call__(self, __query: _TQuery) -> _TQuery: ...
1639

1740
class Query(
1841
_SelectFromElements,
@@ -21,144 +44,163 @@ class Query(
2144
HasSuffixes,
2245
HasHints,
2346
Executable,
47+
Generic[_T],
2448
):
2549
load_options: Any = ...
26-
session: Any = ...
50+
session: Optional[_SessionProtocol] = ...
2751
def __init__(
28-
self, entities: Any, session: Optional[Any] = ...
52+
self,
53+
entities: Iterable[Any],
54+
session: Optional[_SessionProtocol] = ...,
2955
) -> None: ...
3056
@property
31-
def statement(self): ...
57+
def statement(self) -> Union[Select, FromStatement]: ...
3258
def subquery(
3359
self,
34-
name: Optional[Any] = ...,
60+
name: Optional[str] = ...,
3561
with_labels: bool = ...,
3662
reduce_columns: bool = ...,
37-
): ...
38-
def cte(self, name: Optional[Any] = ..., recursive: bool = ...): ...
39-
def label(self, name: Any): ...
40-
def as_scalar(self): ...
41-
def scalar_subquery(self): ...
42-
def __clause_element__(self): ...
43-
def only_return_tuples(self, value: Any) -> None: ...
63+
) -> Alias: ...
64+
def cte(self, name: Optional[str] = ..., recursive: bool = ...) -> CTE: ...
65+
def label(self, name: str) -> Label[Any]: ...
66+
def as_scalar(self) -> ScalarSelect: ...
67+
def scalar_subquery(self) -> ScalarSelect: ...
68+
@property
69+
def selectable(self) -> Union[Select, FromStatement]: ...
70+
def __clause_element__(self) -> Union[Select, FromStatement]: ...
71+
def only_return_tuples(self: _TQuery, value: bool) -> _TQuery: ...
4472
@property
45-
def is_single_entity(self): ...
46-
def enable_eagerloads(self, value: Any) -> None: ...
47-
def with_labels(self): ...
48-
apply_labels: Any = ...
73+
def is_single_entity(self) -> bool: ...
74+
def enable_eagerloads(self: _TQuery, value: bool) -> _TQuery: ...
75+
def with_labels(self: _TQuery) -> _TQuery: ...
76+
def apply_labels(self: _TQuery) -> _TQuery: ...
4977
@property
50-
def get_label_style(self): ...
51-
def set_label_style(self, style: Any): ...
52-
def enable_assertions(self, value: Any) -> None: ...
78+
def get_label_style(self) -> _symbol: ...
79+
def set_label_style(self: _TQuery, style: _symbol) -> _TQuery: ...
80+
def enable_assertions(self: _TQuery, value: bool) -> _TQuery: ...
5381
@property
54-
def whereclause(self): ...
82+
def whereclause(self) -> BooleanClauseList[Any]: ...
5583
def with_polymorphic(
56-
self,
84+
self: _TQuery,
5785
cls_or_mappers: Any,
5886
selectable: Optional[Any] = ...,
5987
polymorphic_on: Optional[Any] = ...,
60-
) -> None: ...
61-
def yield_per(self, count: Any) -> None: ...
62-
def get(self, ident: Any): ...
88+
) -> _TQuery: ...
89+
def yield_per(self: _TQuery, count: int) -> _TQuery: ...
90+
def get(self, ident: Any) -> Optional[_T]: ...
6391
@property
64-
def lazy_loaded_from(self): ...
65-
def correlate(self, *fromclauses: Any) -> None: ...
66-
def autoflush(self, setting: Any) -> None: ...
67-
def populate_existing(self) -> None: ...
92+
def lazy_loaded_from(self) -> InstanceState: ...
93+
def correlate(self: _TQuery, *fromclauses: Any) -> _TQuery: ...
94+
def autoflush(self: _TQuery, setting: Any) -> _TQuery: ...
95+
def populate_existing(self: _TQuery) -> _TQuery: ...
6896
def with_parent(
69-
self,
97+
self: _TQuery,
7098
instance: Any,
7199
property: Optional[Any] = ...,
72100
from_entity: Optional[Any] = ...,
73-
): ...
74-
def add_entity(self, entity: Any, alias: Optional[Any] = ...) -> None: ...
75-
def with_session(self, session: Any) -> None: ...
76-
def from_self(self, *entities: Any): ...
77-
def values(self, *columns: Any): ...
78-
def value(self, column: Any): ...
79-
def with_entities(self, *entities: Any) -> None: ...
80-
def add_columns(self, *column: Any) -> None: ...
81-
def add_column(self, column: Any): ...
82-
def options(self, *args: Any) -> None: ...
83-
def with_transformation(self, fn: Any): ...
84-
def get_execution_options(self): ...
85-
def execution_options(self, **kwargs: Any) -> None: ...
101+
) -> _TQuery: ...
102+
def add_entity(
103+
self: _TQuery, entity: Any, alias: Optional[Any] = ...
104+
) -> _TQuery: ...
105+
def with_session(
106+
self: _TQuery, session: Optional[_SessionProtocol]
107+
) -> _TQuery: ...
108+
def from_self(self: _TQuery, *entities: Any) -> _TQuery: ...
109+
def values(self, *columns: Any) -> Iterator[Tuple[Any, ...]]: ...
110+
def value(self, column: Any) -> Optional[Any]: ...
111+
def with_entities(self: _TQuery, *entities: Any) -> _TQuery: ...
112+
def add_columns(self: _TQuery, *column: Any) -> _TQuery: ...
113+
def add_column(self: _TQuery, column: Any) -> _TQuery: ...
114+
def options(self: _TQuery, *args: Any) -> _TQuery: ...
115+
def with_transformation(
116+
self: _TQuery, fn: _WithTransformationFn[_TQuery]
117+
) -> _TQuery: ...
118+
def get_execution_options(self) -> Any: ...
119+
def execution_options(self: _TQuery, **kwargs: Any) -> _TQuery: ...
86120
def with_for_update(
87-
self,
121+
self: _TQuery,
88122
read: bool = ...,
89123
nowait: bool = ...,
90124
of: Optional[Any] = ...,
91125
skip_locked: bool = ...,
92126
key_share: bool = ...,
93-
) -> None: ...
94-
def params(self, *args: Any, **kwargs: Any) -> None: ...
95-
def where(self, *criterion: Any): ...
96-
def filter(self, *criterion: Any) -> None: ...
97-
def filter_by(self, **kwargs: Any): ...
98-
def order_by(self, *clauses: Any) -> None: ...
99-
def group_by(self, *clauses: Any) -> None: ...
100-
def having(self, criterion: Any) -> None: ...
101-
def union(self, *q: Any): ...
102-
def union_all(self, *q: Any): ...
103-
def intersect(self, *q: Any): ...
104-
def intersect_all(self, *q: Any): ...
105-
def except_(self, *q: Any): ...
106-
def except_all(self, *q: Any): ...
107-
def join(self, target: Any, *props: Any, **kwargs: Any) -> None: ...
108-
def outerjoin(self, target: Any, *props: Any, **kwargs: Any): ...
109-
def reset_joinpoint(self) -> None: ...
110-
def select_from(self, *from_obj: Any) -> None: ...
111-
def select_entity_from(self, from_obj: Any) -> None: ...
112-
def __getitem__(self, item: Any): ...
113-
def slice(self, start: Any, stop: Any) -> None: ...
114-
def limit(self, limit: Any) -> None: ...
115-
def offset(self, offset: Any) -> None: ...
116-
def distinct(self, *expr: Any) -> None: ...
117-
def all(self) -> List[Any]: ...
118-
def from_statement(self, statement: Any) -> None: ...
119-
def first(self): ...
120-
def one_or_none(self): ...
121-
def one(self): ...
122-
def scalar(self): ...
123-
def __iter__(self) -> Any: ...
127+
) -> _TQuery: ...
128+
def params(self: _TQuery, *args: Any, **kwargs: Any) -> _TQuery: ...
129+
def where(self: _TQuery, *criterion: Any) -> _TQuery: ...
130+
def filter(self: _TQuery, *criterion: Any) -> _TQuery: ...
131+
def filter_by(self: _TQuery, **kwargs: Any) -> _TQuery: ...
132+
def order_by(self: _TQuery, *clauses: Any) -> _TQuery: ...
133+
def group_by(self: _TQuery, *clauses: Any) -> _TQuery: ...
134+
def having(self: _TQuery, criterion: Any) -> _TQuery: ...
135+
def union(self: _TQuery, *q: Any) -> _TQuery: ...
136+
def union_all(self: _TQuery, *q: Any) -> _TQuery: ...
137+
def intersect(self: _TQuery, *q: Any) -> _TQuery: ...
138+
def intersect_all(self: _TQuery, *q: Any) -> _TQuery: ...
139+
def except_(self: _TQuery, *q: Any) -> _TQuery: ...
140+
def except_all(self: _TQuery, *q: Any) -> _TQuery: ...
141+
def join(
142+
self: _TQuery, target: Any, *props: Any, **kwargs: Any
143+
) -> _TQuery: ...
144+
def outerjoin(
145+
self: _TQuery, target: Any, *props: Any, **kwargs: Any
146+
) -> _TQuery: ...
147+
def reset_joinpoint(self: _TQuery) -> _TQuery: ...
148+
def select_from(self: _TQuery, *from_obj: Any) -> _TQuery: ...
149+
def select_entity_from(self: _TQuery, from_obj: Any) -> _TQuery: ...
150+
def __getitem__(self, item: Any) -> Any: ...
151+
def slice(self: _TQuery, start: Any, stop: Any) -> _TQuery: ...
152+
def limit(self: _TQuery, limit: Any) -> _TQuery: ...
153+
def offset(self: _TQuery, offset: Any) -> _TQuery: ...
154+
def distinct(self: _TQuery, *expr: Any) -> _TQuery: ...
155+
def all(self) -> List[_T]: ...
156+
def from_statement(self: _TQuery, statement: Any) -> _TQuery: ...
157+
def first(self) -> Optional[_T]: ...
158+
def one_or_none(self) -> Optional[_T]: ...
159+
def one(self) -> _T: ...
160+
def scalar(self) -> Any: ... # type: ignore[override]
161+
def __iter__(self) -> Iterator[_T]: ...
124162
@property
125-
def column_descriptions(self): ...
126-
def instances(self, result_proxy: Any, context: Optional[Any] = ...): ...
127-
def merge_result(self, iterator: Any, load: bool = ...): ...
128-
def exists(self): ...
129-
def count(self): ...
130-
def delete(self, synchronize_session: str = ...): ...
163+
def column_descriptions(self) -> List[Any]: ...
164+
def instances(
165+
self, result_proxy: Any, context: Optional[Any] = ...
166+
) -> Any: ...
167+
def merge_result(self, iterator: Any, load: bool = ...) -> Any: ...
168+
def exists(self) -> Exists: ...
169+
def count(self) -> int: ...
170+
def delete(self, synchronize_session: str = ...) -> int: ...
131171
def update(
132172
self,
133173
values: Any,
134174
synchronize_session: str = ...,
135175
update_args: Optional[Any] = ...,
136-
): ...
176+
) -> int: ...
137177

138178
class FromStatement(GroupedElement, SelectBase, Executable):
139179
__visit_name__: str = ...
140180
element: Any = ...
141-
def __init__(self, entities: Any, element: Any) -> None: ...
142-
def get_label_style(self): ...
143-
def set_label_style(self, label_style: Any): ...
144-
def get_children(self, **kw: Any) -> None: ...
181+
def __init__(self, entities: Iterable[Any], element: Any) -> None: ...
182+
def get_label_style(self) -> _symbol: ...
183+
def set_label_style(
184+
self, label_style: _symbol
185+
) -> SelectStatementGrouping: ...
186+
def get_children(self, **kw: Any) -> Iterable[Any]: ... # type: ignore[override]
145187

146188
class AliasOption(interfaces.LoaderOption):
147189
def __init__(self, alias: Any) -> None: ...
148190
def process_compile_state(self, compile_state: Any) -> None: ...
149191

150192
class BulkUD:
151-
query: Any = ...
193+
query: Query[Any] = ...
152194
mapper: Any = ...
153-
def __init__(self, query: Any) -> None: ...
195+
def __init__(self, query: Query[Any]) -> None: ...
154196
@property
155-
def session(self): ...
197+
def session(self) -> Optional[_SessionProtocol]: ...
156198

157199
class BulkUpdate(BulkUD):
158200
values: Any = ...
159201
update_kwargs: Any = ...
160202
def __init__(
161-
self, query: Any, values: Any, update_kwargs: Any
203+
self, query: Query[Any], values: Any, update_kwargs: Any
162204
) -> None: ...
163205

164206
class BulkDelete(BulkUD): ...

sqlalchemy-stubs/orm/session.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ _TSessionTransaction = TypeVar(
3939
)
4040

4141
_TSharedSessionProtocol = TypeVar(
42-
"_TSharedSessionProtocol", bound=_SharedSessionProtocol
42+
"_TSharedSessionProtocol", bound=_SharedSessionProtocol[Any]
4343
)
4444
_TSessionTransactionProtocol = TypeVar(
4545
"_TSessionTransactionProtocol", bound=_SessionTransactionProtocol
@@ -406,7 +406,9 @@ class Session(_SessionClassMethods):
406406
binds: Optional[Mapping[Any, Union[Connection, Engine]]] = ...,
407407
enable_baked_queries: bool = ...,
408408
info: Optional[Mapping[Any, Any]] = ...,
409-
query_cls: Optional[Union[Query, Callable[..., Query]]] = ...,
409+
query_cls: Optional[
410+
Union[Type[Query[Any]], Callable[..., Query[Any]]]
411+
] = ...,
410412
) -> None: ...
411413
connection_callable: Any = ...
412414
def __enter__(self: _TSession) -> _TSession: ...
@@ -473,7 +475,7 @@ class Session(_SessionClassMethods):
473475
_sa_skip_events: Optional[Any] = ...,
474476
_sa_skip_for_implicit_returning: bool = ...,
475477
) -> Union[Connection, Engine]: ...
476-
def query(self, *entities: Any, **kwargs: Any) -> Query: ...
478+
def query(self, *entities: Any, **kwargs: Any) -> Query[Any]: ...
477479
@property
478480
def no_autoflush(self: _TSession) -> ContextManager[_TSession]: ...
479481
def refresh(

0 commit comments

Comments
 (0)