Skip to content

Make preprocessing and postprocessing consistent accross transforms #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0a1b515
Integrate axis swap into 1d processing funcs
felixblanke Jun 24, 2024
2426060
Make channel dim addition optional
felixblanke Jun 24, 2024
eaa2fb6
Refactor; Add processing funcs for 2d
felixblanke Jun 24, 2024
0e4cb27
Move processung funcs to _util module
felixblanke Jun 24, 2024
3554a53
Generalize tensor processing
felixblanke Jun 24, 2024
4e8ed02
Adapt 1d cases
felixblanke Jun 24, 2024
f889158
Extend _map_result to 1d case
felixblanke Jun 24, 2024
0ed2896
Make _preprocess_coeffs general
felixblanke Jun 24, 2024
03f74b7
Make postprocessing general
felixblanke Jun 24, 2024
ae7b345
Apply process funcs in 3d transforms
felixblanke Jun 24, 2024
21332bf
Format
felixblanke Jun 24, 2024
b461c97
Fix coeff postprocess
felixblanke Jun 24, 2024
105a0a8
Reduce tensor processing to coeff processing
felixblanke Jun 24, 2024
24a25f4
Add fully separable transforms for n dims
felixblanke Jun 24, 2024
169d6cf
Move dtype check to preprocessing
felixblanke Jun 24, 2024
ba1b83d
Encapsulate check for consistent dtype and device
felixblanke Jun 24, 2024
d50a2c2
Revert changes to coeff shape check
felixblanke Jun 24, 2024
2d4620a
Make n-dim separable trafo private
felixblanke Jun 25, 2024
1b45eae
Add explainatory comments
felixblanke Jun 25, 2024
15af78b
Add docstrings
felixblanke Jun 25, 2024
35bab80
Rename _map_result to _apply_to_tensor_elems
felixblanke Jun 25, 2024
37db5db
Format
felixblanke Jun 25, 2024
b4c16e2
rename tree_map
v0lta Jun 26, 2024
a9569f2
rename coeff tree map.
v0lta Jun 26, 2024
d1339f0
Add remark on JAX tree map
felixblanke Jun 26, 2024
6254f13
Merge branch 'main' into fix/keep-ndims-Nd
v0lta Jul 1, 2024
e281c85
merge.
v0lta Jul 1, 2024
75ad167
formatting.
v0lta Jul 1, 2024
e3fe0f4
fix typing.
v0lta Jul 1, 2024
d6d8259
Fix ndim sep trafo usage comments
felixblanke Jul 1, 2024
de1b7c8
Fix docstr
felixblanke Jul 1, 2024
986692f
nd-transforms are out of scope.
v0lta Jul 1, 2024
f008cfa
Merge branch 'fix/keep-ndims-Nd' of github.com:v0lta/PyTorch-Wavelet-…
v0lta Jul 1, 2024
77ca368
short note.
v0lta Jul 1, 2024
cef6a04
move note.
v0lta Jul 1, 2024
9981521
add note to forward and backward.
v0lta Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 257 additions & 8 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import annotations

import typing
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, Literal, NamedTuple, Optional, Protocol, Union, cast, overload

import numpy as np
import pywt
Expand Down Expand Up @@ -186,6 +187,32 @@ def _check_axes_argument(axes: Sequence[int]) -> None:
raise ValueError("Cant transform the same axis twice.")


def _check_same_device(
tensor: torch.Tensor, torch_device: torch.device
) -> torch.Tensor:
if torch_device != tensor.device:
raise ValueError("coefficients must be on the same device")
return tensor


def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.Tensor:
if torch_dtype != tensor.dtype:
raise ValueError("coefficients must have the same dtype")
return tensor


def _check_same_device_dtype(
coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd],
) -> tuple[torch.device, torch.dtype]:
c = _check_if_tensor(coeffs[0])
torch_device, torch_dtype = c.device, c.dtype

_map_result(coeffs, partial(_check_same_device, torch_device=torch_device))
_map_result(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype))

return torch_device, torch_dtype


def _get_transpose_order(
axes: Sequence[int], data_shape: Sequence[int]
) -> tuple[list[int], list[int]]:
Expand All @@ -208,6 +235,13 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor:
return torch.permute(data, restore_sorted)


