@@ -788,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788
788
return x
789
789
return x .to_device (device , stream = stream )
790
790
791
+
791
792
def size (x ):
792
793
"""
793
794
Return the total number of elements of x.
@@ -802,6 +803,7 @@ def size(x):
802
803
return None
803
804
return math .prod (x .shape )
804
805
806
+
805
807
def is_writeable_array (x ):
806
808
"""
807
809
Return False if x.__setitem__ is expected to raise; True otherwise
@@ -812,6 +814,7 @@ def is_writeable_array(x):
812
814
return False
813
815
return True
814
816
817
+
815
818
def _parse_copy_param (x , copy : bool | None | Literal ["_force_false" ]) -> bool :
816
819
"""Preprocess and validate a copy parameter, in line with the same
817
820
parameter in np.asarray(), np.astype(), etc.
@@ -827,8 +830,10 @@ def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
827
830
raise ValueError (f"Invalid value for copy: { copy !r} " )
828
831
return copy
829
832
833
+
830
834
_undef = object ()
831
835
836
+
832
837
class at :
833
838
"""
834
839
Update operations for read-only arrays.
@@ -897,6 +902,7 @@ class at:
897
902
--------
898
903
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
899
904
"""
905
+
900
906
__slots__ = ("x" , "idx" )
901
907
902
908
def __init__ (self , x , idx = _undef ):
@@ -945,7 +951,7 @@ def _common(
945
951
if copy and is_jax_array (self .x ):
946
952
# Use JAX's at[]
947
953
at_ = self .x .at [self .idx ]
948
- args = (y , ) if y is not _undef else ()
954
+ args = (y ,) if y is not _undef else ()
949
955
return getattr (at_ , at_op )(* args , ** kwargs ), None
950
956
951
957
# Emulate at[] behaviour for non-JAX arrays
@@ -958,12 +964,9 @@ def get(self, copy: bool | None = True, **kwargs):
958
964
# Special case when xp=numpy and idx is a fancy index
959
965
# If copy is not False, avoid an unnecessary double copy.
960
966
# if copy is forced to False, raise.
961
- if (
962
- is_numpy_array (self .x )
963
- and (
964
- isinstance (self .idx , (list , tuple ))
965
- or (is_numpy_array (self .idx ) and self .idx .dtype .kind in "biu" )
966
- )
967
+ if is_numpy_array (self .x ) and (
968
+ isinstance (self .idx , (list , tuple ))
969
+ or (is_numpy_array (self .idx ) and self .idx .dtype .kind in "biu" )
967
970
):
968
971
if copy is False :
969
972
raise ValueError (
@@ -994,12 +997,14 @@ def apply(self, ufunc, /, **kwargs):
994
997
ufunc .at (x , self .idx )
995
998
return x
996
999
997
- def _iop (self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs ):
1000
+ def _iop (
1001
+ self , at_op : str , elwise_op : Callable [[Array , Array ], Array ], y : Array , ** kwargs
1002
+ ):
998
1003
"""x[idx] += y or equivalent in-place operation on a subset of x
999
1004
1000
1005
which is the same as saying
1001
1006
x[idx] = x[idx] + y
1002
- Note that this is not the same as
1007
+ Note that this is not the same as
1003
1008
operator.iadd(x[idx], y)
1004
1009
Consider for example when x is a numpy array and idx is a fancy index, which
1005
1010
triggers a deep copy on __getitem__.
@@ -1017,11 +1022,11 @@ def add(self, y, /, **kwargs):
1017
1022
def subtract (self , y , / , ** kwargs ):
1018
1023
"""x[idx] -= y"""
1019
1024
return self ._iop ("subtract" , operator .sub , y , ** kwargs )
1020
-
1025
+
1021
1026
def multiply (self , y , / , ** kwargs ):
1022
1027
"""x[idx] *= y"""
1023
1028
return self ._iop ("multiply" , operator .mul , y , ** kwargs )
1024
-
1029
+
1025
1030
def divide (self , y , / , ** kwargs ):
1026
1031
"""x[idx] /= y"""
1027
1032
return self ._iop ("divide" , operator .truediv , y , ** kwargs )
@@ -1040,9 +1045,10 @@ def max(self, y, /, **kwargs):
1040
1045
xp = array_namespace (self .x )
1041
1046
return self ._iop ("max" , xp .maximum , y , ** kwargs )
1042
1047
1048
+
1043
1049
def where (condition , x = None , y = None , / , copy : bool | None = True ):
1044
1050
"""Return elements from x when condition is True and from y when
1045
- it is False.
1051
+ it is False.
1046
1052
1047
1053
This is a wrapper around xp.where that adds the copy parameter:
1048
1054
0 commit comments