Skip to content

Commit 78d9a74

Browse files
committed
TST: test is_*_namespace fns
1 parent 733d17c commit 78d9a74

File tree

6 files changed

+58
-7
lines changed

6 files changed

+58
-7
lines changed

tests/test_common.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import array
1111
from numpy.testing import assert_allclose
1212

13-
is_functions = {
13+
is_array_functions = {
1414
'numpy': 'is_numpy_array',
1515
'cupy': 'is_cupy_array',
1616
'torch': 'is_torch_array',
@@ -19,8 +19,18 @@
1919
'sparse': 'is_pydata_sparse_array',
2020
}
2121

22-
@pytest.mark.parametrize('library', is_functions.keys())
23-
@pytest.mark.parametrize('func', is_functions.values())
22+
is_namespace_functions = {
23+
'numpy': 'is_numpy_namespace',
24+
'cupy': 'is_cupy_namespace',
25+
'torch': 'is_torch_namespace',
26+
'dask.array': 'is_dask_namespace',
27+
'jax.numpy': 'is_jax_namespace',
28+
'sparse': 'is_pydata_sparse_namespace',
29+
}
30+
31+
32+
@pytest.mark.parametrize('library', is_array_functions.keys())
33+
@pytest.mark.parametrize('func', is_array_functions.values())
2434
def test_is_xp_array(library, func):
2535
lib = import_(library)
2636
is_func = globals()[func]
@@ -31,6 +41,16 @@ def test_is_xp_array(library, func):
3141

3242
assert is_array_api_obj(x)
3343

44+
45+
@pytest.mark.parametrize('library', is_namespace_functions.keys())
46+
@pytest.mark.parametrize('func', is_namespace_functions.values())
47+
def test_is_xp_namespace(library, func):
48+
lib = import_(library)
49+
is_func = globals()[func]
50+
51+
assert is_func(lib) == (func == is_functions[library])
52+
53+
3454
@pytest.mark.parametrize("library", all_libraries)
3555
def test_device(library):
3656
xp = import_(library, wrapper=True)

tests/test_vendoring.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_vendoring_torch():
2020

2121
uses_torch._test_torch()
2222

23+
2324
def test_vendoring_dask():
2425
from vendor_test import uses_dask
2526
uses_dask._test_dask()

vendor_test/uses_cupy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat import cupy as cp_compat
3+
from .vendored._compat import (
4+
cupy as cp_compat,
5+
is_cupy_array,
6+
is_cupy_namespace,
7+
)
48

59
import cupy as cp
610

@@ -16,3 +20,6 @@ def _test_cupy():
1620
assert isinstance(res, cp.ndarray)
1721

1822
cp.testing.assert_allclose(res, [1., 2., 9.])
23+
24+
assert is_cupy_array(res)
25+
assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat)

vendor_test/uses_dask.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat.dask import array as dask_compat
3+
from .vendored._compat import (
4+
array as dask_compat,
5+
is_dask_array,
6+
is_dask_namespace,
7+
)
48

59
import dask.array as da
610
import numpy as np
@@ -17,3 +21,6 @@ def _test_dask():
1721
assert isinstance(res, da.Array)
1822

1923
np.testing.assert_allclose(res, [1., 2., 9.])
24+
25+
assert is_dask_array(res)
26+
assert is_dask_namespace(da) and is_dask_namespace(dask_compat)

vendor_test/uses_numpy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat import numpy as np_compat
3+
from .vendored._compat import (
4+
is_numpy_array,
5+
is_numpy_namespace,
6+
numpy as np_compat,
7+
)
8+
49

510
import numpy as np
611

@@ -16,3 +21,6 @@ def _test_numpy():
1621
assert isinstance(res, np.ndarray)
1722

1823
np.testing.assert_allclose(res, [1., 2., 9.])
24+
25+
assert is_numpy_array(res)
26+
assert is_numpy_namespace(np) and is_numpy_namespace(np_compat)

vendor_test/uses_torch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Basic test that vendoring works
22

3-
from .vendored._compat import torch as torch_compat
3+
from .vendored._compat import (
4+
is_torch_array,
5+
is_torch_namespace,
6+
torch as torch_compat,
7+
)
48

59
import torch
610

@@ -20,3 +24,7 @@ def _test_torch():
2024
assert isinstance(res, torch.Tensor)
2125

2226
torch.testing.assert_allclose(res, [[1., 2., 3.]])
27+
28+
assert is_torch_array(res)
29+
assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)
30+

0 commit comments

Comments
 (0)