diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index baa358a8a..7ec2fe200 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -568,6 +568,8 @@ S2 = TypeVar( | BaseOffset, ) +SN = TypeVar("SN", bound=bool | int | float | complex) + IndexingInt: TypeAlias = ( int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8 ) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 209237746..13a657d8d 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -94,6 +94,7 @@ from pandas._libs.tslibs.nattype import NaTType from pandas._typing import ( S1, S2, + SN, AggFuncTypeBase, AggFuncTypeDictFrame, AggFuncTypeSeriesToFrame, @@ -916,10 +917,16 @@ class Series(IndexOpsMixin[S1], NDFrame): @overload def map( self, - arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2], + arg: Callable[[SN], S2 | NAType] | Mapping[SN, S2] | Series[S2], na_action: Literal["ignore"] = ..., ) -> Series[S2]: ... @overload + def map( + self, + arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2], + na_action: Literal["ignore"], + ) -> Series[S2]: ... + @overload def map( self, arg: Callable[[S1 | NAType], S2 | NAType] | Mapping[S1, S2] | Series[S2], diff --git a/tests/test_series.py b/tests/test_series.py index 9c99dba41..306910323 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -3276,6 +3276,7 @@ def test_map_na() -> None: mapping = {1: "a", 2: "b", 3: "c"} check(assert_type(s.map(mapping, na_action=None), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.map(mapping), "pd.Series[str]"), pd.Series, str) def callable(x: int | NAType) -> str | NAType: if isinstance(x, int): @@ -3285,9 +3286,45 @@ def callable(x: int | NAType) -> str | NAType: check( assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str ) + check(assert_type(s.map(callable), "pd.Series[str]"), pd.Series, str) series = pd.Series(["a", "b", "c"]) check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.map(series), "pd.Series[str]"), pd.Series, str) + + s2: pd.Series[float] = pd.Series([1.0, pd.NA, 3.0]) + + def callable2(x: float) -> float: + return x + 1 + + check( + assert_type(s2.map(callable2, na_action="ignore"), "pd.Series[float]"), + pd.Series, + float, + ) + check( + assert_type(s2.map(callable2), "pd.Series[float]"), + pd.Series, + float, + ) + if TYPE_CHECKING_INVALID_USAGE: + s2.map(callable2, na_action=None) # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType] + + s3: pd.Series[str] = pd.Series(["A", pd.NA, "C"]) + + def callable3(x: str) -> str: + return x.lower() + + check( + assert_type(s3.map(callable3, na_action="ignore"), "pd.Series[str]"), + pd.Series, + str, + ) + if TYPE_CHECKING_INVALID_USAGE: + s3.map( + callable3, na_action=None # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType] + ) + s3.map(callable3) # type: ignore[type-var] # pyright: ignore[reportCallIssue, reportArgumentType] def test_case_when() -> None: