Skip to content

Commit 5817176

Browse files
committed
nits
1 parent 0e8706e commit 5817176

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

array_api_compat/common/_helpers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,14 @@ def __getitem__(self, idx):
903903
self.idx = idx
904904
return self
905905

906-
def _common(self, at_op, y=_undef, mode: str = "promise_in_bounds", **kwargs):
906+
def _common(
907+
self,
908+
at_op: str,
909+
y=_undef,
910+
copy: bool | None = True,
911+
mode: str = "promise_in_bounds",
912+
**kwargs,
913+
):
907914
"""Validate kwargs and perform common prepocessing.
908915
909916
Returns
@@ -1028,7 +1035,7 @@ def max(self, y, /, **kwargs):
10281035
xp = array_namespace(self.x)
10291036
return self._iop("max", xp.maximum, y, **kwargs)
10301037

1031-
def where(condition, x, y, /, copy: bool | None = True):
1038+
def where(condition, x=None, y=None, /, copy: bool | None = True):
10321039
"""Return elements from x when condition is True and from y when
10331040
it is False.
10341041
@@ -1043,6 +1050,10 @@ def where(condition, x, y, /, copy: bool | None = True):
10431050
False
10441051
Raise ValueError if a copy cannot be avoided.
10451052
"""
1053+
if x is None and y is None:
1054+
xp = array_namespace(condition)
1055+
return xp.where(condition)
1056+
10461057
copy = _parse_copy_param(x, copy)
10471058
xp = array_namespace(condition, x, y)
10481059
if copy:

0 commit comments

Comments
 (0)