Skip to content

Commit adab072

Browse files
committed
Abstractions for read-only arrays
1 parent ee25aae commit adab072

File tree

5 files changed

+447
-7
lines changed

5 files changed

+447
-7
lines changed

array_api_compat/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
NumPy Array API compatibility library
33
4-
This is a small wrapper around NumPy and CuPy that is compatible with the
5-
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
6-
https://numpy.org/neps/nep-0047-array-api-standard.html.
4+
This is a small wrapper around NumPy, CuPy, JAX and others that is compatible
5+
with the Array API standard https://data-apis.org/array-api/latest/.
6+
See also NEP 47 https://numpy.org/neps/nep-0047-array-api-standard.html.
77
88
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with

array_api_compat/common/_helpers.py

Lines changed: 255 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
"""
88
from __future__ import annotations
99

10+
import operator
1011
from typing import TYPE_CHECKING
1112

1213
if TYPE_CHECKING:
13-
from typing import Optional, Union, Any
14+
from typing import Callable, Literal, Optional, Union, Any
1415
from ._typing import Array, Device
1516

1617
import sys
@@ -91,7 +92,7 @@ def is_cupy_array(x):
9192
import cupy as cp
9293

9394
# TODO: Should we reject ndarray subclasses?
94-
return isinstance(x, (cp.ndarray, cp.generic))
95+
return isinstance(x, cp.ndarray)
9596

9697
def is_torch_array(x):
9798
"""
@@ -787,6 +788,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
787788
return x
788789
return x.to_device(device, stream=stream)
789790

