@@ -815,23 +815,6 @@ def is_writeable_array(x):
815
815
return True
816
816
817
817
818
- def _parse_copy_param (x , copy : bool | None | Literal ["_force_false" ]) -> bool :
819
- """Preprocess and validate a copy parameter, in line with the same
820
- parameter in np.asarray(), np.astype(), etc.
821
- """
822
- if copy is True :
823
- return True
824
- elif copy is False :
825
- if not is_writeable_array (x ):
826
- raise ValueError ("Cannot avoid modifying parameter in place" )
827
- return False
828
- elif copy is None :
829
- return not is_writeable_array (x )
830
- elif copy == "_force_false" :
831
- return False
832
- raise ValueError (f"Invalid value for copy: { copy !r} " )
833
-
834
-
835
818
_undef = object ()
836
819
837
820
@@ -947,7 +930,15 @@ def _common(
947
930
"(same for all other methods)."
948
931
)
949
932
950
- copy = _parse_copy_param (self .x , copy )
933
+ if copy is False :
934
+ if not is_writeable_array (self .x ):
935
+ raise ValueError ("Cannot avoid modifying parameter in place" )
936
+ elif copy is None :
937
+ copy = not is_writeable_array (self .x )
938
+ elif copy == "_force_false" :
939
+ copy = False
940
+ elif copy is not True :
941
+ raise ValueError (f"Invalid value for copy: { copy !r} " )
951
942
952
943
if copy and is_jax_array (self .x ):
953
944
# Use JAX's at[]
@@ -956,6 +947,9 @@ def _common(
956
947
return getattr (at_ , at_op )(* args , ** kwargs ), None
957
948
958
949
# Emulate at[] behaviour for non-JAX arrays
950
+ # FIXME We blindly expect the output of x.copy() to be always writeable.
951
+ # This holds true for read-only numpy arrays, but not necessarily for
952
+ # other backends.
959
953
x = self .x .copy () if copy else self .x
960
954
return None , x
961
955
@@ -1047,35 +1041,6 @@ def max(self, y, /, **kwargs):
1047
1041
return self ._iop ("max" , xp .maximum , y , ** kwargs )
1048
1042
1049
1043
1050
- def where (condition , x = None , y = None , / , copy : bool | None = True ):
1051
- """Return elements from x when condition is True and from y when
1052
- it is False.
1053
-
1054
- This is a wrapper around xp.where that adds the copy parameter:
1055
-
1056
- None
1057
- x *may* be modified in place if it is possible and beneficial
1058
- for performance. You should not use x after calling this function.
1059
- True
1060
- Ensure that the inputs are not modified.
1061
- This is the default, in line with np.where.
1062
- False
1063
- Raise ValueError if a copy cannot be avoided.
1064
- """
1065
- if x is None and y is None :
1066
- xp = array_namespace (condition , use_compat = False )
1067
- return xp .where (condition )
1068
-
1069
- copy = _parse_copy_param (x , copy )
1070
- xp = array_namespace (condition , x , y , use_compat = False )
1071
- if copy :
1072
- return xp .where (condition , x , y )
1073
- else :
1074
- condition , x , y = xp .broadcast_arrays (condition , x , y )
1075
- x [condition ] = y [condition ]
1076
- return x
1077
-
1078
-
1079
1044
__all__ = [
1080
1045
"array_namespace" ,
1081
1046
"device" ,
@@ -1100,7 +1065,6 @@ def where(condition, x=None, y=None, /, copy: bool | None = True):
1100
1065
"size" ,
1101
1066
"to_device" ,
1102
1067
"at" ,
1103
- "where" ,
1104
1068
]
1105
1069
1106
1070
_all_ignore = ['inspect' , 'math' , 'operator' , 'warnings' , 'sys' ]
0 commit comments