Skip to content

Commit 53a4ac9

Browse files
committed
Revert where
1 parent 437d73a commit 53a4ac9

File tree

2 files changed

+12
-49
lines changed

2 files changed

+12
-49
lines changed

array_api_compat/common/_helpers.py

Lines changed: 12 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -815,23 +815,6 @@ def is_writeable_array(x):
815815
return True
816816

817817

818-
def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
819-
"""Preprocess and validate a copy parameter, in line with the same
820-
parameter in np.asarray(), np.astype(), etc.
821-
"""
822-
if copy is True:
823-
return True
824-
elif copy is False:
825-
if not is_writeable_array(x):
826-
raise ValueError("Cannot avoid modifying parameter in place")
827-
return False
828-
elif copy is None:
829-
return not is_writeable_array(x)
830-
elif copy == "_force_false":
831-
return False
832-
raise ValueError(f"Invalid value for copy: {copy!r}")
833-
834-
835818
_undef = object()
836819

837820

@@ -947,7 +930,15 @@ def _common(
947930
"(same for all other methods)."
948931
)
949932

950-
copy = _parse_copy_param(self.x, copy)
933+
if copy is False:
934+
if not is_writeable_array(self.x):
935+
raise ValueError("Cannot avoid modifying parameter in place")
936+
elif copy is None:
937+
copy = not is_writeable_array(self.x)
938+
elif copy == "_force_false":
939+
copy = False
940+
elif copy is not True:
941+
raise ValueError(f"Invalid value for copy: {copy!r}")
951942

952943
if copy and is_jax_array(self.x):
953944
# Use JAX's at[]
@@ -956,6 +947,9 @@ def _common(
956947
return getattr(at_, at_op)(*args, **kwargs), None
957948

958949
# Emulate at[] behaviour for non-JAX arrays
950+
# FIXME We blindly expect the output of x.copy() to be always writeable.
951+
# This holds true for read-only numpy arrays, but not necessarily for
952+
# other backends.
959953
x = self.x.copy() if copy else self.x
960954
return None, x
961955

@@ -1047,35 +1041,6 @@ def max(self, y, /, **kwargs):
10471041
return self._iop("max", xp.maximum, y, **kwargs)
10481042

10491043

1050-
def where(condition, x=None, y=None, /, copy: bool | None = True):
1051-
"""Return elements from x when condition is True and from y when
1052-
it is False.
1053-
1054-
This is a wrapper around xp.where that adds the copy parameter:
1055-
1056-
None
1057-
x *may* be modified in place if it is possible and beneficial
1058-
for performance. You should not use x after calling this function.
1059-
True
1060-
Ensure that the inputs are not modified.
1061-
This is the default, in line with np.where.
1062-
False
1063-
Raise ValueError if a copy cannot be avoided.
1064-
"""
1065-
if x is None and y is None:
1066-
xp = array_namespace(condition, use_compat=False)
1067-
return xp.where(condition)
1068-
1069-
copy = _parse_copy_param(x, copy)
1070-
xp = array_namespace(condition, x, y, use_compat=False)
1071-
if copy:
1072-
return xp.where(condition, x, y)
1073-
else:
1074-
condition, x, y = xp.broadcast_arrays(condition, x, y)
1075-
x[condition] = y[condition]
1076-
return x
1077-
1078-
10791044
__all__ = [
10801045
"array_namespace",
10811046
"device",
@@ -1100,7 +1065,6 @@ def where(condition, x=None, y=None, /, copy: bool | None = True):
11001065
"size",
11011066
"to_device",
11021067
"at",
1103-
"where",
11041068
]
11051069

11061070
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']

docs/helper-functions.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ instead, which would be wrapped.
3737
.. autofunction:: to_device
3838
.. autofunction:: size
3939
.. autofunction:: at
40-
.. autofunction:: where
4140

4241
Inspection Helpers
4342
------------------

0 commit comments

Comments
 (0)