Skip to content

Commit 8810cc7

Browse files
Further splitting of map overloads
1 parent 7adb42d commit 8810cc7

File tree

3 files changed

+43
-15
lines changed

3 files changed

+43
-15
lines changed

pandas-stubs/_typing.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ S2 = TypeVar(
568568
| BaseOffset,
569569
)
570570

571+
SN = TypeVar("SN", bound=bool | int | float | complex)
572+
571573
IndexingInt: TypeAlias = (
572574
int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8
573575
)

pandas-stubs/core/series.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ from pandas._libs.tslibs.nattype import NaTType
9494
from pandas._typing import (
9595
S1,
9696
S2,
97+
SN,
9798
AggFuncTypeBase,
9899
AggFuncTypeDictFrame,
99100
AggFuncTypeSeriesToFrame,
@@ -914,6 +915,12 @@ class Series(IndexOpsMixin[S1], NDFrame):
914915
fill_value: int | _str | dict | None = ...,
915916
) -> DataFrame: ...
916917
@overload
918+
def map(
919+
self,
920+
arg: Callable[[SN], S2 | NAType] | Mapping[SN, S2] | Series[S2],
921+
na_action: Literal["ignore"] = ...,
922+
) -> Series[S2]: ...
923+
@overload
917924
def map(
918925
self,
919926
arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2],

tests/test_series.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3283,21 +3283,6 @@ def callable(x: int | NAType) -> str | NAType:
32833283
return str(x)
32843284
return x
32853285

3286-
def bad_callable(x: int) -> int:
3287-
return x << 1
3288-
3289-
with pytest.raises(TypeError):
3290-
s.map(
3291-
bad_callable, na_action=None # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
3292-
)
3293-
with pytest.raises(TypeError):
3294-
s.map(bad_callable) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
3295-
check(
3296-
assert_type(s.map(bad_callable, na_action="ignore"), "pd.Series[int]"),
3297-
pd.Series,
3298-
int,
3299-
)
3300-
33013286
check(
33023287
assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str
33033288
)
@@ -3307,6 +3292,40 @@ def bad_callable(x: int) -> int:
33073292
check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str)
33083293
check(assert_type(s.map(series), "pd.Series[str]"), pd.Series, str)
33093294

3295+
s2: pd.Series[float] = pd.Series([1.0, pd.NA, 3.0])
3296+
3297+
def callable2(x: float) -> float:
3298+
return x + 1
3299+
3300+
check(
3301+
assert_type(s2.map(callable2, na_action="ignore"), "pd.Series[float]"),
3302+
pd.Series,
3303+
float,
3304+
)
3305+
check(
3306+
assert_type(s2.map(callable2), "pd.Series[float]"),
3307+
pd.Series,
3308+
float,
3309+
)
3310+
if TYPE_CHECKING_INVALID_USAGE:
3311+
s2.map(callable2, na_action=None) # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
3312+
3313+
s3: pd.Series[str] = pd.Series(["A", pd.NA, "C"])
3314+
3315+
def callable3(x: str) -> str:
3316+
return x.lower()
3317+
3318+
check(
3319+
assert_type(s3.map(callable3, na_action="ignore"), "pd.Series[str]"),
3320+
pd.Series,
3321+
str,
3322+
)
3323+
if TYPE_CHECKING_INVALID_USAGE:
3324+
s3.map(
3325+
callable3, na_action=None # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
3326+
)
3327+
s3.map(callable3) # type: ignore[type-var] # pyright: ignore[reportCallIssue, reportArgumentType]
3328+
33103329

33113330
def test_case_when() -> None:
33123331
c = pd.Series([6, 7, 8, 9], name="c")

0 commit comments

Comments
 (0)