@@ -112,6 +112,7 @@ def test_asarray_cross_library(source_library, target_library, request):
112
112
113
113
assert is_tgt_type (b ), f"Expected { b } to be a { tgt_lib .ndarray } , but was { type (b )} "
114
114
115
+
115
116
@pytest .mark .parametrize ("library" , wrapped_libraries )
116
117
def test_asarray_copy (library ):
117
118
# Note, we have this test here because the test suite currently doesn't
@@ -130,41 +131,57 @@ def test_asarray_copy(library):
130
131
else :
131
132
supports_copy_false = True
132
133
134
+ # Tests for copy=True
133
135
a = asarray ([1 ])
134
136
b = asarray (a , copy = True )
135
137
assert is_lib_func (b )
136
138
a [0 ] = 0
137
139
assert all (b [0 ] == 1 )
138
140
assert all (a [0 ] == 0 )
139
141
142
+ a = asarray ([1 ])
143
+ b = asarray (a , copy = True , dtype = a .dtype )
144
+ assert is_lib_func (b )
145
+ a [0 ] = 0
146
+ assert all (b [0 ] == 1 )
147
+ assert all (a [0 ] == 0 )
148
+
149
+ # Tests for copy=False
140
150
a = asarray ([1 ])
141
151
if supports_copy_false :
142
152
b = asarray (a , copy = False )
143
153
assert is_lib_func (b )
144
154
a [0 ] = 0
145
155
assert all (b [0 ] == 0 )
146
156
else :
147
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
157
+ with pytest .raises (NotImplementedError ):
158
+ asarray (a , copy = False )
148
159
149
160
a = asarray ([1 ])
150
161
if supports_copy_false :
151
- pytest .raises (ValueError , lambda : asarray ( a , copy = False ,
152
- dtype = xp .float64 ) )
162
+ with pytest .raises (ValueError ):
163
+ asarray ( a , copy = False , dtype = xp .float64 )
153
164
else :
154
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False , dtype = xp .float64 ))
165
+ with pytest .raises (NotImplementedError ):
166
+ asarray (a , copy = False , dtype = xp .float64 )
155
167
168
+ # Tests for copy=None
169
+ # Do not test whether the buffer is shared or not after copy=None.
170
+ # A library should have the freedom to alter its behaviour
171
+ # without treating it as a breaking change.
156
172
a = asarray ([1 ])
157
173
b = asarray (a , copy = None )
158
174
assert is_lib_func (b )
159
175
a [0 ] = 0
160
- assert all (b [0 ] == 0 )
176
+ assert all (( b [0 ] == 1.0 ) | ( b [ 0 ] == 0.0 ) )
161
177
162
178
a = asarray ([1.0 ], dtype = xp .float32 )
163
179
assert a .dtype == xp .float32
164
180
b = asarray (a , dtype = xp .float64 , copy = None )
165
181
assert is_lib_func (b )
166
182
assert b .dtype == xp .float64
167
183
a [0 ] = 0.0
184
+ # dtype change must always trigger a copy
168
185
assert all (b [0 ] == 1.0 )
169
186
170
187
a = asarray ([1.0 ], dtype = xp .float64 )
@@ -173,16 +190,18 @@ def test_asarray_copy(library):
173
190
assert is_lib_func (b )
174
191
assert b .dtype == xp .float64
175
192
a [0 ] = 0.0
176
- assert all (b [0 ] == 0.0 )
193
+ assert all (( b [0 ] == 1.0 ) | ( b [ 0 ] == 0.0 ) )
177
194
178
195
# Python built-in types
179
196
for obj in [True , 0 , 0.0 , 0j , [0 ], [[0 ]]]:
180
197
asarray (obj , copy = True ) # No error
181
198
asarray (obj , copy = None ) # No error
182
199
if supports_copy_false :
183
- pytest .raises (ValueError , lambda : asarray (obj , copy = False ))
200
+ with pytest .raises (ValueError ):
201
+ asarray (obj , copy = False )
184
202
else :
185
- pytest .raises (NotImplementedError , lambda : asarray (obj , copy = False ))
203
+ with pytest .raises (NotImplementedError ):
204
+ asarray (obj , copy = False )
186
205
187
206
# Use the standard library array to test the buffer protocol
188
207
a = array .array ('f' , [1.0 ])
@@ -198,14 +217,11 @@ def test_asarray_copy(library):
198
217
a [0 ] = 0.0
199
218
assert all (b [0 ] == 0.0 )
200
219
else :
201
- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
220
+ with pytest .raises (NotImplementedError ):
221
+ asarray (a , copy = False )
202
222
203
223
a = array .array ('f' , [1.0 ])
204
224
b = asarray (a , copy = None )
205
225
assert is_lib_func (b )
206
226
a [0 ] = 0.0
207
- if library == 'cupy' :
208
- # A copy is required for libraries where the default device is not CPU
209
- assert all (b [0 ] == 1.0 )
210
- else :
211
- assert all (b [0 ] == 0.0 )
227
+ assert all ((b [0 ] == 1.0 ) | (b [0 ] == 0.0 ))
0 commit comments