Skip to content

Commit 98181a8

Browse files
committed
Unit tests WIP
1 parent 53a4ac9 commit 98181a8

File tree

3 files changed

+181
-27
lines changed

3 files changed

+181
-27
lines changed

array_api_compat/common/_helpers.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def __getitem__(self, idx):
900900
and feels more intuitive coming from the JAX documentation.
901901
"""
902902
if self.idx is not _undef:
903-
raise TypeError("Index has already been set")
903+
raise ValueError("Index has already been set")
904904
self.idx = idx
905905
return self
906906

@@ -911,14 +911,12 @@ def _common(
911911
copy: bool | None | Literal["_force_false"] = True,
912912
**kwargs,
913913
):
914-
"""Validate kwargs and perform common prepocessing.
914+
"""Perform common prepocessing.
915915
916916
Returns
917917
-------
918-
If the operation can be resolved by at[],
919-
(return value, None)
920-
Otherwise,
921-
(None, preprocessed x)
918+
If the operation can be resolved by at[], (return value, None)
919+
Otherwise, (None, preprocessed x)
922920
"""
923921
if self.idx is _undef:
924922
raise TypeError(
@@ -929,39 +927,46 @@ def _common(
929927
" at(x)[idx].set(value)\n"
930928
"(same for all other methods)."
931929
)
930+
931+
x = self.x
932932

933933
if copy is False:
934-
if not is_writeable_array(self.x):
935-
raise ValueError("Cannot avoid modifying parameter in place")
934+
if not is_writeable_array(x) or is_dask_array(x):
935+
raise ValueError("Cannot modify parameter in place")
936936
elif copy is None:
937-
copy = not is_writeable_array(self.x)
937+
copy = not is_writeable_array(x)
938938
elif copy == "_force_false":
939939
copy = False
940940
elif copy is not True:
941941
raise ValueError(f"Invalid value for copy: {copy!r}")
942942

943-
if copy and is_jax_array(self.x):
943+
if is_jax_array(x):
944944
# Use JAX's at[]
945-
at_ = self.x.at[self.idx]
945+
at_ = x.at[self.idx]
946946
args = (y,) if y is not _undef else ()
947947
return getattr(at_, at_op)(*args, **kwargs), None
948948

949949
# Emulate at[] behaviour for non-JAX arrays
950-
# FIXME We blindly expect the output of x.copy() to be always writeable.
951-
# This holds true for read-only numpy arrays, but not necessarily for
952-
# other backends.
953-
x = self.x.copy() if copy else self.x
950+
if copy:
951+
# FIXME We blindly expect the output of x.copy() to be always writeable.
952+
# This holds true for read-only numpy arrays, but not necessarily for
953+
# other backends.
954+
xp = get_namespace(x)
955+
x = xp.asarray(x, copy=True)
956+
954957
return None, x
955958

956959
def get(self, copy: bool | None = True, **kwargs):
957-
"""Return x[idx]. In addition to plain __getitem__, this allows ensuring
958-
that the output is (not) a copy and kwargs are passed to the backend."""
959-
# Special case when xp=numpy and idx is a fancy index
960-
# If copy is not False, avoid an unnecessary double copy.
961-
# if copy is forced to False, raise.
962-
if is_numpy_array(self.x) and (
960+
"""
961+
Return x[idx]. In addition to plain __getitem__, this allows ensuring
962+
that the output is (not) a copy and kwargs are passed to the backend.
963+
"""
964+
# __getitem__ with a fancy index always returns a copy.
965+
# Avoid an unnecessary double copy.
966+
# If copy is forced to False, raise.
967+
if (
963968
isinstance(self.idx, (list, tuple))
964-
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
969+
or (is_array_api_obj(self.idx) and self.idx.dtype.kind in "biu")
965970
):
966971
if copy is False:
967972
raise ValueError(
@@ -1032,13 +1037,15 @@ def power(self, y, /, **kwargs):
10321037

10331038
def min(self, y, /, **kwargs):
10341039
"""x[idx] = minimum(x[idx], y)"""
1035-
xp = array_namespace(self.x)
1036-
return self._iop("min", xp.minimum, y, **kwargs)
1040+
import numpy as np
1041+
1042+
return self._iop("min", np.minimum, y, **kwargs)
10371043

10381044
def max(self, y, /, **kwargs):
10391045
"""x[idx] = maximum(x[idx], y)"""
1040-
xp = array_namespace(self.x)
1041-
return self._iop("max", xp.maximum, y, **kwargs)
1046+
import numpy as np
1047+
1048+
return self._iop("max", np.maximum, y, **kwargs)
10421049

10431050

10441051
__all__ = [

tests/test_at.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from contextlib import contextmanager, suppress
2+
3+
import numpy as np
4+
import pytest
5+
6+
from array_api_compat import array_namespace, at, is_dask_array, is_jax_array, is_writeable_array
7+
from ._helpers import import_, all_libraries
8+
9+
10+
@contextmanager
11+
def assert_copy(x, copy: bool | None):
12+
# dask arrays are writeable, but writing to them will hot-swap the
13+
# dask graph inside the collection so that anything that references
14+
# the original graph, i.e. the input collection, won't be mutated.
15+
if copy is False and (not is_writeable_array(x) or is_dask_array(x)):
16+
with pytest.raises((TypeError, ValueError)):
17+
yield
18+
return
19+
20+
xp = array_namespace(x)
21+
x_orig = xp.asarray(x, copy=True)
22+
yield
23+
24+
expect_copy = (
25+
copy if copy is not None else (not is_writeable_array(x) or is_dask_array(x))
26+
)
27+
np.testing.assert_array_equal((x == x_orig).all(), expect_copy)
28+
29+
30+
@pytest.fixture(params=all_libraries + ["np_readonly"])
31+
def x(request):
32+
library = request.param
33+
if library == "np_readonly":
34+
x = np.asarray([10, 20, 30])
35+
x.flags.writeable = False
36+
else:
37+
lib = import_(library)
38+
x = lib.asarray([10, 20, 30])
39+
return x
40+
41+
42+
@pytest.mark.parametrize("copy", [True, False, None])
43+
@pytest.mark.parametrize(
44+
"op,arg,expect",
45+
[
46+
("apply", np.negative, [10, -20, 30]),
47+
("set", 40, [10, 40, 30]),
48+
("add", 40, [10, 60, 30]),
49+
("subtract", 100, [10, -80, 30]),
50+
("multiply", 2, [10, 40, 30]),
51+
("divide", 3, [10, 6, 30]),
52+
("power", 2, [10, 400, 30]),
53+
("min", 15, [10, 15, 30]),
54+
("min", 25, [10, 20, 30]),
55+
("max", 15, [10, 20, 30]),
56+
("max", 25, [10, 25, 30]),
57+
],
58+
)
59+
def test_operations(x, copy, op, arg, expect):
60+
with assert_copy(x, copy):
61+
y = getattr(at(x, 1), op)(arg, copy=copy)
62+
assert isinstance(y, type(x))
63+
np.testing.assert_equal(y, expect)
64+
65+
66+
@pytest.mark.parametrize("copy", [True, False, None])
67+
def test_get(x, copy):
68+
with assert_copy(x, copy):
69+
y = at(x, slice(2)).get(copy=copy)
70+
assert isinstance(y, type(x))
71+
np.testing.assert_array_equal(y, [10, 20])
72+
# Let assert_copy test that y is a view or copy
73+
with suppress((TypeError, ValueError)):
74+
y[0] = 40
75+
76+
77+
@pytest.mark.parametrize(
78+
"idx",
79+
[
80+
[0, 1],
81+
((0, 1), ),
82+
np.array([0, 1], dtype="i1"),
83+
np.array([0, 1], dtype="u1"),
84+
[True, True, False],
85+
(True, True, False),
86+
np.array([True, True, False]),
87+
],
88+
)
89+
@pytest.mark.parametrize("wrap_index", [True, False])
90+
def test_get_fancy_indices(x, idx, wrap_index):
91+
"""get() with a fancy index always returns a copy"""
92+
if not wrap_index and is_jax_array(x) and isinstance(idx, (list, tuple)):
93+
pytest.skip("JAX fancy indices must always be arrays")
94+
95+
if wrap_index:
96+
xp = array_namespace(x)
97+
idx = xp.asarray(idx)
98+
99+
with assert_copy(x, True):
100+
y = at(x, [0, 1]).get()
101+
assert isinstance(y, type(x))
102+
np.testing.assert_array_equal(y, [10, 20])
103+
# Let assert_copy test that y is a view or copy
104+
with suppress((TypeError, ValueError)):
105+
y[0] = 40
106+
107+
with assert_copy(x, True):
108+
y = at(x, [0, 1]).get(copy=None)
109+
assert isinstance(y, type(x))
110+
np.testing.assert_array_equal(y, [10, 20])
111+
# Let assert_copy test that y is a view or copy
112+
with suppress((TypeError, ValueError)):
113+
y[0] = 40
114+
115+
with pytest.raises(ValueError, match="fancy index"):
116+
at(x, [0, 1]).get(copy=False)
117+
118+
119+
@pytest.mark.parametrize("copy", [True, False, None])
120+
def test_variant_index_syntax(x, copy):
121+
with assert_copy(x, copy):
122+
y = at(x)[:2].set(40, copy=copy)
123+
assert isinstance(y, type(x))
124+
np.testing.assert_array_equal(y, [40, 40, 30])
125+
with pytest.raises(ValueError):
126+
at(x, 1)[2]
127+
with pytest.raises(ValueError):
128+
at(x)[1][2]

tests/test_common.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import is_array_api_obj, device, to_device
8+
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
99

1010
from ._helpers import import_, wrapped_libraries, all_libraries
1111

@@ -55,6 +55,25 @@ def test_is_xp_namespace(library, func):
5555
assert is_func(lib) == (func == is_namespace_functions[library])
5656

5757

58+
@pytest.mark.parametrize("library", all_libraries)
59+
def test_is_writeable_array(library):
60+
lib = import_(library)
61+
x = lib.asarray([1, 2, 3])
62+
if is_writeable_array(x):
63+
x[1] = 4
64+
np.testing.assert_equal(np.asarray(x), [1, 4, 3])
65+
else:
66+
with pytest.raises((TypeError, ValueError)):
67+
x[1] = 4
68+
69+
70+
def test_is_writeable_array_numpy():
71+
x = np.asarray([1, 2, 3])
72+
assert is_writeable_array(x)
73+
x.flags.writeable = False
74+
assert not is_writeable_array(x)
75+
76+
5877
@pytest.mark.parametrize("library", all_libraries)
5978
def test_device(library):
6079
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)