7
7
from xarray_array_testing .base import DuckArrayTestMixin
8
8
9
9
10
+ def scalar_indexer (size ):
11
+ return st .integers (min_value = - size , max_value = size - 1 )
12
+
13
+
10
14
@st .composite
11
- def scalar_indexers (draw , sizes ):
12
- # TODO: try to define this using builds and flatmap
13
- possible_indexers = {
14
- dim : st .integers (min_value = - size , max_value = size - 1 )
15
- for dim , size in sizes .items ()
16
- }
17
- indexers = xrst .unique_subset_of (possible_indexers )
18
- return {dim : draw (indexer ) for dim , indexer in draw (indexers ).items ()}
15
+ def indexers (draw , sizes , indexer_strategy_fn ):
16
+ possible_indexers = {dim : indexer_strategy_fn (size ) for dim , size in sizes .items ()}
17
+ indexers = draw (xrst .unique_subset_of (possible_indexers ))
18
+ return {dim : draw (indexer ) for dim , indexer in indexers .items ()}
19
19
20
20
21
21
class IndexingTests (DuckArrayTestMixin ):
@@ -24,16 +24,14 @@ def expected_errors(op, **parameters):
24
24
return nullcontext ()
25
25
26
26
@given (st .data ())
27
- def test_variable_scalar_isel (self , data ):
27
+ def test_variable_isel_scalars (self , data ):
28
28
variable = data .draw (xrst .variables (array_strategy_fn = self .array_strategy_fn ))
29
- indexers = data .draw (scalar_indexers ( sizes = variable .sizes ))
29
+ idx = data .draw (indexers ( variable .sizes , scalar_indexer ))
30
30
31
- with self .expected_errors ("scalar_isel " , variable = variable ):
32
- actual = variable .isel (indexers ).data
31
+ with self .expected_errors ("isel_scalars " , variable = variable ):
32
+ actual = variable .isel (idx ).data
33
33
34
- raw_indexers = {
35
- dim : indexers .get (dim , slice (None )) for dim in variable .dims
36
- }
34
+ raw_indexers = {dim : idx .get (dim , slice (None )) for dim in variable .dims }
37
35
expected = variable .data [* raw_indexers .values ()]
38
36
39
37
assert isinstance (actual , self .array_type ), f"wrong type: { type (actual )} "
0 commit comments