Skip to content

Commit 76ee8bb

Browse files
committed
go back to S1 in groupby, set defaults for ByT and _TT
1 parent d5d9428 commit 76ee8bb

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pandas-stubs/_typing.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,7 @@ ByT = TypeVar(
892892
| Period
893893
| Interval[int | float | Timestamp | Timedelta]
894894
| tuple,
895+
default=Any,
895896
)
896897
# Use a distinct SeriesByT when using groupby with Series of known dtype.
897898
# Essentially, an intersection between Series S1 TypeVar, and ByT TypeVar

pandas-stubs/core/groupby/generic.pyi

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ from typing_extensions import (
3030

3131
from pandas._libs.tslibs.timestamps import Timestamp
3232
from pandas._typing import (
33-
S2,
33+
S1,
3434
AggFuncTypeBase,
3535
AggFuncTypeFrame,
3636
ByT,
@@ -52,7 +52,7 @@ class NamedAgg(NamedTuple):
5252
column: str
5353
aggfunc: AggScalar
5454

55-
class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
55+
class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]):
5656
@overload
5757
def aggregate(
5858
self,
@@ -114,7 +114,7 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
114114
self,
115115
indices: TakeIndexer,
116116
**kwargs,
117-
) -> Series[S2]: ...
117+
) -> Series[S1]: ...
118118
def skew(
119119
self,
120120
skipna: bool = ...,
@@ -125,10 +125,10 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
125125
def plot(self) -> GroupByPlot[Self]: ...
126126
def nlargest(
127127
self, n: int = ..., keep: NsmallestNlargestKeep = ...
128-
) -> Series[S2]: ...
128+
) -> Series[S1]: ...
129129
def nsmallest(
130130
self, n: int = ..., keep: NsmallestNlargestKeep = ...
131-
) -> Series[S2]: ...
131+
) -> Series[S1]: ...
132132
def idxmin(self, skipna: bool = ...) -> Series: ...
133133
def idxmax(self, skipna: bool = ...) -> Series: ...
134134
def corr(
@@ -166,9 +166,9 @@ class SeriesGroupBy(GroupBy[Series[S2]], Generic[S2, ByT]):
166166
@final # type: ignore[misc]
167167
def __iter__( # pyright: ignore[reportIncompatibleMethodOverride]
168168
self,
169-
) -> Iterator[tuple[ByT, Series[S2]]]: ...
169+
) -> Iterator[tuple[ByT, Series[S1]]]: ...
170170

171-
_TT = TypeVar("_TT", bound=Literal[True, False])
171+
_TT = TypeVar("_TT", bound=Literal[True, False], default=Literal[True])
172172

173173
class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
174174
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
@@ -217,7 +217,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]):
217217
def filter(
218218
self, func: Callable, dropna: bool = ..., *args, **kwargs
219219
) -> DataFrame: ...
220-
@overload
220+
@overload # type: ignore[override]
221221
def __getitem__(self, key: Scalar) -> SeriesGroupBy[Any, ByT]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
222222
@overload
223223
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]

0 commit comments

Comments
 (0)