Skip to content

Commit b12c28d

Browse files
GH456 First attempt GroupBy.transform improved typing (#1242)
* GH456 First attempt GroupBy.transform improved typing * GH456 Attempt GroupBy.aggregate improved typing * GH456 Attempt GroupBy.aggregate improved typing * GH456 PR Feedback * GH456 PR Feedback * GH456 PR Feedback * GH456 PR Feedback * GH456 PR Feedback
1 parent d4d7e4c commit b12c28d

File tree

3 files changed

+199
-11
lines changed

3 files changed

+199
-11
lines changed

pandas-stubs/core/groupby/base.pyi

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,56 @@
11
from collections.abc import Hashable
22
import dataclasses
3+
from typing import (
4+
Literal,
5+
TypeAlias,
6+
)
37

48
@dataclasses.dataclass(order=True, frozen=True)
59
class OutputKey:
610
label: Hashable
711
position: int
12+
13+
ReductionKernelType: TypeAlias = Literal[
14+
"all",
15+
"any",
16+
"corrwith",
17+
"count",
18+
"first",
19+
"idxmax",
20+
"idxmin",
21+
"last",
22+
"max",
23+
"mean",
24+
"median",
25+
"min",
26+
"nunique",
27+
"prod",
28+
# as long as `quantile`'s signature accepts only
29+
# a single quantile value, it's a reduction.
30+
# GH#27526 might change that.
31+
"quantile",
32+
"sem",
33+
"size",
34+
"skew",
35+
"std",
36+
"sum",
37+
"var",
38+
]
39+
40+
TransformationKernelType: TypeAlias = Literal[
41+
"bfill",
42+
"cumcount",
43+
"cummax",
44+
"cummin",
45+
"cumprod",
46+
"cumsum",
47+
"diff",
48+
"ffill",
49+
"fillna",
50+
"ngroup",
51+
"pct_change",
52+
"rank",
53+
"shift",
54+
]
55+
56+
TransformReductionListType: TypeAlias = ReductionKernelType | TransformationKernelType

pandas-stubs/core/groupby/generic.pyi

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from collections.abc import (
77
)
88
from typing import (
99
Any,
10+
Concatenate,
1011
Generic,
1112
Literal,
1213
NamedTuple,
@@ -18,11 +19,15 @@ from typing import (
1819
from matplotlib.axes import Axes as PlotAxes
1920
import numpy as np
2021
from pandas.core.frame import DataFrame
22+
from pandas.core.groupby.base import TransformReductionListType
2123
from pandas.core.groupby.groupby import (
2224
GroupBy,
2325
GroupByPlot,
2426
)
25-
from pandas.core.series import Series
27+
from pandas.core.series import (
28+
Series,
29+
UnknownSeries,
30+
)
2631
from typing_extensions import (
2732
Self,
2833
TypeAlias,
@@ -31,6 +36,7 @@ from typing_extensions import (
3136
from pandas._libs.tslibs.timestamps import Timestamp
3237
from pandas._typing import (
3338
S1,
39+
S2,
3440
AggFuncTypeBase,
3541
AggFuncTypeFrame,
3642
ByT,
@@ -40,6 +46,7 @@ from pandas._typing import (
4046
Level,
4147
ListLike,
4248
NsmallestNlargestKeep,
49+
P,
4350
Scalar,
4451
TakeIndexer,
4552
WindowingEngine,
@@ -53,10 +60,30 @@ class NamedAgg(NamedTuple):
5360
aggfunc: AggScalar
5461

5562
class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
63+
@overload
64+
def aggregate(
65+
self,
66+
func: Callable[Concatenate[Series[S1], P], S2],
67+
/,
68+
*args,
69+
engine: WindowingEngine = ...,
70+
engine_kwargs: WindowingEngineKwargs = ...,
71+
**kwargs,
72+
) -> Series[S2]: ...
73+
@overload
74+
def aggregate(
75+
self,
76+
func: Callable[[Series], S2],
77+
*args,
78+
engine: WindowingEngine = ...,
79+
engine_kwargs: WindowingEngineKwargs = ...,
80+
**kwargs,
81+
) -> Series[S2]: ...
5682
@overload
5783
def aggregate(
5884
self,
5985
func: list[AggFuncTypeBase],
86+
/,
6087
*args,
6188
engine: WindowingEngine = ...,
6289
engine_kwargs: WindowingEngineKwargs = ...,
@@ -66,20 +93,34 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
6693
def aggregate(
6794
self,
6895
func: AggFuncTypeBase | None = ...,
96+
/,
6997
*args,
7098
engine: WindowingEngine = ...,
7199
engine_kwargs: WindowingEngineKwargs = ...,
72100
**kwargs,
73-
) -> Series: ...
101+
) -> UnknownSeries: ...
74102
agg = aggregate
103+
@overload
75104
def transform(
76105
self,
77-
func: Callable | str,
78-
*args,
106+
func: Callable[Concatenate[Series[S1], P], Series[S2]],
107+
/,
108+
*args: Any,
79109
engine: WindowingEngine = ...,
80110
engine_kwargs: WindowingEngineKwargs = ...,
81-
**kwargs,
82-
) -> Series: ...
111+
**kwargs: Any,
112+
) -> Series[S2]: ...
113+
@overload
114+
def transform(
115+
self,
116+
func: Callable,
117+
*args: Any,
118+
**kwargs: Any,
119+
) -> UnknownSeries: ...
120+
@overload
121+
def transform(
122+
self, func: TransformReductionListType, *args, **kwargs
123+
) -> UnknownSeries: ...
83124
def filter(
84125
self, func: Callable | str, dropna: bool = ..., *args, **kwargs
85126
) -> Series: ...
@@ -206,13 +247,25 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
206247
**kwargs,
207248
) -> DataFrame: ...
208249
agg = aggregate
250+
@overload
209251
def transform(
210252
self,
211-
func: Callable | str,
212-
*args,
253+
func: Callable[Concatenate[DataFrame, P], DataFrame],
254+
*args: Any,
213255
engine: WindowingEngine = ...,
214256
engine_kwargs: WindowingEngineKwargs = ...,
215-
**kwargs,
257+
**kwargs: Any,
258+
) -> DataFrame: ...
259+
@overload
260+
def transform(
261+
self,
262+
func: Callable,
263+
*args: Any,
264+
**kwargs: Any,
265+
) -> DataFrame: ...
266+
@overload
267+
def transform(
268+
self, func: TransformReductionListType, *args, **kwargs
216269
) -> DataFrame: ...
217270
def filter(
218271
self, func: Callable, dropna: bool = ..., *args, **kwargs

tests/test_series.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,25 +1078,111 @@ def test_types_groupby_agg() -> None:
10781078
r"The provided callable <built-in function (min|sum)> is currently using",
10791079
upper="2.3.99",
10801080
):
1081-
check(assert_type(s.groupby(level=0).agg(sum), pd.Series), pd.Series)
1081+
1082+
def sum_sr(s: pd.Series[int]) -> int:
1083+
# type of `sum` not well inferred by mypy
1084+
return s.sum()
1085+
1086+
check(
1087+
assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"),
1088+
pd.Series,
1089+
np.integer,
1090+
)
10821091
check(
10831092
assert_type(s.groupby(level=0).agg([min, sum]), pd.DataFrame), pd.DataFrame
10841093
)
10851094

10861095

1096+
def test_types_groupby_transform() -> None:
1097+
s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
1098+
1099+
def transform_func(
1100+
x: pd.Series[int], pos_arg: bool, kw_arg: str
1101+
) -> pd.Series[float]:
1102+
return x / (2.0 if pos_arg else 1.0)
1103+
1104+
check(
1105+
assert_type(
1106+
s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"),
1107+
"pd.Series[float]",
1108+
),
1109+
pd.Series,
1110+
float,
1111+
)
1112+
check(
1113+
assert_type(
1114+
s.groupby(lambda x: x).transform(
1115+
transform_func, True, engine="cython", kw_arg="foo"
1116+
),
1117+
"pd.Series[float]",
1118+
),
1119+
pd.Series,
1120+
float,
1121+
)
1122+
check(
1123+
assert_type(
1124+
s.groupby(lambda x: x).transform("mean"),
1125+
"pd.Series",
1126+
),
1127+
pd.Series,
1128+
)
1129+
check(
1130+
assert_type(
1131+
s.groupby(lambda x: x).transform("first"),
1132+
"pd.Series",
1133+
),
1134+
pd.Series,
1135+
)
1136+
1137+
10871138
def test_types_groupby_aggregate() -> None:
10881139
s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
10891140
check(assert_type(s.groupby(level=0).aggregate("sum"), pd.Series), pd.Series)
10901141
check(
10911142
assert_type(s.groupby(level=0).aggregate(["min", "sum"]), pd.DataFrame),
10921143
pd.DataFrame,
10931144
)
1145+
1146+
def func(s: pd.Series[int]) -> float:
1147+
return s.astype(float).min()
1148+
1149+
s = pd.Series([1, 2, 3, 4])
1150+
check(
1151+
assert_type(s.groupby([1, 1, 2, 2]).agg(func), "pd.Series[float]"),
1152+
pd.Series,
1153+
np.floating,
1154+
)
1155+
check(
1156+
assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"),
1157+
pd.Series,
1158+
np.floating,
1159+
)
1160+
check(
1161+
assert_type(
1162+
s.groupby(level=0).aggregate(func, engine="cython"), "pd.Series[float]"
1163+
),
1164+
pd.Series,
1165+
np.floating,
1166+
)
1167+
1168+
# test below fails with mypy but pyright correctly sees it as pd.Series[float]
1169+
# check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), "pd.Series[float]"), pd.Series, float)
1170+
10941171
with pytest_warns_bounded(
10951172
FutureWarning,
10961173
r"The provided callable <built-in function (min|sum)> is currently using",
10971174
upper="2.3.99",
10981175
):
1099-
check(assert_type(s.groupby(level=0).aggregate(sum), pd.Series), pd.Series)
1176+
1177+
def sum_sr(s: pd.Series[int]) -> int:
1178+
# type of `sum` not well inferred by mypy
1179+
return s.sum()
1180+
1181+
check(
1182+
assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"),
1183+
pd.Series,
1184+
np.integer,
1185+
)
11001186
check(
11011187
assert_type(s.groupby(level=0).aggregate([min, sum]), pd.DataFrame),
11021188
pd.DataFrame,

0 commit comments

Comments
 (0)