Skip to content

Commit ac6a786

Browse files
committed
refactor the indexers strategy
1 parent b14f05e commit ac6a786

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

xarray_array_testing/indexing.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
from xarray_array_testing.base import DuckArrayTestMixin
88

99

10+
def scalar_indexer(size):
11+
return st.integers(min_value=-size, max_value=size - 1)
12+
13+
1014
@st.composite
11-
def scalar_indexers(draw, sizes):
12-
# TODO: try to define this using builds and flatmap
13-
possible_indexers = {
14-
dim: st.integers(min_value=-size, max_value=size - 1)
15-
for dim, size in sizes.items()
16-
}
17-
indexers = xrst.unique_subset_of(possible_indexers)
18-
return {dim: draw(indexer) for dim, indexer in draw(indexers).items()}
15+
def indexers(draw, sizes, indexer_strategy_fn):
16+
possible_indexers = {dim: indexer_strategy_fn(size) for dim, size in sizes.items()}
17+
indexers = draw(xrst.unique_subset_of(possible_indexers))
18+
return {dim: draw(indexer) for dim, indexer in indexers.items()}
1919

2020

2121
class IndexingTests(DuckArrayTestMixin):
@@ -24,16 +24,14 @@ def expected_errors(op, **parameters):
2424
return nullcontext()
2525

2626
@given(st.data())
27-
def test_variable_scalar_isel(self, data):
27+
def test_variable_isel_scalars(self, data):
2828
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
29-
indexers = data.draw(scalar_indexers(sizes=variable.sizes))
29+
idx = data.draw(indexers(variable.sizes, scalar_indexer))
3030

31-
with self.expected_errors("scalar_isel", variable=variable):
32-
actual = variable.isel(indexers).data
31+
with self.expected_errors("isel_scalars", variable=variable):
32+
actual = variable.isel(idx).data
3333

34-
raw_indexers = {
35-
dim: indexers.get(dim, slice(None)) for dim in variable.dims
36-
}
34+
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
3735
expected = variable.data[*raw_indexers.values()]
3836

3937
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"

0 commit comments

Comments
 (0)