@overload
def _map_result(
data: list[torch.Tensor],
function: Callable[[torch.Tensor], torch.Tensor],
) -> list[torch.Tensor]: ...


@overload
def _map_result(
data: WaveletCoeff2d,
Expand All @@ -223,12 +257,13 @@ def _map_result(


def _map_result(
data: Union[WaveletCoeff2d, WaveletCoeffNd],
data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd],
function: Callable[[torch.Tensor], torch.Tensor],
) -> Union[WaveletCoeff2d, WaveletCoeffNd]:
) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]:
approx = function(data[0])
result_lst: list[
Union[
torch.Tensor,
WaveletDetailDict,
WaveletDetailTuple2d,
]
Expand All @@ -245,11 +280,225 @@ def _map_result(
elif isinstance(element, dict):
new_dict = {key: function(value) for key, value in element.items()}
result_lst.append(new_dict)
elif isinstance(element, torch.Tensor):
result_lst.append(function(element))
else:
raise ValueError(f"Unexpected input type {type(element)}")

# cast since we assume that the full list is of the same type
cast_result_lst = cast(
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst
if not result_lst:
# if only approximation coeff:
# use list iff data is a list
return [approx] if isinstance(data, list) else (approx,)
elif isinstance(result_lst[0], torch.Tensor):
# if the first detail coeff is tensor
# -> all are tensors -> return a list
return [approx] + cast(list[torch.Tensor], result_lst)
else:
# cast since we assume that the full list is of the same type
cast_result_lst = cast(
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst
)
return approx, *cast_result_lst


# 1d case
@overload
def _preprocess_coeffs(
coeffs: list[torch.Tensor],
ndim: Literal[1],
axes: int,
add_channel_dim: bool = False,
) -> tuple[list[torch.Tensor], list[int]]: ...


# 2d case
@overload
def _preprocess_coeffs(
coeffs: WaveletCoeff2d,
ndim: Literal[2],
axes: tuple[int, int],
add_channel_dim: bool = False,
) -> tuple[WaveletCoeff2d, list[int]]: ...


# Nd case
@overload
def _preprocess_coeffs(
coeffs: WaveletCoeffNd,
ndim: int,
axes: tuple[int, ...],
add_channel_dim: bool = False,
) -> tuple[WaveletCoeffNd, list[int]]: ...


# list of nd tensors
@overload
def _preprocess_coeffs(
coeffs: list[torch.Tensor],
ndim: int,
axes: Union[tuple[int, ...], int],
add_channel_dim: bool = False,
) -> tuple[list[torch.Tensor], list[int]]: ...


def _preprocess_coeffs(
coeffs: Union[
list[torch.Tensor],
WaveletCoeff2d,
WaveletCoeffNd,
],
ndim: int,
axes: Union[tuple[int, ...], int],
add_channel_dim: bool = False,
) -> tuple[
Union[
list[torch.Tensor],
WaveletCoeff2d,
WaveletCoeffNd,
],
list[int],
]:
if isinstance(axes, int):
axes = (axes,)

torch_dtype = _check_if_tensor(coeffs[0]).dtype
if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")

if ndim <= 0:
raise ValueError("Number of dimensions must be positive")

if tuple(axes) != tuple(range(-ndim, 0)):
if len(axes) != ndim:
raise ValueError(f"{ndim}D transforms work with {ndim} axes.")
else:
swap_fn = partial(_swap_axes, axes=axes)
coeffs = _map_result(coeffs, swap_fn)

# Fold axes for the wavelets
ds = list(coeffs[0].shape)
if len(ds) < ndim:
raise ValueError(f"At least {ndim} input dimensions required.")
elif len(ds) == ndim:
coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0))
elif len(ds) > ndim + 1:
coeffs = _map_result(coeffs, lambda t: _fold_axes(t, ndim)[0])

if add_channel_dim:
coeffs = _map_result(coeffs, lambda x: x.unsqueeze(1))

return coeffs, ds


# 1d case
@overload
def _postprocess_coeffs(
coeffs: list[torch.Tensor],
ndim: Literal[1],
ds: list[int],
axes: int,
) -> list[torch.Tensor]: ...


# 2d case
@overload
def _postprocess_coeffs(
coeffs: WaveletCoeff2d,
ndim: Literal[2],
ds: list[int],
axes: tuple[int, int],
) -> WaveletCoeff2d: ...


# Nd case
@overload
def _postprocess_coeffs(
coeffs: WaveletCoeffNd,
ndim: int,
ds: list[int],
axes: tuple[int, ...],
) -> WaveletCoeffNd: ...


# list of nd tensors
@overload
def _postprocess_coeffs(
coeffs: list[torch.Tensor],
ndim: int,
ds: list[int],
axes: Union[tuple[int, ...], int],
) -> list[torch.Tensor]: ...


def _postprocess_coeffs(
coeffs: Union[
list[torch.Tensor],
WaveletCoeff2d,
WaveletCoeffNd,
],
ndim: int,
ds: list[int],
axes: Union[tuple[int, ...], int],
) -> Union[
list[torch.Tensor],
WaveletCoeff2d,
WaveletCoeffNd,
]:
if isinstance(axes, int):
axes = (axes,)

if ndim <= 0:
raise ValueError("Number of dimensions must be positive")

# Fold axes for the wavelets
if len(ds) < ndim:
raise ValueError(f"At least {ndim} input dimensions required.")
elif len(ds) == ndim:
coeffs = _map_result(coeffs, lambda x: x.squeeze(0))
elif len(ds) > ndim + 1:
unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim)
coeffs = _map_result(coeffs, unfold_axes_fn)

