Skip to content

Commit 428ee84

Browse files
committed
Support changed asarray behaviour in dask 2023.12.0
1 parent ee25aae commit 428ee84

File tree

2 files changed

+46
-27
lines changed

2 files changed

+46
-27
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,23 +128,26 @@ def asarray(
128128
129129
See the corresponding documentation in the array library and/or the array API
130130
specification for more details.
131+
132+
.. note::
133+
copy=True means that if you update the output array the input will never
134+
be affected; however the output array may internally hold references to the
135+
input array, preventing deallocation. This kind of implementation detail should
136+
be left at dask's discretion.
131137
"""
132138
if copy is False:
133139
# copy=False is not yet implemented in dask
134-
raise NotImplementedError("copy=False is not yet implemented")
135-
elif copy is True:
136-
if isinstance(obj, da.Array) and dtype is None:
137-
return obj.copy()
138-
# Go through numpy, since dask copy is no-op by default
139-
obj = np.array(obj, dtype=dtype, copy=True)
140-
return da.array(obj, dtype=dtype)
141-
else:
142-
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
143-
obj = np.asarray(obj, dtype=dtype)
144-
return da.from_array(obj)
145-
return obj
140+
raise NotImplementedError("copy=False can't be implemented in dask")
141+
142+
if (
143+
copy is True
144+
and isinstance(obj, da.Array)
145+
and (dtype is None or dtype == obj.dtype)
146+
):
147+
return obj.copy()
148+
149+
return da.asarray(obj, dtype=dtype)
146150

147-
return da.asarray(obj, dtype=dtype, **kwargs)
148151

149152
from dask.array import (
150153
# Element wise aliases

tests/test_common.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_asarray_cross_library(source_library, target_library, request):
112112

113113
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
114114

115+
115116
@pytest.mark.parametrize("library", wrapped_libraries)
116117
def test_asarray_copy(library):
117118
# Note, we have this test here because the test suite currently doesn't
@@ -130,41 +131,57 @@ def test_asarray_copy(library):
130131
else:
131132
supports_copy_false = True
132133

134+
# Tests for copy=True
133135
a = asarray([1])
134136
b = asarray(a, copy=True)
135137
assert is_lib_func(b)
136138
a[0] = 0
137139
assert all(b[0] == 1)
138140
assert all(a[0] == 0)
139141

142+
a = asarray([1])
143+
b = asarray(a, copy=True, dtype=a.dtype)
144+
assert is_lib_func(b)
145+
a[0] = 0
146+
assert all(b[0] == 1)
147+
assert all(a[0] == 0)
148+
149+
# Tests for copy=False
140150
a = asarray([1])
141151
if supports_copy_false:
142152
b = asarray(a, copy=False)
143153
assert is_lib_func(b)
144154
a[0] = 0
145155
assert all(b[0] == 0)
146156
else:
147-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
157+
with pytest.raises(NotImplementedError):
158+
asarray(a, copy=False)
148159

149160
a = asarray([1])
150161
if supports_copy_false:
151-
pytest.raises(ValueError, lambda: asarray(a, copy=False,
152-
dtype=xp.float64))
162+
with pytest.raises(ValueError):
163+
asarray(a, copy=False, dtype=xp.float64)
153164
else:
154-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
165+
with pytest.raises(NotImplementedError):
166+
asarray(a, copy=False, dtype=xp.float64)
155167

168+
# Tests for copy=None
169+
# Do not test whether the buffer is shared or not after copy=None.
170+
# A library should have the freedom to alter its behaviour
171+
# without treating it as a breaking change.
156172
a = asarray([1])
157173
b = asarray(a, copy=None)
158174
assert is_lib_func(b)
159175
a[0] = 0
160-
assert all(b[0] == 0)
176+
assert all((b[0] == 1.0) | (b[0] == 0.0))
161177

162178
a = asarray([1.0], dtype=xp.float32)
163179
assert a.dtype == xp.float32
164180
b = asarray(a, dtype=xp.float64, copy=None)
165181
assert is_lib_func(b)
166182
assert b.dtype == xp.float64
167183
a[0] = 0.0
184+
# dtype change must always trigger a copy
168185
assert all(b[0] == 1.0)
169186

170187
a = asarray([1.0], dtype=xp.float64)
@@ -173,16 +190,18 @@ def test_asarray_copy(library):
173190
assert is_lib_func(b)
174191
assert b.dtype == xp.float64
175192
a[0] = 0.0
176-
assert all(b[0] == 0.0)
193+
assert all((b[0] == 1.0) | (b[0] == 0.0))
177194

178195
# Python built-in types
179196
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
180197
asarray(obj, copy=True) # No error
181198
asarray(obj, copy=None) # No error
182199
if supports_copy_false:
183-
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
200+
with pytest.raises(ValueError):
201+
asarray(obj, copy=False)
184202
else:
185-
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
203+
with pytest.raises(NotImplementedError):
204+
asarray(obj, copy=False)
186205

187206
# Use the standard library array to test the buffer protocol
188207
a = array.array('f', [1.0])
@@ -198,14 +217,11 @@ def test_asarray_copy(library):
198217
a[0] = 0.0
199218
assert all(b[0] == 0.0)
200219
else:
201-
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
220+
with pytest.raises(NotImplementedError):
221+
asarray(a, copy=False)
202222

203223
a = array.array('f', [1.0])
204224
b = asarray(a, copy=None)
205225
assert is_lib_func(b)
206226
a[0] = 0.0
207-
if library == 'cupy':
208-
# A copy is required for libraries where the default device is not CPU
209-
assert all(b[0] == 1.0)
210-
else:
211-
assert all(b[0] == 0.0)
227+
assert all((b[0] == 1.0) | (b[0] == 0.0))

0 commit comments

Comments
 (0)