7
7
from itertools import chain , combinations
8
8
9
9
import numpy as np
10
- from xarray import DataArray
11
10
from xarray import concat as xr_concat
12
11
13
12
from pytensor .xtensor .shape import concat , stack
14
13
from pytensor .xtensor .type import xtensor
15
- from tests .xtensor .util import xr_assert_allclose , xr_function , xr_random_like
14
+ from tests .xtensor .util import (
15
+ xr_arange_like ,
16
+ xr_assert_allclose ,
17
+ xr_function ,
18
+ xr_random_like ,
19
+ )
16
20
17
21
18
22
def powerset (iterable , min_group_size = 0 ):
@@ -42,10 +46,7 @@ def test_transpose():
42
46
outs = [transpose (x , * perm ) for perm in permutations ]
43
47
44
48
fn = xr_function ([x ], outs )
45
- x_test = DataArray (
46
- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
47
- dims = x .type .dims ,
48
- )
49
+ x_test = xr_arange_like (x )
49
50
res = fn (x_test )
50
51
expected_res = [x_test .transpose (* perm ) for perm in permutations ]
51
52
for outs_i , res_i , expected_res_i in zip (outs , res , expected_res ):
@@ -61,10 +62,7 @@ def test_stack():
61
62
]
62
63
63
64
fn = xr_function ([x ], outs )
64
- x_test = DataArray (
65
- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
66
- dims = x .type .dims ,
67
- )
65
+ x_test = xr_arange_like (x )
68
66
res = fn (x_test )
69
67
70
68
expected_res = [
@@ -81,10 +79,7 @@ def test_stack_single_dim():
81
79
assert out .type .dims == ("b" , "c" , "d" )
82
80
83
81
fn = xr_function ([x ], out )
84
- x_test = DataArray (
85
- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
86
- dims = x .type .dims ,
87
- )
82
+ x_test = xr_arange_like (x )
88
83
fn .fn .dprint (print_type = True )
89
84
res = fn (x_test )
90
85
expected_res = x_test .stack (d = ["a" ])
@@ -96,10 +91,7 @@ def test_multiple_stacks():
96
91
out = stack (x , new_dim1 = ("a" , "b" ), new_dim2 = ("c" , "d" ))
97
92
98
93
fn = xr_function ([x ], [out ])
99
- x_test = DataArray (
100
- np .arange (np .prod (x .type .shape ), dtype = x .type .dtype ).reshape (x .type .shape ),
101
- dims = x .type .dims ,
102
- )
94
+ x_test = xr_arange_like (x )
103
95
res = fn (x_test )
104
96
expected_res = x_test .stack (new_dim1 = ("a" , "b" ), new_dim2 = ("c" , "d" ))
105
97
xr_assert_allclose (res [0 ], expected_res )
0 commit comments