Skip to content

Commit 4631862

Browse files
authored
improve types of over in functions (#158)
1 parent f4b68dc commit 4631862

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

sqlalchemy-stubs/sql/elements.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ _BE = TypeVar("_BE", bound=BinaryExpression[Any])
5959
_QN = TypeVar("_QN", bound=quoted_name)
6060
_TL = TypeVar("_TL", bound=_truncated_label)
6161

62+
_OverByType = Union[ClauseElement, str]
63+
6264
def collate(expression: Any, collation: str) -> BinaryExpression[_TE]: ...
6365
def between(
6466
expr: Any,
@@ -542,8 +544,10 @@ class Over(ColumnElement[_TE]):
542544
def __init__(
543545
self,
544546
element: ColumnElement[_TE],
545-
partition_by: Optional[ClauseElement] = ...,
546-
order_by: Optional[ClauseElement] = ...,
547+
partition_by: Optional[
548+
Union[_OverByType, Sequence[_OverByType]]
549+
] = ...,
550+
order_by: Optional[Union[_OverByType, Sequence[_OverByType]]] = ...,
547551
range_: Optional[Any] = ...,
548552
rows: Optional[Any] = ...,
549553
) -> None: ...

sqlalchemy-stubs/sql/functions.pyi

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from typing import Any
22
from typing import Generic
33
from typing import Optional
44
from typing import overload
5+
from typing import Sequence
56
from typing import Type
67
from typing import TypeVar
78
from typing import Union
@@ -35,6 +36,8 @@ _T_co = TypeVar("_T_co", covariant=True)
3536
_TE = TypeVar("_TE", bound=type_api.TypeEngine[Any])
3637
_FE = TypeVar("_FE", bound=FunctionElement[Any])
3738

39+
_OverByType = Union[ClauseElement, str]
40+
3841
def register_function(
3942
identifier: str, fn: Any, package: str = ...
4043
) -> None: ...
@@ -63,8 +66,10 @@ class FunctionElement( # type: ignore[misc]
6366
def clauses(self) -> ClauseList[Any]: ...
6467
def over(
6568
self,
66-
partition_by: Optional[ClauseElement] = ...,
67-
order_by: Optional[ClauseElement] = ...,
69+
partition_by: Optional[
70+
Union[_OverByType, Sequence[_OverByType]]
71+
] = ...,
72+
order_by: Optional[Union[_OverByType, Sequence[_OverByType]]] = ...,
6873
rows: Optional[Any] = ...,
6974
range_: Optional[Any] = ...,
7075
) -> Over[_TE]: ...

test/files/functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from sqlalchemy import Column, Integer, create_engine, func
2+
from sqlalchemy.orm import Session, declarative_base, sessionmaker
3+
4+
Base = declarative_base()
5+
6+
7+
class Foo(Base):
8+
__tablename__ = "foo"
9+
10+
id = Column(Integer(), primary_key=True)
11+
a = Column(Integer())
12+
b = Column(Integer())
13+
14+
15+
func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())
16+
func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
17+
func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
18+
func.row_number().over(order_by="a", partition_by=("a", "b"))
19+
func.row_number().over(partition_by="a", order_by=("a", "b"))

0 commit comments

Comments
 (0)