Skip to content

Implement Transpose for XTensorVariables #1430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Concat, Stack
from pytensor.xtensor.shape import Concat, Stack, Transpose


@register_xcanonicalize
Expand Down Expand Up @@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Use the final dimensions that were already computed in make_node
out_dims = node.outputs[0].type.dims
in_dims = x.type.dims

# Compute the permutation based on the final dimensions
perm = tuple(in_dims.index(d) for d in out_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]
95 changes: 94 additions & 1 deletion pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
from collections.abc import Sequence
from typing import Literal

from pytensor import Variable
from pytensor.graph import Apply
from pytensor.scalar import upcast
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor


class Stack(XOp):
Expand Down Expand Up @@ -73,6 +75,97 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y


class Transpose(XOp):
__props__ = ("dims",)

def __init__(
self,
dims: tuple[str | Literal[...], ...],
):
super().__init__()
if dims.count(...) > 1:
raise ValueError("an index can only have a single ellipsis ('...')")
self.dims = dims

def make_node(self, x):
x = as_xtensor(x)

transpose_dims = self.dims
x_dims = x.type.dims

if transpose_dims == () or transpose_dims == (...,):
out_dims = tuple(reversed(x_dims))
elif ... in transpose_dims:
# Handle ellipsis expansion
ellipsis_idx = transpose_dims.index(...)
pre = transpose_dims[:ellipsis_idx]
post = transpose_dims[ellipsis_idx + 1 :]
middle = [d for d in x_dims if d not in pre + post]
out_dims = (*pre, *middle, *post)
if set(out_dims) != set(x_dims):
raise ValueError(f"{out_dims} must be a permuted list of {x_dims}")
else:
out_dims = transpose_dims
if set(out_dims) != set(x_dims):
raise ValueError(
f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included"
)

output = xtensor(
dtype=x.type.dtype,
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims),
dims=out_dims,
)
return Apply(self, [x], [output])


def transpose(
x,
*dims: str | Literal[...],
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
) -> XTensorVariable:
"""Transpose dimensions of the tensor.

Parameters
----------
x : XTensorVariable
Input tensor to transpose.
*dims : str
Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional
How to handle dimensions that don't exist in the input tensor:
- "raise": Raise an error if any dimensions don't exist (default)
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist

Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.

Raises
------
ValueError
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
"""
# Validate dimensions
x = as_xtensor(x)
all_dims = x.type.dims
invalid_dims = set(dims) - {..., *all_dims}
if invalid_dims:
if missing_dims != "ignore":
msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}"
if missing_dims == "raise":
raise ValueError(msg)
else:
warnings.warn(msg)
# Handle missing dimensions if not raising
dims = tuple(d for d in dims if d in all_dims or d is ...)

return Transpose(dims)(x)


class Concat(XOp):
__props__ = ("dim",)

Expand Down
49 changes: 46 additions & 3 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
XARRAY_AVAILABLE = False

from collections.abc import Sequence
from typing import TypeVar
from typing import Literal, TypeVar

import numpy as np

Expand Down Expand Up @@ -357,6 +357,50 @@ def imag(self):
def real(self):
return px.math.real(self)

def transpose(
self,
*dims: str | Literal[...],
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
) -> "XTensorVariable":
"""Transpose dimensions of the tensor.

Parameters
----------
*dims : str | Ellipsis
Dimensions to transpose. If empty, performs a full transpose.
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist

Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.

Raises
------
ValueError
If missing_dims="raise" and any dimensions don't exist.
If multiple ellipsis are provided.
"""
return px.shape.transpose(self, *dims, missing_dims=missing_dims)

@property
def T(self) -> "XTensorVariable":
"""Return the full transpose of the tensor.

This is equivalent to calling transpose() with no arguments.

Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return self.transpose()

# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim):
Expand Down Expand Up @@ -470,8 +514,7 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
if isinstance(x, Apply):
if len(x.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a "
"multi-output Op has to be fetched.",
"It is ambiguous which output of a multi-output Op has to be fetched.",
x,
)
else:
Expand Down
90 changes: 69 additions & 21 deletions tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# ruff: noqa: E402
import re

import pytest


Expand All @@ -7,12 +9,16 @@
from itertools import chain, combinations

import numpy as np
from xarray import DataArray
from xarray import concat as xr_concat

from pytensor.xtensor.shape import concat, stack
from pytensor.xtensor.shape import concat, stack, transpose
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like
from tests.xtensor.util import (
xr_arange_like,
xr_assert_allclose,
xr_function,
xr_random_like,
)


def powerset(iterable, min_group_size=0):
Expand All @@ -24,9 +30,7 @@ def powerset(iterable, min_group_size=0):
)


@pytest.mark.xfail(reason="Not yet implemented")
def test_transpose():
transpose = None
a, b, c, d, e = "abcde"

x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11))
Expand All @@ -42,16 +46,69 @@ def test_transpose():
outs = [transpose(x, *perm) for perm in permutations]

fn = xr_function([x], outs)
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = [x_test.transpose(*perm) for perm in permutations]
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
xr_assert_allclose(res_i, expected_res_i)


def test_xtensor_variable_transpose():
"""Test the transpose() method of XTensorVariable."""
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))

# Test basic transpose
out = x.transpose()
fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test.transpose())

# Test transpose with specific dimensions
out = x.transpose("c", "a", "b")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b"))

# Test transpose with ellipsis
out = x.transpose("c", ...)
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))

# Test error cases
with pytest.raises(
ValueError,
match=re.escape(
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
),
):
x.transpose("d")

with pytest.raises(ValueError, match="an index can only have a single ellipsis"):
x.transpose("a", ..., "b", ...)

# Test missing_dims parameter
# Test ignore
out = x.transpose("c", ..., "d", missing_dims="ignore")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))

# Test warn
with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"):
out = x.transpose("c", ..., "d", missing_dims="warn")
fn = xr_function([x], out)
xr_assert_allclose(fn(x_test), x_test.transpose("c", ...))


def test_xtensor_variable_T():
"""Test the T property of XTensorVariable."""
# Test T property with 3D tensor
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
out = x.T

fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test.T)


def test_stack():
dims = ("a", "b", "c", "d")
x = xtensor("x", dims=dims, shape=(2, 3, 5, 7))
Expand All @@ -61,10 +118,7 @@ def test_stack():
]

fn = xr_function([x], outs)
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
res = fn(x_test)

expected_res = [
Expand All @@ -81,10 +135,7 @@ def test_stack_single_dim():
assert out.type.dims == ("b", "c", "d")

fn = xr_function([x], out)
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
fn.fn.dprint(print_type=True)
res = fn(x_test)
expected_res = x_test.stack(d=["a"])
Expand All @@ -96,10 +147,7 @@ def test_multiple_stacks():
out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d"))

fn = xr_function([x], [out])
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
xr_assert_allclose(res[0], expected_res)
Expand Down
Loading