Skip to content

Commit 6082d94

Browse files
AllenDowneyricardoV94
authored andcommitted
Implement transpose for XTensorVariables
1 parent 56bd328 commit 6082d94

File tree

4 files changed

+215
-6
lines changed

4 files changed

+215
-6
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Concat, Stack
5+
from pytensor.xtensor.shape import Concat, Stack, Transpose
66

77

88
@register_xcanonicalize
@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
7070
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
7171
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
7272
return [new_out]
73+
74+
75+
@register_xcanonicalize
76+
@node_rewriter(tracks=[Transpose])
77+
def lower_transpose(fgraph, node):
78+
[x] = node.inputs
79+
# Use the final dimensions that were already computed in make_node
80+
out_dims = node.outputs[0].type.dims
81+
in_dims = x.type.dims
82+
83+
# Compute the permutation based on the final dimensions
84+
perm = tuple(in_dims.index(d) for d in out_dims)
85+
x_tensor = tensor_from_xtensor(x)
86+
x_tensor_transposed = x_tensor.transpose(perm)
87+
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
88+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import warnings
12
from collections.abc import Sequence
3+
from typing import Literal
24

35
from pytensor import Variable
46
from pytensor.graph import Apply
57
from pytensor.scalar import upcast
68
from pytensor.xtensor.basic import XOp
7-
from pytensor.xtensor.type import as_xtensor, xtensor
9+
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
810

911

1012
class Stack(XOp):
@@ -73,6 +75,97 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
7375
return y
7476

7577

78+
class Transpose(XOp):
79+
__props__ = ("dims",)
80+
81+
def __init__(
82+
self,
83+
dims: tuple[str | Literal[...], ...],
84+
):
85+
super().__init__()
86+
if dims.count(...) > 1:
87+
raise ValueError("an index can only have a single ellipsis ('...')")
88+
self.dims = dims
89+
90+
def make_node(self, x):
91+
x = as_xtensor(x)
92+
93+
transpose_dims = self.dims
94+
x_dims = x.type.dims
95+
96+
if transpose_dims == () or transpose_dims == (...,):
97+
out_dims = tuple(reversed(x_dims))
98+
elif ... in transpose_dims:
99+
# Handle ellipsis expansion
100+
ellipsis_idx = transpose_dims.index(...)
101+
pre = transpose_dims[:ellipsis_idx]
102+
post = transpose_dims[ellipsis_idx + 1 :]
103+
middle = [d for d in x_dims if d not in pre + post]
104+
out_dims = (*pre, *middle, *post)
105+
if set(out_dims) != set(x_dims):
106+
raise ValueError(f"{out_dims} must be a permuted list of {x_dims}")
107+
else:
108+
out_dims = transpose_dims
109+
if set(out_dims) != set(x_dims):
110+
raise ValueError(
111+
f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included"
112+
)
113+
114+
output = xtensor(
115+
dtype=x.type.dtype,
116+
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims),
117+
dims=out_dims,
118+
)
119+
return Apply(self, [x], [output])
120+
121+
122+
def transpose(
123+
x,
124+
*dims: str | Literal[...],
125+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
126+
) -> XTensorVariable:
127+
"""Transpose dimensions of the tensor.
128+
129+
Parameters
130+
----------
131+
x : XTensorVariable
132+
Input tensor to transpose.
133+
*dims : str
134+
Dimensions to transpose to. Can include ellipsis (...) to represent
135+
remaining dimensions in their original order.
136+
missing_dims : {"raise", "warn", "ignore"}, optional
137+
How to handle dimensions that don't exist in the input tensor:
138+
- "raise": Raise an error if any dimensions don't exist (default)
139+
- "warn": Warn if any dimensions don't exist
140+
- "ignore": Silently ignore any dimensions that don't exist
141+
142+
Returns
143+
-------
144+
XTensorVariable
145+
Transposed tensor with reordered dimensions.
146+
147+
Raises
148+
------
149+
ValueError
150+
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
151+
"""
152+
# Validate dimensions
153+
x = as_xtensor(x)
154+
all_dims = x.type.dims
155+
invalid_dims = set(dims) - {..., *all_dims}
156+
if invalid_dims:
157+
if missing_dims != "ignore":
158+
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}"
159+
if missing_dims == "raise":
160+
raise ValueError(msg)
161+
else:
162+
warnings.warn(msg)
163+
# Handle missing dimensions if not raising
164+
dims = tuple(d for d in dims if d in all_dims or d is ...)
165+
166+
return Transpose(dims)(x)
167+
168+
76169
class Concat(XOp):
77170
__props__ = ("dim",)
78171

