Skip to content

Commit 56bd328

Browse files
committed
Use xr_arange_like in existing tests
1 parent 201ee95 commit 56bd328

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

tests/xtensor/test_shape.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
from itertools import chain, combinations
88

99
import numpy as np
10-
from xarray import DataArray
1110
from xarray import concat as xr_concat
1211

1312
from pytensor.xtensor.shape import concat, stack
1413
from pytensor.xtensor.type import xtensor
15-
from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like
14+
from tests.xtensor.util import (
15+
xr_arange_like,
16+
xr_assert_allclose,
17+
xr_function,
18+
xr_random_like,
19+
)
1620

1721

1822
def powerset(iterable, min_group_size=0):
@@ -42,10 +46,7 @@ def test_transpose():
4246
outs = [transpose(x, *perm) for perm in permutations]
4347

4448
fn = xr_function([x], outs)
45-
x_test = DataArray(
46-
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
47-
dims=x.type.dims,
48-
)
49+
x_test = xr_arange_like(x)
4950
res = fn(x_test)
5051
expected_res = [x_test.transpose(*perm) for perm in permutations]
5152
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
@@ -61,10 +62,7 @@ def test_stack():
6162
]
6263

6364
fn = xr_function([x], outs)
64-
x_test = DataArray(
65-
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
66-
dims=x.type.dims,
67-
)
65+
x_test = xr_arange_like(x)
6866
res = fn(x_test)
6967

7068
expected_res = [
@@ -81,10 +79,7 @@ def test_stack_single_dim():
8179
assert out.type.dims == ("b", "c", "d")
8280

8381
fn = xr_function([x], out)
84-
x_test = DataArray(
85-
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
86-
dims=x.type.dims,
87-
)
82+
x_test = xr_arange_like(x)
8883
fn.fn.dprint(print_type=True)
8984
res = fn(x_test)
9085
expected_res = x_test.stack(d=["a"])
@@ -96,10 +91,7 @@ def test_multiple_stacks():
9691
out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d"))
9792

9893
fn = xr_function([x], [out])
99-
x_test = DataArray(
100-
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
101-
dims=x.type.dims,
102-
)
94+
x_test = xr_arange_like(x)
10395
res = fn(x_test)
10496
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
10597
xr_assert_allclose(res[0], expected_res)

0 commit comments

Comments
 (0)