diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 06b8c40a32..03deb9a91c 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack +from pytensor.xtensor.shape import Concat, Stack, Transpose @register_xcanonicalize @@ -70,3 +70,19 @@ def lower_concat(fgraph, node): joined_tensor = join(concat_axis, *bcast_tensor_inputs) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Transpose]) +def lower_transpose(fgraph, node): + [x] = node.inputs + # Use the final dimensions that were already computed in make_node + out_dims = node.outputs[0].type.dims + in_dims = x.type.dims + + # Compute the permutation based on the final dimensions + perm = tuple(in_dims.index(d) for d in out_dims) + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = x_tensor.transpose(perm) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f39d495285..cc0a2a2fa6 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,10 +1,12 @@ +import warnings from collections.abc import Sequence +from typing import Literal from pytensor import Variable from pytensor.graph import Apply from pytensor.scalar import upcast from pytensor.xtensor.basic import XOp -from pytensor.xtensor.type import as_xtensor, xtensor +from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor class Stack(XOp): @@ -73,6 +75,97 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +class Transpose(XOp): + __props__ = ("dims",) + + def __init__( + self, + dims: tuple[str | Literal[...], ...], + ): + super().__init__() + if dims.count(...) > 1: + raise ValueError("an index can only have a single ellipsis ('...')") + self.dims = dims + + def make_node(self, x): + x = as_xtensor(x) + + transpose_dims = self.dims + x_dims = x.type.dims + + if transpose_dims == () or transpose_dims == (...,): + out_dims = tuple(reversed(x_dims)) + elif ... in transpose_dims: + # Handle ellipsis expansion + ellipsis_idx = transpose_dims.index(...) + pre = transpose_dims[:ellipsis_idx] + post = transpose_dims[ellipsis_idx + 1 :] + middle = [d for d in x_dims if d not in pre + post] + out_dims = (*pre, *middle, *post) + if set(out_dims) != set(x_dims): + raise ValueError(f"{out_dims} must be a permuted list of {x_dims}") + else: + out_dims = transpose_dims + if set(out_dims) != set(x_dims): + raise ValueError( + f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included" + ) + + output = xtensor( + dtype=x.type.dtype, + shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims), + dims=out_dims, + ) + return Apply(self, [x], [output]) + + +def transpose( + x, + *dims: str | Literal[...], + missing_dims: Literal["raise", "warn", "ignore"] = "raise", +) -> XTensorVariable: + """Transpose dimensions of the tensor. + + Parameters + ---------- + x : XTensorVariable + Input tensor to transpose. + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". + """ + # Validate dimensions + x = as_xtensor(x) + all_dims = x.type.dims + invalid_dims = set(dims) - {..., *all_dims} + if invalid_dims: + if missing_dims != "ignore": + msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}" + if missing_dims == "raise": + raise ValueError(msg) + else: + warnings.warn(msg) + # Handle missing dimensions if not raising + dims = tuple(d for d in dims if d in all_dims or d is ...) + + return Transpose(dims)(x) + + class Concat(XOp): __props__ = ("dim",) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 5b79e9ae57..5968a8014c 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -10,7 +10,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import TypeVar +from typing import Literal, TypeVar import numpy as np @@ -357,6 +357,50 @@ def imag(self): def real(self): return px.math.real(self) + def transpose( + self, + *dims: str | Literal[...], + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + ) -> "XTensorVariable": + """Transpose dimensions of the tensor. + + Parameters + ---------- + *dims : str | Ellipsis + Dimensions to transpose. If empty, performs a full transpose. + Can use ellipsis (...) to represent remaining dimensions. + missing_dims : {"raise", "warn", "ignore"}, default="raise" + How to handle dimensions that don't exist in the tensor: + - "raise": Raise an error if any dimensions don't exist + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If missing_dims="raise" and any dimensions don't exist. + If multiple ellipsis are provided. + """ + return px.shape.transpose(self, *dims, missing_dims=missing_dims) + + @property + def T(self) -> "XTensorVariable": + """Return the full transpose of the tensor. + + This is equivalent to calling transpose() with no arguments. + + Returns + ------- + XTensorVariable + Fully transposed tensor. + """ + return self.transpose() + # Aggregation # https://docs.xarray.dev/en/latest/api.html#id6 def all(self, dim): @@ -470,8 +514,7 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None): if isinstance(x, Apply): if len(x.outputs) != 1: raise ValueError( - "It is ambiguous which output of a " - "multi-output Op has to be fetched.", + "It is ambiguous which output of a multi-output Op has to be fetched.", x, ) else: diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 79cc2738a2..2fc1b50fd0 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -1,4 +1,6 @@ # ruff: noqa: E402 +import re + import pytest @@ -7,12 +9,16 @@ from itertools import chain, combinations import numpy as np -from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack +from pytensor.xtensor.shape import concat, stack, transpose from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) def powerset(iterable, min_group_size=0): @@ -24,9 +30,7 @@ def powerset(iterable, min_group_size=0): ) -@pytest.mark.xfail(reason="Not yet implemented") def test_transpose(): - transpose = None a, b, c, d, e = "abcde" x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) @@ -42,16 +46,69 @@ def test_transpose(): outs = [transpose(x, *perm) for perm in permutations] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [x_test.transpose(*perm) for perm in permutations] for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): xr_assert_allclose(res_i, expected_res_i) +def test_xtensor_variable_transpose(): + """Test the transpose() method of XTensorVariable.""" + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + + # Test basic transpose + out = x.transpose() + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test transpose with specific dimensions + out = x.transpose("c", "a", "b") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b")) + + # Test transpose with ellipsis + out = x.transpose("c", ...) + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test error cases + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')" + ), + ): + x.transpose("d") + + with pytest.raises(ValueError, match="an index can only have a single ellipsis"): + x.transpose("a", ..., "b", ...) + + # Test missing_dims parameter + # Test ignore + out = x.transpose("c", ..., "d", missing_dims="ignore") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test warn + with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"): + out = x.transpose("c", ..., "d", missing_dims="warn") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + +def test_xtensor_variable_T(): + """Test the T property of XTensorVariable.""" + # Test T property with 3D tensor + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + out = x.T + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.T) + + def test_stack(): dims = ("a", "b", "c", "d") x = xtensor("x", dims=dims, shape=(2, 3, 5, 7)) @@ -61,10 +118,7 @@ def test_stack(): ] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [ @@ -81,10 +135,7 @@ def test_stack_single_dim(): assert out.type.dims == ("b", "c", "d") fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) fn.fn.dprint(print_type=True) res = fn(x_test) expected_res = x_test.stack(d=["a"]) @@ -96,10 +147,7 @@ def test_multiple_stacks(): out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d")) fn = xr_function([x], [out]) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) xr_assert_allclose(res[0], expected_res)