Skip to content

Commit 20b959e

Browse files
committed
Tweak test_asarray_cross_library
1 parent 8a60892 commit 20b959e

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tests/test_common.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,24 @@ def test_to_device_host(library):
138138
@pytest.mark.parametrize("target_library", is_array_functions.keys())
139139
@pytest.mark.parametrize("source_library", is_array_functions.keys())
140140
def test_asarray_cross_library(source_library, target_library, request):
141-
if source_library == "dask.array" and target_library == "torch":
141+
def _xfail(reason: str) -> None:
142142
# Allow rest of test to execute instead of immediately xfailing
143143
# xref https://github.com/pandas-dev/pandas/issues/38902
144+
request.node.add_marker(pytest.mark.xfail(reason=reason))
144145

146+
if source_library == "dask.array" and target_library == "torch":
145147
# TODO: remove xfail once
146148
# https://github.com/dask/dask/issues/8260 is resolved
147-
request.node.add_marker(
148-
pytest.mark.xfail(reason="Bug in dask raising error on conversion")
149-
)
149+
_xfail(reason="Bug in dask raising error on conversion")
150150
elif (
151151
source_library == "ndonnx"
152152
and target_library not in ("array_api_strict", "ndonnx", "numpy")
153153
):
154-
request.node.add_marker(
155-
pytest.mark.xfail(
156-
reason="The truth value of lazy Array Array(dtype=Boolean) is unknown"
157-
)
158-
)
154+
_xfail(reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
155+
elif source_library == "ndonnx" and target_library == "numpy":
156+
_xfail(reason="produces numpy array of ndonnx scalar arrays")
157+
elif source_library == "jax.numpy" and target_library == "torch":
158+
_xfail(reason="casts int to float")
159159
elif source_library == "cupy" and target_library != "cupy":
160160
# cupy explicitly disallows implicit conversions to CPU
161161
pytest.skip(reason="cupy does not support implicit conversion to CPU")
@@ -166,10 +166,12 @@ def test_asarray_cross_library(source_library, target_library, request):
166166
tgt_lib = import_(target_library, wrapper=True)
167167
is_tgt_type = globals()[is_array_functions[target_library]]
168168

169-
a = src_lib.asarray([1, 2, 3])
169+
a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32)
170170
b = tgt_lib.asarray(a)
171171

172172
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
173+
assert b.dtype == tgt_lib.int32
174+
173175

174176
@pytest.mark.parametrize("library", wrapped_libraries)
175177
def test_asarray_copy(library):

0 commit comments

Comments
 (0)