-
Notifications
You must be signed in to change notification settings - Fork 39
Make preprocessing and postprocessing consistent accross transforms #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
0a1b515
2426060
eaa2fb6
0e4cb27
3554a53
4e8ed02
f889158
0ed2896
03f74b7
ae7b345
21332bf
b461c97
105a0a8
24a25f4
169d6cf
ba1b83d
d50a2c2
2d4620a
1b45eae
15af78b
35bab80
37db5db
b4c16e2
a9569f2
d1339f0
6254f13
e281c85
75ad167
e3fe0f4
d6d8259
de1b7c8
986692f
f008cfa
77ca368
cef6a04
9981521
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,9 @@ | |
from __future__ import annotations | ||
|
||
import typing | ||
from collections.abc import Sequence | ||
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload | ||
from collections.abc import Callable, Sequence | ||
from functools import partial | ||
from typing import Any, Literal, NamedTuple, Optional, Protocol, Union, cast, overload | ||
|
||
import numpy as np | ||
import pywt | ||
|
@@ -186,6 +187,32 @@ def _check_axes_argument(axes: Sequence[int]) -> None: | |
raise ValueError("Cant transform the same axis twice.") | ||
|
||
|
||
def _check_same_device( | ||
tensor: torch.Tensor, torch_device: torch.device | ||
) -> torch.Tensor: | ||
if torch_device != tensor.device: | ||
raise ValueError("coefficients must be on the same device") | ||
return tensor | ||
|
||
|
||
def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.Tensor: | ||
if torch_dtype != tensor.dtype: | ||
raise ValueError("coefficients must have the same dtype") | ||
return tensor | ||
|
||
|
||
def _check_same_device_dtype( | ||
coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], | ||
) -> tuple[torch.device, torch.dtype]: | ||
c = _check_if_tensor(coeffs[0]) | ||
torch_device, torch_dtype = c.device, c.dtype | ||
|
||
_map_result(coeffs, partial(_check_same_device, torch_device=torch_device)) | ||
_map_result(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) | ||
|
||
return torch_device, torch_dtype | ||
|
||
|
||
def _get_transpose_order( | ||
axes: Sequence[int], data_shape: Sequence[int] | ||
) -> tuple[list[int], list[int]]: | ||
|
@@ -208,6 +235,13 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: | |
return torch.permute(data, restore_sorted) | ||
|
||
|
||
@overload | ||
def _map_result( | ||
data: list[torch.Tensor], | ||
function: Callable[[torch.Tensor], torch.Tensor], | ||
) -> list[torch.Tensor]: ... | ||
|
||
|
||
@overload | ||
def _map_result( | ||
data: WaveletCoeff2d, | ||
|
@@ -223,12 +257,13 @@ def _map_result( | |
|
||
|
||
def _map_result( | ||
data: Union[WaveletCoeff2d, WaveletCoeffNd], | ||
data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], | ||
function: Callable[[torch.Tensor], torch.Tensor], | ||
) -> Union[WaveletCoeff2d, WaveletCoeffNd]: | ||
) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: | ||
approx = function(data[0]) | ||
result_lst: list[ | ||
Union[ | ||
torch.Tensor, | ||
WaveletDetailDict, | ||
WaveletDetailTuple2d, | ||
] | ||
|
@@ -245,11 +280,225 @@ def _map_result( | |
elif isinstance(element, dict): | ||
new_dict = {key: function(value) for key, value in element.items()} | ||
result_lst.append(new_dict) | ||
elif isinstance(element, torch.Tensor): | ||
result_lst.append(function(element)) | ||
else: | ||
raise ValueError(f"Unexpected input type {type(element)}") | ||
|
||
# cast since we assume that the full list is of the same type | ||
cast_result_lst = cast( | ||
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst | ||
if not result_lst: | ||
# if only approximation coeff: | ||
# use list iff data is a list | ||
return [approx] if isinstance(data, list) else (approx,) | ||
elif isinstance(result_lst[0], torch.Tensor): | ||
# if the first detail coeff is tensor | ||
# -> all are tensors -> return a list | ||
return [approx] + cast(list[torch.Tensor], result_lst) | ||
else: | ||
# cast since we assume that the full list is of the same type | ||
cast_result_lst = cast( | ||
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst | ||
) | ||
return approx, *cast_result_lst | ||
|
||
|
||
# 1d case | ||
@overload | ||
def _preprocess_coeffs( | ||
coeffs: list[torch.Tensor], | ||
ndim: Literal[1], | ||
axes: int, | ||
add_channel_dim: bool = False, | ||
) -> tuple[list[torch.Tensor], list[int]]: ... | ||
|
||
|
||
# 2d case | ||
@overload | ||
def _preprocess_coeffs( | ||
coeffs: WaveletCoeff2d, | ||
ndim: Literal[2], | ||
axes: tuple[int, int], | ||
add_channel_dim: bool = False, | ||
) -> tuple[WaveletCoeff2d, list[int]]: ... | ||
|
||
|
||
# Nd case | ||
@overload | ||
def _preprocess_coeffs( | ||
coeffs: WaveletCoeffNd, | ||
ndim: int, | ||
axes: tuple[int, ...], | ||
add_channel_dim: bool = False, | ||
) -> tuple[WaveletCoeffNd, list[int]]: ... | ||
|
||
|
||
# list of nd tensors | ||
@overload | ||
def _preprocess_coeffs( | ||
coeffs: list[torch.Tensor], | ||
ndim: int, | ||
axes: Union[tuple[int, ...], int], | ||
add_channel_dim: bool = False, | ||
) -> tuple[list[torch.Tensor], list[int]]: ... | ||
|
||
|
||
def _preprocess_coeffs( | ||
coeffs: Union[ | ||
list[torch.Tensor], | ||
WaveletCoeff2d, | ||
WaveletCoeffNd, | ||
], | ||
ndim: int, | ||
axes: Union[tuple[int, ...], int], | ||
add_channel_dim: bool = False, | ||
) -> tuple[ | ||
Union[ | ||
list[torch.Tensor], | ||
WaveletCoeff2d, | ||
WaveletCoeffNd, | ||
], | ||
list[int], | ||
]: | ||
if isinstance(axes, int): | ||
axes = (axes,) | ||
|
||
torch_dtype = _check_if_tensor(coeffs[0]).dtype | ||
if not _is_dtype_supported(torch_dtype): | ||
raise ValueError(f"Input dtype {torch_dtype} not supported") | ||
|
||
if ndim <= 0: | ||
raise ValueError("Number of dimensions must be positive") | ||
|
||
if tuple(axes) != tuple(range(-ndim, 0)): | ||
if len(axes) != ndim: | ||
raise ValueError(f"{ndim}D transforms work with {ndim} axes.") | ||
else: | ||
swap_fn = partial(_swap_axes, axes=axes) | ||
coeffs = _map_result(coeffs, swap_fn) | ||
|
||
# Fold axes for the wavelets | ||
ds = list(coeffs[0].shape) | ||
if len(ds) < ndim: | ||
raise ValueError(f"At least {ndim} input dimensions required.") | ||
elif len(ds) == ndim: | ||
coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) | ||
elif len(ds) > ndim + 1: | ||
coeffs = _map_result(coeffs, lambda t: _fold_axes(t, ndim)[0]) | ||
|
||
if add_channel_dim: | ||
coeffs = _map_result(coeffs, lambda x: x.unsqueeze(1)) | ||
|
||
return coeffs, ds | ||
|
||
|
||
# 1d case | ||
@overload | ||
def _postprocess_coeffs( | ||
coeffs: list[torch.Tensor], | ||
ndim: Literal[1], | ||
ds: list[int], | ||
axes: int, | ||
) -> list[torch.Tensor]: ... | ||
|
||
|
||
# 2d case | ||
@overload | ||
def _postprocess_coeffs( | ||
coeffs: WaveletCoeff2d, | ||
ndim: Literal[2], | ||
ds: list[int], | ||
axes: tuple[int, int], | ||
) -> WaveletCoeff2d: ... | ||
|
||
|
||
# Nd case | ||
@overload | ||
def _postprocess_coeffs( | ||
coeffs: WaveletCoeffNd, | ||
ndim: int, | ||
ds: list[int], | ||
axes: tuple[int, ...], | ||
) -> WaveletCoeffNd: ... | ||
|
||
|
||
# list of nd tensors | ||
@overload | ||
def _postprocess_coeffs( | ||
coeffs: list[torch.Tensor], | ||
ndim: int, | ||
ds: list[int], | ||
axes: Union[tuple[int, ...], int], | ||
) -> list[torch.Tensor]: ... | ||
|
||
|
||
def _postprocess_coeffs( | ||
coeffs: Union[ | ||
list[torch.Tensor], | ||
WaveletCoeff2d, | ||
WaveletCoeffNd, | ||
], | ||
ndim: int, | ||
ds: list[int], | ||
axes: Union[tuple[int, ...], int], | ||
) -> Union[ | ||
list[torch.Tensor], | ||
WaveletCoeff2d, | ||
WaveletCoeffNd, | ||
]: | ||
if isinstance(axes, int): | ||
axes = (axes,) | ||
|
||
if ndim <= 0: | ||
raise ValueError("Number of dimensions must be positive") | ||
|
||
# Fold axes for the wavelets | ||
if len(ds) < ndim: | ||
raise ValueError(f"At least {ndim} input dimensions required.") | ||
elif len(ds) == ndim: | ||
coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) | ||
elif len(ds) > ndim + 1: | ||
unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) | ||
coeffs = _map_result(coeffs, unfold_axes_fn) | ||
|
||
if tuple(axes) != tuple(range(-ndim, 0)): | ||
if len(axes) != ndim: | ||
raise ValueError(f"{ndim}D transforms work with {ndim} axes.") | ||
else: | ||
undo_swap_fn = partial(_undo_swap_axes, axes=axes) | ||
coeffs = _map_result(coeffs, undo_swap_fn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would really advise against all of these
will always be less readable and understandable than
when _map_result has lots of hidden functionality There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The snippet
applies the function
So using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the name should have something to do with tree and map. The new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is an interesting intro discussing the pytree processing philosophy: https://jax.readthedocs.io/en/latest/working-with-pytrees.html . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think @cthoyt has a point since the tree-map concept is not very popular. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree when comparing coeffs = _map_result(coeffs, partial(torch.squeeze, 0)) to coeffs = [coeff.squeeze(0) for coeff in coeffs] The list wins, but what if it's a nested structure? |
||
|
||
return coeffs | ||
|
||
|
||
def _preprocess_tensor( | ||
data: torch.Tensor, | ||
ndim: int, | ||
axes: Union[tuple[int, ...], int], | ||
add_channel_dim: bool = True, | ||
) -> tuple[torch.Tensor, list[int]]: | ||
"""Preprocess input tensor dimensions. | ||
|
||
Args: | ||
data (torch.Tensor): An input tensor with at least `ndim` axes. | ||
ndim (int): The number of axes on which the transformation is | ||
applied. | ||
axes (int or tuple of ints): Compute the transform over these axes | ||
instead of the last ones. | ||
add_channel_dim (bool): If True, ensures that the return has at | ||
least `ndim` + 2 axes by potentially adding a new axis at dim 1. | ||
Defaults to True. | ||
|
||
Returns: | ||
A tuple (data, ds) where data is the transformed data tensor | ||
and ds contains the original shape. If `add_channel_dim` is True, | ||
`data` has `ndim` + 2 axes, otherwise `ndim` + 1. | ||
""" | ||
data_lst, ds = _preprocess_coeffs( | ||
[data], ndim=ndim, axes=axes, add_channel_dim=add_channel_dim | ||
) | ||
return approx, *cast_result_lst | ||
return data_lst[0], ds | ||
|
||
|
||
def _postprocess_tensor( | ||
data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] | ||
) -> torch.Tensor: | ||
return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0] |
Uh oh!
There was an error while loading. Please reload this page.