Skip to content

Commit 524cce1

Browse files
authored
GH1103 Fixes for Dataframe.squeeze and Series.squeeze (#1199)
* fix type signatures of Dataframe.squeeze and Series.squeeze * add assert_type to tests
1 parent ee800a2 commit 524cce1

File tree

4 files changed

+32
-2
lines changed

4 files changed

+32
-2
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2194,7 +2194,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
21942194
numeric_only: _bool = ...,
21952195
**kwargs: Any,
21962196
) -> Series: ...
2197-
def squeeze(self, axis: Axis | None = ...): ...
2197+
def squeeze(self, axis: Axis | None = ...) -> DataFrame | Series | Scalar: ...
21982198
def std(
21992199
self,
22002200
axis: Axis = ...,

pandas-stubs/core/series.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
12231223
) -> Series[S1]: ...
12241224
def droplevel(self, level: Level | list[Level], axis: AxisIndex = ...) -> Self: ...
12251225
def pop(self, item: Hashable) -> S1: ...
1226-
def squeeze(self, axis: AxisIndex | None = ...) -> Scalar: ...
1226+
def squeeze(self) -> Series[S1] | Scalar: ...
12271227
def __abs__(self) -> Series[S1]: ...
12281228
def add_prefix(self, prefix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...
12291229
def add_suffix(self, suffix: _str, axis: AxisIndex | None = ...) -> Series[S1]: ...

tests/test_frame.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3253,6 +3253,25 @@ def test_resample() -> None:
32533253
check(assert_type(df.resample("2min").ohlc(), pd.DataFrame), pd.DataFrame)
32543254

32553255

3256+
def test_squeeze() -> None:
3257+
df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
3258+
check(
3259+
assert_type(df1.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), pd.DataFrame
3260+
)
3261+
df2 = pd.DataFrame({"a": [1, 2]})
3262+
check(assert_type(df2.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), pd.Series)
3263+
df3 = pd.DataFrame({"a": [1], "b": [2]})
3264+
check(
3265+
assert_type(df3.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]),
3266+
pd.Series,
3267+
np.integer,
3268+
)
3269+
df4 = pd.DataFrame({"a": [1]})
3270+
check(
3271+
assert_type(df4.squeeze(), Union[pd.DataFrame, pd.Series, Scalar]), np.integer
3272+
)
3273+
3274+
32563275
def test_loc_set() -> None:
32573276
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
32583277
df.loc["a"] = [3, 4]

tests/test_series.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,17 @@ def test_resample() -> None:
17321732
check(assert_type(s.resample("2min").ohlc(), pd.DataFrame), pd.DataFrame)
17331733

17341734

1735+
def test_squeeze() -> None:
1736+
s1 = pd.Series([1, 2, 3])
1737+
check(
1738+
assert_type(s1.squeeze(), Union["pd.Series[int]", Scalar]),
1739+
pd.Series,
1740+
np.integer,
1741+
)
1742+
s2 = pd.Series([1])
1743+
check(assert_type(s2.squeeze(), Union["pd.Series[int]", Scalar]), np.integer)
1744+
1745+
17351746
def test_to_xarray():
17361747
s = pd.Series([1, 2])
17371748
check(assert_type(s.to_xarray(), xr.DataArray), xr.DataArray)

0 commit comments

Comments
 (0)