@@ -3283,21 +3283,6 @@ def callable(x: int | NAType) -> str | NAType:
3283
3283
return str (x )
3284
3284
return x
3285
3285
3286
- def bad_callable (x : int ) -> int :
3287
- return x << 1
3288
-
3289
- with pytest .raises (TypeError ):
3290
- s .map (
3291
- bad_callable , na_action = None # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
3292
- )
3293
- with pytest .raises (TypeError ):
3294
- s .map (bad_callable ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
3295
- check (
3296
- assert_type (s .map (bad_callable , na_action = "ignore" ), "pd.Series[int]" ),
3297
- pd .Series ,
3298
- int ,
3299
- )
3300
-
3301
3286
check (
3302
3287
assert_type (s .map (callable , na_action = None ), "pd.Series[str]" ), pd .Series , str
3303
3288
)
@@ -3307,6 +3292,40 @@ def bad_callable(x: int) -> int:
3307
3292
check (assert_type (s .map (series , na_action = None ), "pd.Series[str]" ), pd .Series , str )
3308
3293
check (assert_type (s .map (series ), "pd.Series[str]" ), pd .Series , str )
3309
3294
3295
+ s2 : pd .Series [float ] = pd .Series ([1.0 , pd .NA , 3.0 ])
3296
+
3297
+ def callable2 (x : float ) -> float :
3298
+ return x + 1
3299
+
3300
+ check (
3301
+ assert_type (s2 .map (callable2 , na_action = "ignore" ), "pd.Series[float]" ),
3302
+ pd .Series ,
3303
+ float ,
3304
+ )
3305
+ check (
3306
+ assert_type (s2 .map (callable2 ), "pd.Series[float]" ),
3307
+ pd .Series ,
3308
+ float ,
3309
+ )
3310
+ if TYPE_CHECKING_INVALID_USAGE :
3311
+ s2 .map (callable2 , na_action = None ) # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
3312
+
3313
+ s3 : pd .Series [str ] = pd .Series (["A" , pd .NA , "C" ])
3314
+
3315
+ def callable3 (x : str ) -> str :
3316
+ return x .lower ()
3317
+
3318
+ check (
3319
+ assert_type (s3 .map (callable3 , na_action = "ignore" ), "pd.Series[str]" ),
3320
+ pd .Series ,
3321
+ str ,
3322
+ )
3323
+ if TYPE_CHECKING_INVALID_USAGE :
3324
+ s3 .map (
3325
+ callable3 , na_action = None # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
3326
+ )
3327
+ s3 .map (callable3 ) # type: ignore[type-var] # pyright: ignore[reportCallIssue, reportArgumentType]
3328
+
3310
3329
3311
3330
def test_case_when () -> None :
3312
3331
c = pd .Series ([6 , 7 , 8 , 9 ], name = "c" )
0 commit comments