-
-
Notifications
You must be signed in to change notification settings - Fork 144
GH456 First attempt GroupBy.transform improved typing #1242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
020f93d
106a6f5
3bba101
053b7e7
4141a06
f9863d0
e26b4c1
96abf3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,56 @@ | ||
from collections.abc import Hashable | ||
import dataclasses | ||
from typing import ( | ||
Literal, | ||
TypeAlias, | ||
) | ||
|
||
@dataclasses.dataclass(order=True, frozen=True) | ||
class OutputKey: | ||
label: Hashable | ||
position: int | ||
|
||
ReductionKernelType: TypeAlias = Literal[ | ||
"all", | ||
"any", | ||
"corrwith", | ||
"count", | ||
"first", | ||
"idxmax", | ||
"idxmin", | ||
"last", | ||
"max", | ||
"mean", | ||
"median", | ||
"min", | ||
"nunique", | ||
"prod", | ||
# as long as `quantile`'s signature accepts only | ||
# a single quantile value, it's a reduction. | ||
# GH#27526 might change that. | ||
"quantile", | ||
"sem", | ||
"size", | ||
"skew", | ||
"std", | ||
"sum", | ||
"var", | ||
] | ||
|
||
TransformationKernelType: TypeAlias = Literal[ | ||
"bfill", | ||
"cumcount", | ||
"cummax", | ||
"cummin", | ||
"cumprod", | ||
"cumsum", | ||
"diff", | ||
"ffill", | ||
"fillna", | ||
"ngroup", | ||
"pct_change", | ||
"rank", | ||
"shift", | ||
] | ||
|
||
TransformReductionListType: TypeAlias = ReductionKernelType | TransformationKernelType |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1078,25 +1078,111 @@ def test_types_groupby_agg() -> None: | |
r"The provided callable <built-in function (min|sum)> is currently using", | ||
upper="2.2.99", | ||
): | ||
check(assert_type(s.groupby(level=0).agg(sum), pd.Series), pd.Series) | ||
|
||
def sum_sr(s: pd.Series[int]) -> int: | ||
# type of `sum` not well inferred by mypy | ||
return s.sum() | ||
|
||
check( | ||
assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"), | ||
pd.Series, | ||
np.integer, | ||
) | ||
check( | ||
assert_type(s.groupby(level=0).agg([min, sum]), pd.DataFrame), pd.DataFrame | ||
) | ||
|
||
|
||
def test_types_groupby_transform() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should add tests for two of the string transform arguments (e.g., "mean", "first") |
||
s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) | ||
|
||
def transform_func( | ||
x: pd.Series[int], pos_arg: bool, kw_arg: str | ||
) -> pd.Series[float]: | ||
return x / (2.0 if pos_arg else 1.0) | ||
|
||
check( | ||
assert_type( | ||
s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"), | ||
"pd.Series[float]", | ||
), | ||
Dr-Irv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pd.Series, | ||
float, | ||
) | ||
check( | ||
assert_type( | ||
s.groupby(lambda x: x).transform( | ||
transform_func, True, engine="cython", kw_arg="foo" | ||
), | ||
"pd.Series[float]", | ||
), | ||
pd.Series, | ||
float, | ||
) | ||
check( | ||
assert_type( | ||
s.groupby(lambda x: x).transform("mean"), | ||
"pd.Series", | ||
), | ||
pd.Series, | ||
) | ||
check( | ||
assert_type( | ||
s.groupby(lambda x: x).transform("first"), | ||
"pd.Series", | ||
), | ||
pd.Series, | ||
) | ||
|
||
|
||
def test_types_groupby_aggregate() -> None: | ||
s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) | ||
check(assert_type(s.groupby(level=0).aggregate("sum"), pd.Series), pd.Series) | ||
check( | ||
assert_type(s.groupby(level=0).aggregate(["min", "sum"]), pd.DataFrame), | ||
pd.DataFrame, | ||
) | ||
|
||
def func(s: pd.Series[int]) -> float: | ||
return s.astype(float).min() | ||
|
||
s = pd.Series([1, 2, 3, 4]) | ||
check( | ||
assert_type(s.groupby([1, 1, 2, 2]).agg(func), "pd.Series[float]"), | ||
pd.Series, | ||
np.floating, | ||
) | ||
check( | ||
assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"), | ||
pd.Series, | ||
np.floating, | ||
) | ||
check( | ||
assert_type( | ||
s.groupby(level=0).aggregate(func, engine="cython"), "pd.Series[float]" | ||
), | ||
pd.Series, | ||
np.floating, | ||
) | ||
|
||
# test below passes with mypy but pyright correctly sees it as pd.Series[float] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just have to change the comment to say "fails with mypy" |
||
# check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), pd.Series), pd.Series, float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep the commented test in there so it is still there and executes, since it works for both type checkers, but comment out the one that is "better" that has There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am forced to comment it out because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess you have to keep it commented out. Do you have a test like this that passes both checkers: func: Callable[[pd.Series], float] = lambda x: x.astype(float).min()
check(assert_type(s.groupby([1,1,2,2]).agg(func), "pd.Series[float]"), pd.Series, float) So you can have the "preferred" version in there commented out, but I think the above test would pass both type checkers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually that also fails to pass with mypy (pyright is fine with it). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried a bunch of ideas and couldn't get it to work. It's probably a mypy bug, but I couldn't come up with a simple example that illustrates the problem. |
||
|
||
with pytest_warns_bounded( | ||
FutureWarning, | ||
r"The provided callable <built-in function (min|sum)> is currently using", | ||
upper="2.2.99", | ||
): | ||
check(assert_type(s.groupby(level=0).aggregate(sum), pd.Series), pd.Series) | ||
|
||
def sum_sr(s: pd.Series[int]) -> int: | ||
# type of `sum` not well inferred by mypy | ||
return s.sum() | ||
|
||
check( | ||
assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"), | ||
pd.Series, | ||
np.integer, | ||
) | ||
check( | ||
assert_type(s.groupby(level=0).aggregate([min, sum]), pd.DataFrame), | ||
pd.DataFrame, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this overload, you could add this overload:
Then you know that if you start with a
Series
with a known type, then the return type would be inferred from the callable. And it works with a lambda function, e.g.:In this case,
q
would have typeSeries[float]
, which is what you want.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's because the type of
new_func
isn't clear.But I think it would work if you did
check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), "pd.Series[int]"), pd.Series, int)
Because then it can know that
x
is aSeries[int]
and that thelambda
becomesSeries[int]
Can you try that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried that for the last push, see
pandas-stubs/tests/test_series.py
Line 1167 in f9863d0
It fails in all CI:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I look with how
mypy
reads the type of the lambda, it has no idea about the type ofx
:so that may explain why it fails on lambda expressions whatsoever.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK - so we can leave the
lambda
test in, but just have itassert_type()
againstSeries
instead ofSeries[float]