pytensor/xtensor/type.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
XARRAY_AVAILABLE = False
1111

1212
from collections.abc import Sequence
13-
from typing import TypeVar
13+
from typing import Literal, TypeVar
1414

1515
import numpy as np
1616

@@ -357,6 +357,50 @@ def imag(self):
357357
def real(self):
358358
return px.math.real(self)
359359

360+
def transpose(
361+
self,
362+
*dims: str | Literal[...],
363+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
364+
) -> "XTensorVariable":
365+
"""Transpose dimensions of the tensor.
366+
367+
Parameters
368+
----------
369+
*dims : str | Ellipsis
370+
Dimensions to transpose. If empty, performs a full transpose.
371+
Can use ellipsis (...) to represent remaining dimensions.
372+
missing_dims : {"raise", "warn", "ignore"}, default="raise"
373+
How to handle dimensions that don't exist in the tensor:
374+
- "raise": Raise an error if any dimensions don't exist
375+
- "warn": Warn if any dimensions don't exist
376+
- "ignore": Silently ignore any dimensions that don't exist
377+
378+
Returns
379+
-------
380+
XTensorVariable
381+
Transposed tensor with reordered dimensions.
382+
383+
Raises
384+
------
385+
ValueError
386+
If missing_dims="raise" and any dimensions don't exist.
387+
If multiple ellipsis are provided.
388+
"""
389+
return px.shape.transpose(self, *dims, missing_dims=missing_dims)
390+
391+
@property
392+
def T(self) -> "XTensorVariable":
393+
"""Return the full transpose of the tensor.
394+
395+
This is equivalent to calling transpose() with no arguments.
396+
397+
Returns
398+
-------
399+
XTensorVariable
400+
Fully transposed tensor.
401+
"""
402+
return self.transpose()
403+
360404
# Aggregation
361405
# https://docs.xarray.dev/en/latest/api.html#id6
362406
def all(self, dim):

tests/xtensor/test_shape.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# ruff: noqa: E402
2+
import re
3+
24
import pytest
35

46

@@ -9,7 +11,7 @@
911
import numpy as np
1012
from xarray import concat as xr_concat
1113

12-
from pytensor.xtensor.shape import concat, stack
14+
from pytensor.xtensor.shape import concat, stack, transpose
1315
from pytensor.xtensor.type import xtensor
1416
from tests.xtensor.util import (
1517
xr_arange_like,
@@ -28,9 +30,7 @@ def powerset(iterable, min_group_size=0):
2830
)
2931

3032

31-
@pytest.mark.xfail(reason="Not yet implemented")
3233
def test_transpose():
33-
transpose = None
3434
a, b, c, d, e = "abcde"
3535

3636
x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
@@ -53,6 +53,62 @@ def test_transpose():
5353
xr_assert_allclose(res_i, expected_res_i)
5454

5555

56+
def test_xtensor_variable_transpose():
57+
"""Test the transpose() method of XTensorVariable."""
58+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
59+
60+
# Test basic transpose
61+
out = x.transpose()
62+
fn = xr_function([x], out)
63+
x_test = xr_arange_like(x)
64+
xr_assert_allclose(fn(x_test), x_test.transpose())
65+
66+
# Test transpose with specific dimensions
67+
out = x.transpose("c", "a", "b")
68+
fn = xr_function([x], out)
69+
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))
70+
71+
# Test transpose with ellipsis
72+
out = x.transpose("c", ...)
73+
fn = xr_function([x], out)
74+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
75+
76+
# Test error cases
77+
with pytest.raises(
78+
ValueError,
79+
match=re.escape(
80+
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
81+
),
82+
):
83+
x.transpose("d")
84+
85+
with pytest.raises(ValueError, match="an index can only have a single ellipsis"):
86+
x.transpose("a", ..., "b", ...)
87+
88+
# Test missing_dims parameter
89+
# Test ignore
90+
out = x.transpose("c", ..., "d", missing_dims="ignore")
91+
fn = xr_function([x], out)
92+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
93+
94+
# Test warn
95+
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
96+
out = x.transpose("c", ..., "d", missing_dims="warn")
97+
fn = xr_function([x], out)
98+
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))
99+
100+
101+
def test_xtensor_variable_T():
102+
"""Test the T property of XTensorVariable."""
103+
# Test T property with 3D tensor
104+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
105+
out = x.T
106+
107+
fn = xr_function([x], out)
108+
x_test = xr_arange_like(x)
109+
xr_assert_allclose(fn(x_test), x_test.T)
110+
111+
56112
def test_stack():
57113
dims = ("a", "b", "c", "d")
58114
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))

0 commit comments

Comments
 (0)