From ab8b011f8e624364cce5b7a7db6785d420fc8455 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 1 Jul 2024 18:23:21 -0400 Subject: [PATCH] add case_when --- pandas-stubs/core/series.pyi | 11 +++++++++++ tests/test_series.py | 13 +++++++++++++ 2 files changed, 24 insertions(+) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 41ee06d5c..209237746 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1426,6 +1426,17 @@ class Series(IndexOpsMixin[S1], NDFrame): axis: AxisIndex | None = ..., level: Level | None = ..., ) -> Series[S1]: ... + def case_when( + self, + caselist: list[ + tuple[ + Sequence[bool] + | Series[bool] + | Callable[[Series], Series | np.ndarray | Sequence[bool]], + ListLikeU | Scalar | Callable[[Series], Series | np.ndarray], + ], + ], + ) -> Series: ... def truncate( self, before: date | _str | int | None = ..., diff --git a/tests/test_series.py b/tests/test_series.py index 470bfb431..9c99dba41 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -3288,3 +3288,16 @@ def callable(x: int | NAType) -> str | NAType: series = pd.Series(["a", "b", "c"]) check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str) + + +def test_case_when() -> None: + c = pd.Series([6, 7, 8, 9], name="c") + a = pd.Series([0, 0, 1, 2]) + b = pd.Series([0, 3, 4, 5]) + r = c.case_when( + caselist=[ + (a.gt(0), a), + (b.gt(0), b), + ] + ) + check(assert_type(r, pd.Series), pd.Series)