3
3
import numpy as np
4
4
import pytest
5
5
6
- from array_api_compat import array_namespace , at , is_dask_array , is_jax_array , is_writeable_array
6
+ from array_api_compat import array_namespace , at , is_dask_array , is_jax_array , is_pydata_sparse_array , is_writeable_array
7
7
from ._helpers import import_ , all_libraries
8
8
9
9
10
+ def assert_array_equal (a , b ):
11
+ if is_pydata_sparse_array (a ):
12
+ a = a .todense ()
13
+ elif is_dask_array (a ):
14
+ a = a .compute ()
15
+ np .testing .assert_array_equal (a , b )
16
+
17
+
10
18
@contextmanager
11
19
def assert_copy (x , copy : bool | None ):
12
20
# dask arrays are writeable, but writing to them will hot-swap the
@@ -21,10 +29,13 @@ def assert_copy(x, copy: bool | None):
21
29
x_orig = xp .asarray (x , copy = True )
22
30
yield
23
31
24
- expect_copy = (
25
- copy if copy is not None else (not is_writeable_array (x ) or is_dask_array (x ))
26
- )
27
- np .testing .assert_array_equal ((x == x_orig ).all (), expect_copy )
32
+ if is_dask_array (x ):
33
+ expect_copy = True
34
+ elif copy is None :
35
+ expect_copy = not is_writeable_array (x )
36
+ else :
37
+ expect_copy = copy
38
+ assert_array_equal ((x == x_orig ).all (), expect_copy )
28
39
29
40
30
41
@pytest .fixture (params = all_libraries + ["np_readonly" ])
@@ -58,15 +69,15 @@ def test_operations(x, copy, op, arg, expect):
58
69
with assert_copy (x , copy ):
59
70
y = getattr (at (x , slice (1 , None )), op )(arg , copy = copy )
60
71
assert isinstance (y , type (x ))
61
- np . testing . assert_equal (y , expect )
72
+ assert_array_equal (y , expect )
62
73
63
74
64
75
@pytest .mark .parametrize ("copy" , [True , False , None ])
65
76
def test_get (x , copy ):
66
77
with assert_copy (x , copy ):
67
78
y = at (x , slice (2 )).get (copy = copy )
68
79
assert isinstance (y , type (x ))
69
- np . testing . assert_array_equal (y , [10 , 20 ])
80
+ assert_array_equal (y , [10 , 20 ])
70
81
# Let assert_copy test that y is a view or copy
71
82
with suppress ((TypeError , ValueError )):
72
83
y [0 ] = 40
@@ -97,15 +108,15 @@ def test_get_fancy_indices(x, idx, wrap_index):
97
108
with assert_copy (x , True ):
98
109
y = at (x , [0 , 1 ]).get ()
99
110
assert isinstance (y , type (x ))
100
- np . testing . assert_array_equal (y , [10 , 20 ])
111
+ assert_array_equal (y , [10 , 20 ])
101
112
# Let assert_copy test that y is a view or copy
102
113
with suppress ((TypeError , ValueError )):
103
114
y [0 ] = 40
104
115
105
116
with assert_copy (x , True ):
106
117
y = at (x , [0 , 1 ]).get (copy = None )
107
118
assert isinstance (y , type (x ))
108
- np . testing . assert_array_equal (y , [10 , 20 ])
119
+ assert_array_equal (y , [10 , 20 ])
109
120
# Let assert_copy test that y is a view or copy
110
121
with suppress ((TypeError , ValueError )):
111
122
y [0 ] = 40
@@ -119,7 +130,7 @@ def test_variant_index_syntax(x, copy):
119
130
with assert_copy (x , copy ):
120
131
y = at (x )[:2 ].set (40 , copy = copy )
121
132
assert isinstance (y , type (x ))
122
- np . testing . assert_array_equal (y , [40 , 40 , 30 ])
133
+ assert_array_equal (y , [40 , 40 , 30 ])
123
134
with pytest .raises (ValueError ):
124
135
at (x , 1 )[2 ]
125
136
with pytest .raises (ValueError ):
0 commit comments