Skip to content

Commit dfe80f2

Browse files
committed
fixes
1 parent a1f1b0f commit dfe80f2

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

tests/test_at.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33
import numpy as np
44
import pytest
55

6-
from array_api_compat import array_namespace, at, is_dask_array, is_jax_array, is_writeable_array
6+
from array_api_compat import array_namespace, at, is_dask_array, is_jax_array, is_pydata_sparse_array, is_writeable_array
77
from ._helpers import import_, all_libraries
88

99

10+
def assert_array_equal(a, b):
11+
if is_pydata_sparse_array(a):
12+
a = a.todense()
13+
elif is_dask_array(a):
14+
a = a.compute()
15+
np.testing.assert_array_equal(a, b)
16+
17+
1018
@contextmanager
1119
def assert_copy(x, copy: bool | None):
1220
# dask arrays are writeable, but writing to them will hot-swap the
@@ -21,10 +29,13 @@ def assert_copy(x, copy: bool | None):
2129
x_orig = xp.asarray(x, copy=True)
2230
yield
2331

24-
expect_copy = (
25-
copy if copy is not None else (not is_writeable_array(x) or is_dask_array(x))
26-
)
27-
np.testing.assert_array_equal((x == x_orig).all(), expect_copy)
32+
if is_dask_array(x):
33+
expect_copy = True
34+
elif copy is None:
35+
expect_copy = not is_writeable_array(x)
36+
else:
37+
expect_copy = copy
38+
assert_array_equal((x == x_orig).all(), expect_copy)
2839

2940

3041
@pytest.fixture(params=all_libraries + ["np_readonly"])
@@ -58,15 +69,15 @@ def test_operations(x, copy, op, arg, expect):
5869
with assert_copy(x, copy):
5970
y = getattr(at(x, slice(1, None)), op)(arg, copy=copy)
6071
assert isinstance(y, type(x))
61-
np.testing.assert_equal(y, expect)
72+
assert_array_equal(y, expect)
6273

6374

6475
@pytest.mark.parametrize("copy", [True, False, None])
6576
def test_get(x, copy):
6677
with assert_copy(x, copy):
6778
y = at(x, slice(2)).get(copy=copy)
6879
assert isinstance(y, type(x))
69-
np.testing.assert_array_equal(y, [10, 20])
80+
assert_array_equal(y, [10, 20])
7081
# Let assert_copy test that y is a view or copy
7182
with suppress((TypeError, ValueError)):
7283
y[0] = 40
@@ -97,15 +108,15 @@ def test_get_fancy_indices(x, idx, wrap_index):
97108
with assert_copy(x, True):
98109
y = at(x, [0, 1]).get()
99110
assert isinstance(y, type(x))
100-
np.testing.assert_array_equal(y, [10, 20])
111+
assert_array_equal(y, [10, 20])
101112
# Let assert_copy test that y is a view or copy
102113
with suppress((TypeError, ValueError)):
103114
y[0] = 40
104115

105116
with assert_copy(x, True):
106117
y = at(x, [0, 1]).get(copy=None)
107118
assert isinstance(y, type(x))
108-
np.testing.assert_array_equal(y, [10, 20])
119+
assert_array_equal(y, [10, 20])
109120
# Let assert_copy test that y is a view or copy
110121
with suppress((TypeError, ValueError)):
111122
y[0] = 40
@@ -119,7 +130,7 @@ def test_variant_index_syntax(x, copy):
119130
with assert_copy(x, copy):
120131
y = at(x)[:2].set(40, copy=copy)
121132
assert isinstance(y, type(x))
122-
np.testing.assert_array_equal(y, [40, 40, 30])
133+
assert_array_equal(y, [40, 40, 30])
123134
with pytest.raises(ValueError):
124135
at(x, 1)[2]
125136
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)