if tuple(axes) != tuple(range(-ndim, 0)):
if len(axes) != ndim:
raise ValueError(f"{ndim}D transforms work with {ndim} axes.")
else:
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
coeffs = _map_result(coeffs, undo_swap_fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really advise against all of these _map_result calls - have one function that does the processing that can be reused, then just do list comprehensions for all successive function calls.

coeffs = _map_result(coeffs, lambda x: x.squeeze(0))

will always be less readable and understandable than

coeffs = _map_result(coeffs)
coeffs = [coeff.squeeze(0) for coeff in coeffs]

when _map_result has lots of hidden functionality

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The snippet

coeffs = _map_result(coeffs, lambda x: x.squeeze(0))

applies the function x.squeeze(0) to all tensors in coeffs. In the 1d case (where coeffs is of type list[Tensor]) this is equivalent to the list comprehension, as you said. However, coeffs might also be

  • (Tensor, dict[str, Tensor], ...)
  • (Tensor, (Tensor, Tensor, Tensor), ...)

So using _map_result allows to write the function once for all possible coefficient types. Would it perhaps help to rename _map_result or add documentation for it?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name should have something to do with tree and map. The new _apply_to_tensor_elems name really hides how general the concept is. I would take a page from https://jax.readthedocs.io/en/latest/_autosummary/jax.tree.map.html#jax.tree.map and also use their type hinting. The concept does not exist in torch, but I think it makes sense here, since we save on a lot of boilerplate-code. Perhaps we should include the link and explain what's going on?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an interesting intro discussing the pytree processing philosophy: https://jax.readthedocs.io/en/latest/working-with-pytrees.html .

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think @cthoyt has a point since the tree-map concept is not very popular.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree when comparing

coeffs = _map_result(coeffs, partial(torch.squeeze, 0))

to

coeffs = [coeff.squeeze(0) for coeff in coeffs]

The list wins, but what if it's a nested structure?


return coeffs


def _preprocess_tensor(
data: torch.Tensor,
ndim: int,
axes: Union[tuple[int, ...], int],
add_channel_dim: bool = True,
) -> tuple[torch.Tensor, list[int]]:
"""Preprocess input tensor dimensions.

Args:
data (torch.Tensor): An input tensor with at least `ndim` axes.
ndim (int): The number of axes on which the transformation is
applied.
axes (int or tuple of ints): Compute the transform over these axes
instead of the last ones.
add_channel_dim (bool): If True, ensures that the return has at
least `ndim` + 2 axes by potentially adding a new axis at dim 1.
Defaults to True.

Returns:
A tuple (data, ds) where data is the transformed data tensor
and ds contains the original shape. If `add_channel_dim` is True,
`data` has `ndim` + 2 axes, otherwise `ndim` + 1.
"""
data_lst, ds = _preprocess_coeffs(
[data], ndim=ndim, axes=axes, add_channel_dim=add_channel_dim
)
return approx, *cast_result_lst
return data_lst[0], ds


def _postprocess_tensor(
data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int]
) -> torch.Tensor:
return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0]
Loading
Loading