Skip to content

Commit b14f05e

Browse files
committed
check the indexing behavior for scalars
1 parent 72d7635 commit b14f05e

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

xarray_array_testing/indexing.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from contextlib import nullcontext
2+
3+
import hypothesis.strategies as st
4+
import xarray.testing.strategies as xrst
5+
from hypothesis import given
6+
7+
from xarray_array_testing.base import DuckArrayTestMixin
8+
9+
10+
@st.composite
11+
def scalar_indexers(draw, sizes):
12+
# TODO: try to define this using builds and flatmap
13+
possible_indexers = {
14+
dim: st.integers(min_value=-size, max_value=size - 1)
15+
for dim, size in sizes.items()
16+
}
17+
indexers = xrst.unique_subset_of(possible_indexers)
18+
return {dim: draw(indexer) for dim, indexer in draw(indexers).items()}
19+
20+
21+
class IndexingTests(DuckArrayTestMixin):
22+
@staticmethod
23+
def expected_errors(op, **parameters):
24+
return nullcontext()
25+
26+
@given(st.data())
27+
def test_variable_scalar_isel(self, data):
28+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
29+
indexers = data.draw(scalar_indexers(sizes=variable.sizes))
30+
31+
with self.expected_errors("scalar_isel", variable=variable):
32+
actual = variable.isel(indexers).data
33+
34+
raw_indexers = {
35+
dim: indexers.get(dim, slice(None)) for dim in variable.dims
36+
}
37+
expected = variable.data[*raw_indexers.values()]
38+
39+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
40+
self.assert_equal(actual, expected)

xarray_array_testing/tests/test_numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from xarray_array_testing.base import DuckArrayTestMixin
77
from xarray_array_testing.creation import CreationTests
8+
from xarray_array_testing.indexing import IndexingTests
89
from xarray_array_testing.reduction import ReductionTests
910

1011

@@ -32,3 +33,7 @@ class TestCreationNumpy(CreationTests, NumpyTestMixin):
3233

3334
class TestReductionNumpy(ReductionTests, NumpyTestMixin):
3435
pass
36+
37+
38+
class TestIndexingNumpy(IndexingTests, NumpyTestMixin):
39+
pass

0 commit comments

Comments
 (0)