791+
790792
def size(x):
791793
"""
792794
Return the total number of elements of x.
@@ -801,6 +803,254 @@ def size(x):
801803
return None
802804
return math.prod(x.shape)
803805

806+
807+
def is_writeable_array(x) -> bool:
808+
"""
809+
Return False if x.__setitem__ is expected to raise; True otherwise
810+
"""
811+
if is_numpy_array(x):
812+
return x.flags.writeable
813+
if is_jax_array(x) or is_pydata_sparse_array(x):
814+
return False
815+
return True
816+
817+
818+
def _is_fancy_index(idx) -> bool:
819+
if not isinstance(idx, tuple):
820+
idx = (idx,)
821+
return any(
822+
isinstance(i, (list, tuple)) or is_array_api_obj(i)
823+
for i in idx
824+
)
825+
826+
827+
_undef = object()
828+
829+
830+
class at:
831+
"""
832+
Update operations for read-only arrays.
833+
834+
This implements ``jax.numpy.ndarray.at`` for all backends.
835+
Writeable arrays may be updated in place; you should not rely on it.
836+
837+
Keyword arguments (e.g. ``indices_are_sorted``) are passed to JAX and are
838+
quietly ignored for backends that don't support them.
839+
840+
Additionally, this introduces support for the `copy` keyword for all backends:
841+
842+
None
843+
x *may* be modified in place if it is possible and beneficial
844+
for performance. You should not use x after calling this function.
845+
True
846+
Ensure that the inputs are not modified. This is the default.
847+
False
848+
Raise ValueError if a copy cannot be avoided.
849+
850+
Examples
851+
--------
852+
Given either of these equivalent expressions::
853+
854+
x = at(x)[1].add(2, copy=None)
855+
x = at(x, 1).add(2, copy=None)
856+
857+
If x is a JAX array, they are the same as::
858+
859+
x = x.at[1].add(2)
860+
861+
If x is a read-only numpy array, they are the same as::
862+
863+
x = x.copy()
864+
x[1] += 2
865+
866+
Otherwise, they are the same as::
867+
868+
x[1] += 2
869+
870+
Warning
871+
-------
872+
When you use copy=None, you should always immediately overwrite
873+
the parameter array::
874+
875+
x = at(x, 0).set(2, copy=None)
876+
877+
The anti-pattern below must be avoided, as it will result in different behaviour
878+
on read-only versus writeable arrays:
879+
880+
x = xp.asarray([0, 0, 0])
881+
y = at(x, 0).set(2, copy=None)
882+
z = at(x, 1).set(3, copy=None)
883+
884+
In the above example, y == [2, 0, 0] and z == [0, 3, 0] when x is read-only,
885+
whereas y == z == [2, 3, 0] when x is writeable!
886+
887+
Caveat
888+
------
889+
Sparse does not support update methods yet.
890+
891+
Caveat
892+
------
893+
The behaviour of update methods when the index is an array of integers which
894+
contains multiple occurrences of the same index is undefined.
895+
896+
**Undefined behaviour:** ``at(x, [0, 0]).set(2)``
897+
898+
See Also
899+
--------
900+
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
901+
"""
902+
903+
__slots__ = ("x", "idx")
904+
905+
def __init__(self, x, idx=_undef):
906+
self.x = x
907+
self.idx = idx
908+
909+
def __getitem__(self, idx):
910+
"""
911+
Allow for the alternate syntax ``at(x)[start:stop:step]``,
912+
which looks prettier than ``at(x, slice(start, stop, step))``
913+
and feels more intuitive coming from the JAX documentation.
914+
"""
915+
if self.idx is not _undef:
916+
raise ValueError("Index has already been set")
917+
self.idx = idx
918+
return self
919+
920+
def _common(
921+
self,
922+
at_op: str,
923+
y=_undef,
924+
copy: bool | None | Literal["_force_false"] = True,
925+
**kwargs,
926+
):
927+
"""Perform common prepocessing.
928+
929+
Returns
930+
-------
931+
If the operation can be resolved by at[], (return value, None)
932+
Otherwise, (None, preprocessed x)
933+
"""
934+
if self.idx is _undef:
935+
raise TypeError(
936+
"Index has not been set.\n"
937+
"Usage: either\n"
938+
" at(x, idx).set(value)\n"
939+
"or\n"
940+
" at(x)[idx].set(value)\n"
941+
"(same for all other methods)."
942+
)
943+
944+
x = self.x
945+
946+
if copy is False:
947+
if not is_writeable_array(x) or is_dask_array(x):
948+
raise ValueError("Cannot modify parameter in place")
949+
elif copy is None:
950+
copy = not is_writeable_array(x)
951+
elif copy == "_force_false":
952+
copy = False
953+
elif copy is not True:
954+
raise ValueError(f"Invalid value for copy: {copy!r}")
955+
956+
if is_jax_array(x):
957+
# Use JAX's at[]
958+
at_ = x.at[self.idx]
959+
args = (y,) if y is not _undef else ()
960+
return getattr(at_, at_op)(*args, **kwargs), None
961+
962+
# Emulate at[] behaviour for non-JAX arrays
963+
if copy:
964+
# FIXME We blindly expect the output of x.copy() to be always writeable.
965+
# This holds true for read-only numpy arrays, but not necessarily for
966+
# other backends.
967+
xp = array_namespace(x)
968+
x = xp.asarray(x, copy=True)
969+
970+
return None, x
971+
972+
def get(self, copy: bool | None = True, **kwargs):
973+
"""
974+
Return x[idx]. In addition to plain __getitem__, this allows ensuring
975+
that the output is either a copy or a view; it also allows passing
976+
kwargs to the backend.
977+
"""
978+
# __getitem__ with a fancy index always returns a copy.
979+
# Avoid an unnecessary double copy.
980+
# If copy is forced to False, raise.
981+
if _is_fancy_index(self.idx):
982+
if copy is False:
983+
raise TypeError(
984+
"Indexing a numpy array with a fancy index always "
985+
"results in a copy"
986+
)
987+
# Skip copy inside _common, even if array is not writeable
988+
copy = "_force_false" # type: ignore
989+
990+
res, x = self._common("get", copy=copy, **kwargs)
991+
if res is not None:
992+
return res
993+
return x[self.idx]
994+
995+
def set(self, y, /, **kwargs):
996+
"""x[idx] = y"""
997+
res, x = self._common("set", y, **kwargs)
998+
if res is not None:
999+
return res
1000+
x[self.idx] = y
1001+
return x
1002+
1003+
def _iop(
1004+
self, at_op: str, elwise_op: Callable[[Array, Array], Array], y: Array, **kwargs
1005+
):
1006+
"""x[idx] += y or equivalent in-place operation on a subset of x
1007+
1008+
which is the same as saying
1009+
x[idx] = x[idx] + y
1010+
Note that this is not the same as
1011+
operator.iadd(x[idx], y)
1012+
Consider for example when x is a numpy array and idx is a fancy index, which
1013+
triggers a deep copy on __getitem__.
1014+
"""
1015+
res, x = self._common(at_op, y, **kwargs)
1016+
if res is not None:
1017+
return res
1018+
x[self.idx] = elwise_op(x[self.idx], y)
1019+
return x
1020+
1021+
def add(self, y, /, **kwargs):
1022+
"""x[idx] += y"""
1023+
return self._iop("add", operator.add, y, **kwargs)
1024+
1025+
def subtract(self, y, /, **kwargs):
1026+
"""x[idx] -= y"""
1027+
return self._iop("subtract", operator.sub, y, **kwargs)
1028+
1029+
def multiply(self, y, /, **kwargs):
1030+
"""x[idx] *= y"""
1031+
return self._iop("multiply", operator.mul, y, **kwargs)
1032+
1033+
def divide(self, y, /, **kwargs):
1034+
"""x[idx] /= y"""
1035+
return self._iop("divide", operator.truediv, y, **kwargs)
1036+
1037+
def power(self, y, /, **kwargs):
1038+
"""x[idx] **= y"""
1039+
return self._iop("power", operator.pow, y, **kwargs)
1040+
1041+
def min(self, y, /, **kwargs):
1042+
"""x[idx] = minimum(x[idx], y)"""
1043+
xp = array_namespace(self.x)
1044+
y = xp.asarray(y)
1045+
return self._iop("min", xp.minimum, y, **kwargs)
1046+
1047+
def max(self, y, /, **kwargs):
1048+
"""x[idx] = maximum(x[idx], y)"""
1049+
xp = array_namespace(self.x)
1050+
y = xp.asarray(y)
1051+
return self._iop("max", xp.maximum, y, **kwargs)
1052+
1053+
8041054
__all__ = [
8051055
"array_namespace",
8061056
"device",
@@ -821,8 +1071,10 @@ def size(x):
8211071
"is_ndonnx_namespace",
8221072
"is_pydata_sparse_array",
8231073
"is_pydata_sparse_namespace",
1074+
"is_writeable_array",
8241075
"size",
8251076
"to_device",
1077+
"at",
8261078
]
8271079

828-
_all_ignore = ['sys', 'math', 'inspect', 'warnings']
1080+
_all_ignore = ['inspect', 'math', 'operator', 'warnings', 'sys']

docs/helper-functions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ instead, which would be wrapped.
3636
.. autofunction:: device
3737
.. autofunction:: to_device
3838
.. autofunction:: size
39+
.. autofunction:: at
3940

4041
Inspection Helpers
4142
------------------
@@ -51,6 +52,7 @@ yet.
5152
.. autofunction:: is_jax_array
5253
.. autofunction:: is_pydata_sparse_array
5354
.. autofunction:: is_ndonnx_array
55+
.. autofunction:: is_writeable_array
5456
.. autofunction:: is_numpy_namespace
5557
.. autofunction:: is_cupy_namespace
5658
.. autofunction:: is_torch_namespace

0 commit comments

Comments
 (0)