diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 70465d85..1491c6ad 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -6,7 +6,8 @@ import typing import warnings from collections.abc import Callable, Sequence -from typing import Any, NamedTuple, Optional, Protocol, Union, cast, overload +from functools import partial +from typing import Any, Literal, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np import pywt @@ -191,6 +192,50 @@ def _check_axes_argument(axes: Sequence[int]) -> None: raise ValueError("Cant transform the same axis twice.") +def _check_same_device( + tensor: torch.Tensor, torch_device: torch.device +) -> torch.Tensor: + if torch_device != tensor.device: + raise ValueError("coefficients must be on the same device") + return tensor + + +def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.Tensor: + if torch_dtype != tensor.dtype: + raise ValueError("coefficients must have the same dtype") + return tensor + + +def _check_same_device_dtype( + coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], +) -> tuple[torch.device, torch.dtype]: + """Check coefficients for dtype and device consistency. + + Check that all coefficient tensors in `coeffs` have the same + device and dtype. + + Args: + coeffs (Wavelet coefficients): The resulting coefficients of + a discrete wavelet transform. Can be either of + `list[torch.Tensor]` (1d case), + :data:`ptwt.constants.WaveletCoeff2d` (2d case) or + :data:`ptwt.constants.WaveletCoeffNd` (Nd case). + + Returns: + A tuple (device, dtype) with the shared device and dtype of + all tensors in coeffs. + """ + c = _check_if_tensor(coeffs[0]) + torch_device, torch_dtype = c.device, c.dtype + + # check for all tensors in `coeffs` that the device matches `torch_device` + _coeff_tree_map(coeffs, partial(_check_same_device, torch_device=torch_device)) + # check for all tensors in `coeffs` that the dtype matches `torch_dtype` + _coeff_tree_map(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) + + return torch_device, torch_dtype + + def _get_transpose_order( axes: Sequence[int], data_shape: Sequence[int] ) -> tuple[list[int], list[int]]: @@ -214,31 +259,53 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: @overload -def _map_result( - data: WaveletCoeff2d, +def _coeff_tree_map( + coeffs: list[torch.Tensor], + function: Callable[[torch.Tensor], torch.Tensor], +) -> list[torch.Tensor]: ... + + +@overload +def _coeff_tree_map( + coeffs: WaveletCoeff2d, function: Callable[[torch.Tensor], torch.Tensor], ) -> WaveletCoeff2d: ... @overload -def _map_result( - data: WaveletCoeffNd, +def _coeff_tree_map( + coeffs: WaveletCoeffNd, function: Callable[[torch.Tensor], torch.Tensor], ) -> WaveletCoeffNd: ... -def _map_result( - data: Union[WaveletCoeff2d, WaveletCoeffNd], +def _coeff_tree_map( + coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], -) -> Union[WaveletCoeff2d, WaveletCoeffNd]: - approx = function(data[0]) +) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: + """Apply `function` to all tensor elements in `coeffs`. + + Applying a function to all tensors in the (potentially nested) + coefficient data structure is a common requirement in coefficient + pre- and postprocessing. This function saves us from having to loop + over the coefficient data structures in processing. + + Conceptually, this function is inspired by the + pytree processing philosophy of the JAX framework, see + https://jax.readthedocs.io/en/latest/working-with-pytrees.html + + Raises: + ValueError: If the input type is not supported. + """ + approx = function(coeffs[0]) result_lst: list[ Union[ + torch.Tensor, WaveletDetailDict, WaveletDetailTuple2d, ] ] = [] - for element in data[1:]: + for element in coeffs[1:]: if isinstance(element, tuple): result_lst.append( WaveletDetailTuple2d( @@ -250,14 +317,326 @@ def _map_result( elif isinstance(element, dict): new_dict = {key: function(value) for key, value in element.items()} result_lst.append(new_dict) + elif isinstance(element, torch.Tensor): + result_lst.append(function(element)) else: raise ValueError(f"Unexpected input type {type(element)}") - # cast since we assume that the full list is of the same type - cast_result_lst = cast( - Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst + if not result_lst: + # if only approximation coeff: + # use list iff data is a list + return [approx] if isinstance(coeffs, list) else (approx,) + elif isinstance(result_lst[0], torch.Tensor): + # if the first detail coeff is tensor + # -> all are tensors -> return a list + return [approx] + cast(list[torch.Tensor], result_lst) + else: + # cast since we assume that the full list is of the same type + cast_result_lst = cast( + Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst + ) + return approx, *cast_result_lst + + +# 1d case +@overload +def _preprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: Literal[1], + axes: int, + add_channel_dim: bool = False, +) -> tuple[list[torch.Tensor], list[int]]: ... + + +# 2d case +@overload +def _preprocess_coeffs( + coeffs: WaveletCoeff2d, + ndim: Literal[2], + axes: tuple[int, int], + add_channel_dim: bool = False, +) -> tuple[WaveletCoeff2d, list[int]]: ... + + +# Nd case +@overload +def _preprocess_coeffs( + coeffs: WaveletCoeffNd, + ndim: int, + axes: tuple[int, ...], + add_channel_dim: bool = False, +) -> tuple[WaveletCoeffNd, list[int]]: ... + + +# list of nd tensors +@overload +def _preprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: int, + axes: Union[tuple[int, ...], int], + add_channel_dim: bool = False, +) -> tuple[list[torch.Tensor], list[int]]: ... + + +def _preprocess_coeffs( + coeffs: Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, + ], + ndim: int, + axes: Union[tuple[int, ...], int], + add_channel_dim: bool = False, +) -> tuple[ + Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, + ], + list[int], +]: + """Preprocess coeff tensor dimensions. + + For each coefficient tensor in `coeffs` the transformed axes + as specified by `axes` are moved to be the last. + Adds a batch dim if a coefficient tensor has none. + If it has has multiple batch dimensions, they are folded into a single + batch dimension. + + Args: + coeffs (Wavelet coefficients): The resulting coefficients of + a discrete wavelet transform. Can be either of + `list[torch.Tensor]` (1d case), + :data:`ptwt.constants.WaveletCoeff2d` (2d case) or + :data:`ptwt.constants.WaveletCoeffNd` (Nd case). + ndim (int): The number of axes :math:`N` on which the transformation + was applied. + axes (int or tuple of ints): Axes on which the transform was calculated. + add_channel_dim (bool): If True, ensures that all returned coefficients + have at least `:math:`N + 2` axes by potentially adding a new axis at dim 1. + Defaults to False. + + Returns: + A tuple ``(coeffs, ds)`` where ``coeffs`` are the transformed + coefficients and ``ds`` contains the original shape of ``coeffs[0]``. + If `add_channel_dim` is True, all coefficient tensors have + :math:`N + 2` axes ([B, 1, c1, ..., cN]). + otherwise :math:`N + 1` ([B, c1, ..., cN]). + + Raises: + ValueError: If the input dtype is unsupported or `ndim` does not + fit to the passed `axes` or `coeffs` dimensions. + """ + if isinstance(axes, int): + axes = (axes,) + + torch_dtype = _check_if_tensor(coeffs[0]).dtype + if not _is_dtype_supported(torch_dtype): + raise ValueError(f"Input dtype {torch_dtype} not supported") + + if ndim <= 0: + raise ValueError("Number of dimensions must be positive") + + if tuple(axes) != tuple(range(-ndim, 0)): + if len(axes) != ndim: + raise ValueError(f"{ndim}D transforms work with {ndim} axes.") + else: + # for all tensors in `coeffs`: swap the axes + swap_fn = partial(_swap_axes, axes=axes) + coeffs = _coeff_tree_map(coeffs, swap_fn) + + # Fold axes for the wavelets + ds = list(coeffs[0].shape) + if len(ds) < ndim: + raise ValueError(f"At least {ndim} input dimensions required.") + elif len(ds) == ndim: + # for all tensors in `coeffs`: unsqueeze(0) + coeffs = _coeff_tree_map(coeffs, lambda x: x.unsqueeze(0)) + elif len(ds) > ndim + 1: + # for all tensors in `coeffs`: fold leading dims to batch dim + coeffs = _coeff_tree_map(coeffs, lambda t: _fold_axes(t, ndim)[0]) + + if add_channel_dim: + # for all tensors in `coeffs`: add channel dim + coeffs = _coeff_tree_map(coeffs, lambda x: x.unsqueeze(1)) + + return coeffs, ds + + +# 1d case +@overload +def _postprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: Literal[1], + ds: list[int], + axes: int, +) -> list[torch.Tensor]: ... + + +# 2d case +@overload +def _postprocess_coeffs( + coeffs: WaveletCoeff2d, + ndim: Literal[2], + ds: list[int], + axes: tuple[int, int], +) -> WaveletCoeff2d: ... + + +# Nd case +@overload +def _postprocess_coeffs( + coeffs: WaveletCoeffNd, + ndim: int, + ds: list[int], + axes: tuple[int, ...], +) -> WaveletCoeffNd: ... + + +# list of nd tensors +@overload +def _postprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: int, + ds: list[int], + axes: Union[tuple[int, ...], int], +) -> list[torch.Tensor]: ... + + +def _postprocess_coeffs( + coeffs: Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, + ], + ndim: int, + ds: list[int], + axes: Union[tuple[int, ...], int], +) -> Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, +]: + """Postprocess coeff tensor dimensions. + + This revereses the operations of :func:`_preprocess_coeffs`. + + Unfolds potentially folded batch dimensions and removes any added + dimensions. + The transformed axes as specified by `axes` are moved back to their + original position. + + Args: + coeffs (Wavelet coefficients): The preprocessed coefficients of + a discrete wavelet transform. Can be either of + `list[torch.Tensor]` (1d case), + :data:`ptwt.constants.WaveletCoeff2d` (2d case) or + :data:`ptwt.constants.WaveletCoeffNd` (Nd case). + ndim (int): The number of axes :math:`N` on which the transformation was + applied. + ds (list of ints): The shape of the original first coefficient before + preprocessing, i.e. of ``coeffs[0]``. + axes (int or tuple of ints): Axes on which the transform was calculated. + + Returns: + The result of undoing the preprocessing operations on `coeffs`. + + Raises: + ValueError: If `ndim` does not fit to the passed `axes` + or `coeffs` dimensions. + """ + if isinstance(axes, int): + axes = (axes,) + + if ndim <= 0: + raise ValueError("Number of dimensions must be positive") + + # Fold axes for the wavelets + if len(ds) < ndim: + raise ValueError(f"At least {ndim} input dimensions required.") + elif len(ds) == ndim: + # for all tensors in `coeffs`: remove batch dim + coeffs = _coeff_tree_map(coeffs, lambda x: x.squeeze(0)) + elif len(ds) > ndim + 1: + # for all tensors in `coeffs`: unfold batch dim + unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) + coeffs = _coeff_tree_map(coeffs, unfold_axes_fn) + + if tuple(axes) != tuple(range(-ndim, 0)): + if len(axes) != ndim: + raise ValueError(f"{ndim}D transforms work with {ndim} axes.") + else: + # for all tensors in `coeffs`: undo axes swapping + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + coeffs = _coeff_tree_map(coeffs, undo_swap_fn) + + return coeffs + + +def _preprocess_tensor( + data: torch.Tensor, + ndim: int, + axes: Union[tuple[int, ...], int], + add_channel_dim: bool = True, +) -> tuple[torch.Tensor, list[int]]: + """Preprocess input tensor dimensions. + + The transformed axes as specified by `axes` are moved to be the last. + Adds a batch dim if `data` has none. + If `data` has multiple batch dimensions, they are folded into a single + batch dimension. + + Args: + data (torch.Tensor): An input tensor with at least `ndim` axes. + ndim (int): The number of axes :math:`N` on which the transformation is + applied. + axes (int or tuple of ints): Axes on which the transform is calculated. + add_channel_dim (bool): If True, ensures that the return has at + least :math:`N + 2` axes by potentially adding a new axis at dim 1. + Defaults to True. + + Returns: + A tuple ``(data, ds)`` where ``data`` is the transformed data tensor + and ``ds`` contains the original shape. + If `add_channel_dim` is True, + `data` has :math:`N + 2` axes ([B, 1, d1, ..., dN]). + otherwise :math:`N + 1` ([B, d1, ..., dN]). + """ + # interpreting data as the approximation coeffs of a 0-level FWT + # allows us to reuse the `_preprocess_coeffs` code + data_lst, ds = _preprocess_coeffs( + [data], ndim=ndim, axes=axes, add_channel_dim=add_channel_dim ) - return approx, *cast_result_lst + return data_lst[0], ds + + +def _postprocess_tensor( + data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] +) -> torch.Tensor: + """Postprocess input tensor dimensions. + + This revereses the operations of :func:`_preprocess_tensor`. + + Unfolds potentially folded batch dimensions and removes any added + dimensions. + The transformed axes as specified by `axes` are moved back to their + original position. + + Args: + data (torch.Tensor): An preprocessed input tensor. + ndim (int): The number of axes :math:`N` on which the transformation is + applied. + ds (list of ints): The shape of the original input tensor before + preprocessing. + axes (int or tuple of ints): Axes on which the transform was calculated. + + Returns: + The result of undoing the preprocessing operations on `data`. + """ + # interpreting data as the approximation coeffs of a 0-level FWT + # allows us to reuse the `_postprocess_coeffs` code + # return approx, *cast_result_lst + return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0] Param = ParamSpec("Param") diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 644aca80..c6fd4a39 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -14,11 +14,13 @@ from ._util import ( Wavelet, _as_wavelet, - _fold_axes, + _check_same_device_dtype, _get_len, - _is_dtype_supported, _pad_symmetric, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) from .constants import BoundaryMode, WaveletCoeff2d @@ -211,63 +213,6 @@ def _adjust_padding_at_reconstruction( return pad_end, pad_start -def _preprocess_tensor_dec1d( - data: torch.Tensor, -) -> tuple[torch.Tensor, list[int]]: - """Preprocess input tensor dimensions. - - Args: - data (torch.Tensor): An input tensor of any shape. - - Returns: - A tuple (data, ds) where data is a data tensor of shape - [new_batch, 1, to_process] and ds contains the original shape. - """ - ds = list(data.shape) - if len(ds) == 1: - # assume time series - data = data.unsqueeze(0).unsqueeze(0) - elif len(ds) == 2: - # assume batched time series - data = data.unsqueeze(1) - else: - data, ds = _fold_axes(data, 1) - data = data.unsqueeze(1) - return data, ds - - -def _postprocess_result_list_dec1d( - result_list: list[torch.Tensor], ds: list[int], axis: int -) -> list[torch.Tensor]: - if len(ds) == 1: - result_list = [r_el.squeeze(0) for r_el in result_list] - elif len(ds) > 2: - # Unfold axes for the wavelets - result_list = [_unfold_axes(fres, ds, 1) for fres in result_list] - else: - result_list = result_list - - if axis != -1: - result_list = [coeff.swapaxes(axis, -1) for coeff in result_list] - - return result_list - - -def _preprocess_result_list_rec1d( - result_lst: Sequence[torch.Tensor], -) -> tuple[Sequence[torch.Tensor], list[int]]: - # Fold axes for the wavelets - ds = list(result_lst[0].shape) - fold_coeffs: Sequence[torch.Tensor] - if len(ds) == 1: - fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst] - elif len(ds) > 2: - fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] - else: - fold_coeffs = result_lst - return fold_coeffs, ds - - def wavedec( data: torch.Tensor, wavelet: Union[Wavelet, str], @@ -315,10 +260,6 @@ def wavedec( containing the wavelet coefficients. A denotes approximation and D detail coefficients. - Raises: - ValueError: If the dtype of the input data tensor is unsupported or - if more than one axis is provided. - Example: >>> import torch >>> import ptwt, pywt @@ -330,16 +271,7 @@ def wavedec( >>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'), >>> mode='zero', level=2) """ - if axis != -1: - if isinstance(axis, int): - data = data.swapaxes(axis, -1) - else: - raise ValueError("wavedec transforms a single axis only.") - - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - - data, ds = _preprocess_tensor_dec1d(data) + data, ds = _preprocess_tensor(data, ndim=1, axes=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -360,9 +292,7 @@ def wavedec( result_list.append(res_lo.squeeze(1)) result_list.reverse() - result_list = _postprocess_result_list_dec1d(result_list, ds, axis) - - return result_list + return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis) def waverec( @@ -381,11 +311,6 @@ def waverec( Returns: The reconstructed signal tensor. - Raises: - ValueError: If the dtype of the coeffs tensor is unsupported or if the - coefficients have incompatible shapes, dtypes or devices or if - more than one axis is provided. - Example: >>> import torch >>> import ptwt, pywt @@ -399,29 +324,11 @@ def waverec( >>> pywt.Wavelet('haar')) """ - torch_device = coeffs[0].device - torch_dtype = coeffs[0].dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - - for coeff in coeffs[1:]: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - - if axis != -1: - swap = [] - if isinstance(axis, int): - for coeff in coeffs: - swap.append(coeff.swapaxes(axis, -1)) - coeffs = swap - else: - raise ValueError("waverec transforms a single axis only.") - - # fold channels, if necessary. - ds = list(coeffs[0].shape) - coeffs, ds = _preprocess_result_list_rec1d(coeffs) + # fold channels and swap axis, if necessary. + if not isinstance(coeffs, list): + coeffs = list(coeffs) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -446,12 +353,7 @@ def waverec( if padr > 0: res_lo = res_lo[..., :-padr] - if len(ds) == 1: - res_lo = res_lo.squeeze(0) - elif len(ds) > 2: - res_lo = _unfold_axes(res_lo, ds, 1) - - if axis != -1: - res_lo = res_lo.swapaxes(axis, -1) + # undo folding and swapping + res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis) return res_lo diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 30b6fcb2..ddb96231 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -6,7 +6,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Union import pywt @@ -14,18 +13,14 @@ from ._util import ( Wavelet, - _as_wavelet, - _check_axes_argument, - _check_if_tensor, - _fold_axes, + _check_same_device_dtype, _get_len, - _is_dtype_supported, - _map_result, _outer, _pad_symmetric, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d from .conv_transform import ( @@ -107,32 +102,6 @@ def _fwt_pad2( return data_pad -def _waverec2d_fold_channels_2d_list( - coeffs: WaveletCoeff2d, -) -> tuple[WaveletCoeff2d, list[int]]: - # fold the input coefficients for processing conv2d_transpose. - ds = list(_check_if_tensor(coeffs[0]).shape) - return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds - - -def _preprocess_tensor_dec2d( - data: torch.Tensor, -) -> tuple[torch.Tensor, Union[list[int], None]]: - # Preprocess multidimensional input. - ds = None - if len(data.shape) == 2: - data = data.unsqueeze(0).unsqueeze(0) - elif len(data.shape) == 3: - # add a channel dimension for torch. - data = data.unsqueeze(1) - elif len(data.shape) >= 4: - data, ds = _fold_axes(data, 2) - data = data.unsqueeze(1) - elif len(data.shape) == 1: - raise ValueError("More than one input dimension required.") - return data, ds - - def wavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], @@ -183,11 +152,6 @@ def wavedec2( A tuple containing the wavelet coefficients in pywt order, see :data:`ptwt.constants.WaveletCoeff2d`. - Raises: - ValueError: If the dimensionality or the dtype of the input data tensor - is unsupported or if the provided ``axes`` - input has a length other than two. - Example: >>> import torch >>> import ptwt, pywt @@ -200,17 +164,7 @@ def wavedec2( >>> level=2, mode="zero") """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - wavelet = _as_wavelet(wavelet) - data, ds = _preprocess_tensor_dec2d(data) + data, ds = _preprocess_tensor(data, ndim=2, axes=axes) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -234,13 +188,7 @@ def wavedec2( res_ll = res_ll.squeeze(1) result: WaveletCoeff2d = res_ll, *result_lst - if ds: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - result = _map_result(result, _unfold_axes2) - - if axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - result = _map_result(result, undo_swap_fn) + result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=axes) return result @@ -286,28 +234,8 @@ def waverec2( >>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar")) """ - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - _check_axes_argument(list(axes)) - swap_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_fn) - - ds = None - wavelet = _as_wavelet(wavelet) - - res_ll = _check_if_tensor(coeffs[0]) - torch_device = res_ll.device - torch_dtype = res_ll.dtype - - if res_ll.dim() >= 4: - # avoid the channel sum, fold the channels into batches. - coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs) - res_ll = _check_if_tensor(coeffs[0]) - - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") + coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -315,6 +243,7 @@ def waverec2( filt_len = rec_lo.shape[-1] rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) + res_ll = coeffs[0] for c_pos, coeff_tuple in enumerate(coeffs[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( @@ -325,11 +254,7 @@ def waverec2( curr_shape = res_ll.shape for coeff in coeff_tuple: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif coeff.shape != curr_shape: + if coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) @@ -362,10 +287,6 @@ def waverec2( if padr > 0: res_ll = res_ll[..., :-padr] - if ds: - res_ll = _unfold_axes(res_ll, list(ds), 2) - - if axes != (-2, -1): - res_ll = _undo_swap_axes(res_ll, list(axes)) + res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) return res_ll diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 2a527c20..ee9e75a6 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -5,7 +5,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Union import pywt @@ -14,19 +13,16 @@ from ._util import ( Wavelet, _as_wavelet, - _check_axes_argument, - _check_if_tensor, - _fold_axes, + _check_same_device_dtype, _get_len, - _is_dtype_supported, - _map_result, _outer, _pad_symmetric, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) -from .constants import BoundaryMode, WaveletCoeffNd +from .constants import BoundaryMode, WaveletCoeffNd, WaveletDetailDict from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, @@ -143,34 +139,12 @@ def wavedec3( A tuple containing the wavelet coefficients, see :data:`ptwt.constants.WaveletCoeffNd`. - Raises: - ValueError: If the input has fewer than three dimensions or - if the dtype is not supported or - if the provided axes input has length other than three. - Example: >>> import ptwt, torch >>> data = torch.randn(5, 16, 16, 16) >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") """ - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("3D transforms work with three axes.") - else: - _check_axes_argument(list(axes)) - data = _swap_axes(data, list(axes)) - - ds = None - if data.dim() < 3: - raise ValueError("At least three dimensions are required for 3d wavedec.") - elif len(data.shape) == 3: - data = data.unsqueeze(1) - else: - data, ds = _fold_axes(data, 3) - data = data.unsqueeze(1) - - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") + data, ds = _preprocess_tensor(data, ndim=3, axes=axes) wavelet = _as_wavelet(wavelet) dec_lo, dec_hi, _, _ = _get_filter_tensors( @@ -183,7 +157,7 @@ def wavedec3( [data.shape[-1], data.shape[-2], data.shape[-3]], wavelet ) - result_lst: list[dict[str, torch.Tensor]] = [] + result_lst: list[WaveletDetailDict] = [] res_lll = data for _ in range(level): if len(res_lll.shape) == 4: @@ -205,34 +179,9 @@ def wavedec3( } ) result_lst.reverse() - result: WaveletCoeffNd = res_lll, *result_lst + coeffs: WaveletCoeffNd = res_lll, *result_lst - if ds: - _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) - result = _map_result(result, _unfold_axes_fn) - - if tuple(axes) != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - result = _map_result(result, undo_swap_fn) - - return result - - -def _waverec3d_fold_channels_3d_list( - coeffs: WaveletCoeffNd, -) -> tuple[ - WaveletCoeffNd, - list[int], -]: - # fold the input coefficients for processing conv2d_transpose. - fold_approx_coeff = _fold_axes(coeffs[0], 3)[0] - fold_coeffs: list[dict[str, torch.Tensor]] = [] - ds = list(_check_if_tensor(coeffs[0]).shape) - fold_coeffs = [ - {key: _fold_axes(value, 3)[0] for key, value in coeff.items()} - for coeff in coeffs[1:] - ] - return (fold_approx_coeff, *fold_coeffs), ds + return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=axes) def waverec3( @@ -267,32 +216,8 @@ def waverec3( >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") >>> reconstruction = ptwt.waverec3(transformed, "haar") """ - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("3D transforms work with three axes") - else: - _check_axes_argument(list(axes)) - swap_axes_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_axes_fn) - - wavelet = _as_wavelet(wavelet) - ds = None - # the Union[tensor, dict] idea is coming from pywt. We don't change it here. - res_lll = _check_if_tensor(coeffs[0]) - if res_lll.dim() < 3: - raise ValueError( - "Three dimensional transforms require at least three dimensions." - ) - elif res_lll.dim() >= 5: - coeffs, ds = _waverec3d_fold_channels_3d_list(coeffs) - res_lll = _check_if_tensor(coeffs[0]) - - torch_device = res_lll.device - torch_dtype = res_lll.dtype - - if not _is_dtype_supported(torch_dtype): - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") + coeffs, ds = _preprocess_coeffs(coeffs, ndim=3, axes=axes) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -300,6 +225,7 @@ def waverec3( filt_len = rec_lo.shape[-1] rec_filt = _construct_3d_filt(lo=rec_lo, hi=rec_hi) + res_lll = coeffs[0] coeff_dicts = coeffs[1:] for c_pos, coeff_dict in enumerate(coeff_dicts): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: @@ -309,11 +235,7 @@ def waverec3( "wavedec3." ) for coeff in coeff_dict.values(): - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif res_lll.shape != coeff.shape: + if res_lll.shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" ) @@ -362,11 +284,5 @@ def waverec3( res_lll = res_lll[..., padfr:, :, :] if padba > 0: res_lll = res_lll[..., :-padba, :, :] - res_lll = res_lll.squeeze(1) - - if ds: - res_lll = _unfold_axes(res_lll, ds, 3) - if axes != (-3, -2, -1): - res_lll = _undo_swap_axes(res_lll, list(axes)) - return res_lll + return _postprocess_tensor(res_lll, ndim=3, ds=ds, axes=axes) diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 71597564..f4082e13 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -17,19 +17,16 @@ from ._util import ( Wavelet, _as_wavelet, + _check_same_device_dtype, _deprecated_alias, - _is_dtype_supported, _is_orthogonalize_method_supported, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) from .constants import BoundaryMode, OrthogonalizeMethod -from .conv_transform import ( - _fwt_pad, - _get_filter_tensors, - _postprocess_result_list_dec1d, - _preprocess_result_list_rec1d, - _preprocess_tensor_dec1d, -) +from .conv_transform import _fwt_pad, _get_filter_tensors from .sparse_math import ( _orth_by_gram_schmidt, _orth_by_qr, @@ -347,14 +344,12 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - if self.axis != -1: - input_signal = input_signal.swapaxes(self.axis, -1) - - input_signal, ds = _preprocess_tensor_dec1d(input_signal) - input_signal = input_signal.squeeze(1) - - if not _is_dtype_supported(input_signal.dtype): - raise ValueError(f"Input dtype {input_signal.dtype} not supported") + input_signal, ds = _preprocess_tensor( + input_signal, + ndim=1, + axes=self.axis, + add_channel_dim=False, + ) if input_signal.shape[-1] % 2 != 0: # odd length input @@ -405,8 +400,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: result_list = [s.T for s in split_list[::-1]] # unfold if necessary - result_list = _postprocess_result_list_dec1d(result_list, ds, self.axis) - return result_list + return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=self.axis) @_deprecated_alias(boundary="orthogonalization") @@ -638,13 +632,10 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: coefficients are not in the shape as it is returned from a `MatrixWavedec` object. """ - if self.axis != -1: - swap = [] - for coeff in coefficients: - swap.append(coeff.swapaxes(self.axis, -1)) - coefficients = swap - - coefficients, ds = _preprocess_result_list_rec1d(coefficients) + if not isinstance(coefficients, list): + coefficients = list(coefficients) + coefficients, ds = _preprocess_coeffs(coefficients, ndim=1, axes=self.axis) + torch_device, torch_dtype = _check_same_device_dtype(coefficients) level = len(coefficients) - 1 input_length = coefficients[-1].shape[-1] * 2 @@ -655,17 +646,6 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: self.input_length = input_length re_build = True - torch_device = coefficients[0].device - torch_dtype = coefficients[0].dtype - for coeff in coefficients[1:]: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, @@ -696,12 +676,4 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: res_lo = lo.T - if len(ds) == 1: - res_lo = res_lo.squeeze(0) - elif len(ds) > 2: - res_lo = _unfold_axes(res_lo, ds, 1) - - if self.axis != -1: - res_lo = res_lo.swapaxes(self.axis, -1) - - return res_lo + return _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=self.axis) diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 70ef68fc..c36b4402 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -6,7 +6,6 @@ from __future__ import annotations import sys -from functools import partial from typing import Optional, Union, cast import numpy as np @@ -16,14 +15,13 @@ Wavelet, _as_wavelet, _check_axes_argument, - _check_if_tensor, + _check_same_device_dtype, _deprecated_alias, - _is_dtype_supported, _is_orthogonalize_method_supported, - _map_result, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) from .constants import ( BoundaryMode, @@ -33,12 +31,7 @@ WaveletDetailTuple2d, ) from .conv_transform import _get_filter_tensors -from .conv_transform_2 import ( - _construct_2d_filt, - _fwt_pad2, - _preprocess_tensor_dec2d, - _waverec2d_fold_channels_2d_list, -) +from .conv_transform_2 import _construct_2d_filt, _fwt_pad2 from .matmul_transform import ( BaseMatrixWaveDec, construct_boundary_a, @@ -127,7 +120,6 @@ def _construct_s_2( Returns: The generated fast wavelet synthesis matrix. """ - wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=True, device=device, dtype=dtype ) @@ -319,8 +311,8 @@ def __init__( if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: - _check_axes_argument(list(axes)) - self.axes = tuple(axes) + _check_axes_argument(axes) + self.axes = axes self.level = level self.orthogonalization = orthogonalization self.odd_coeff_padding_mode = odd_coeff_padding_mode @@ -464,17 +456,12 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - if self.axes != (-2, -1): - input_signal = _swap_axes(input_signal, list(self.axes)) - - input_signal, ds = _preprocess_tensor_dec2d(input_signal) - input_signal = input_signal.squeeze(1) + input_signal, ds = _preprocess_tensor( + input_signal, ndim=2, axes=self.axes, add_channel_dim=False + ) batch_size, height, width = input_signal.shape - if not _is_dtype_supported(input_signal.dtype): - raise ValueError(f"Input dtype {input_signal.dtype} not supported") - re_build = False if ( self.input_signal_shape is None @@ -563,13 +550,7 @@ def _add_padding(signal: torch.Tensor, pad: tuple[bool, bool]) -> torch.Tensor: split_list.reverse() result: WaveletCoeff2d = ll, *split_list - if ds: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - result = _map_result(result, _unfold_axes2) - - if self.axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) - result = _map_result(result, undo_swap_fn) + result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=self.axes) return result @@ -631,7 +612,7 @@ def __init__( if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: - _check_axes_argument(list(axes)) + _check_axes_argument(axes) self.axes = axes self.ifwt_matrix_list: list[ @@ -768,23 +749,8 @@ def __call__( coefficients are not in the shape as it is returned from a `MatrixWavedec2` object. """ - ll = _check_if_tensor(coefficients[0]) - - if tuple(self.axes) != (-2, -1): - swap_fn = partial(_swap_axes, axes=list(self.axes)) - coefficients = _map_result(coefficients, swap_fn) - ll = _check_if_tensor(coefficients[0]) - - ds = None - if ll.dim() == 1: - raise ValueError("2d transforms require more than a single input dim.") - elif ll.dim() == 2: - # add batch dim to unbatched input - ll = ll.unsqueeze(0) - elif ll.dim() >= 4: - # avoid the channel sum, fold the channels into batches. - coefficients, ds = _waverec2d_fold_channels_2d_list(coefficients) - ll = _check_if_tensor(coefficients[0]) + coefficients, ds = _preprocess_coeffs(coefficients, ndim=2, axes=self.axes) + torch_device, torch_dtype = _check_same_device_dtype(coefficients) level = len(coefficients) - 1 height, width = tuple(c * 2 for c in coefficients[-1][0].shape[-2:]) @@ -802,19 +768,14 @@ def __call__( self.level = level re_build = True - batch_size = ll.shape[0] - torch_device = ll.device - torch_dtype = ll.dtype - - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) + ll = coefficients[0] + batch_size = ll.shape[0] for c_pos, coeff_tuple in enumerate(coefficients[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( @@ -825,11 +786,7 @@ def __call__( curr_shape = ll.shape for coeff in coeff_tuple: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif coeff.shape != curr_shape: + if coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) @@ -877,9 +834,6 @@ def __call__( if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] - if ds: - ll = _unfold_axes(ll, list(ds), 2) + ll = _postprocess_tensor(ll, ndim=2, ds=ds, axes=self.axes) - if self.axes != (-2, -1): - ll = _undo_swap_axes(ll, list(self.axes)) return ll diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 0624b363..bd0249eb 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -13,18 +13,21 @@ Wavelet, _as_wavelet, _check_axes_argument, - _check_if_tensor, + _check_same_device_dtype, _deprecated_alias, - _fold_axes, - _is_dtype_supported, _is_orthogonalize_method_supported, - _map_result, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) -from .constants import BoundaryMode, OrthogonalizeMethod, WaveletCoeffNd -from .conv_transform_3 import _fwt_pad3, _waverec3d_fold_channels_3d_list +from .constants import ( + BoundaryMode, + OrthogonalizeMethod, + WaveletCoeffNd, + WaveletDetailDict, +) +from .conv_transform_3 import _fwt_pad3 from .matmul_transform import construct_boundary_a, construct_boundary_s from .sparse_math import _batch_dim_mm @@ -187,22 +190,11 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: Raises: ValueError: If the input dimensions don't work. """ - if self.axes != (-3, -2, -1): - input_signal = _swap_axes(input_signal, list(self.axes)) - - ds = None - if input_signal.dim() < 3: - raise ValueError("At least three dimensions are required for 3d wavedec.") - elif len(input_signal.shape) == 3: - input_signal = input_signal.unsqueeze(1) - else: - input_signal, ds = _fold_axes(input_signal, 3) - + input_signal, ds = _preprocess_tensor( + input_signal, ndim=3, axes=self.axes, add_channel_dim=False + ) _, depth, height, width = input_signal.shape - if not _is_dtype_supported(input_signal.dtype): - raise ValueError(f"Input dtype {input_signal.dtype} not supported") - re_build = False if ( self.input_signal_shape is None @@ -233,7 +225,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: device=input_signal.device, dtype=input_signal.dtype ) - split_list: list[dict[str, torch.Tensor]] = [] + split_list: list[WaveletDetailDict] = [] lll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): pad_tuple = self.pad_list[scale] @@ -254,17 +246,17 @@ def _split_rec( tensor: torch.Tensor, key: str, depth: int, - dict: dict[str, torch.Tensor], + to_dict: WaveletDetailDict, ) -> None: if key: - dict[key] = tensor + to_dict[key] = tensor if len(key) < depth: dim = len(key) + 1 ca, cd = torch.split(tensor, tensor.shape[-dim] // 2, dim=-dim) - _split_rec(ca, "a" + key, depth, dict) - _split_rec(cd, "d" + key, depth, dict) + _split_rec(ca, "a" + key, depth, to_dict) + _split_rec(cd, "d" + key, depth, to_dict) - coeff_dict: dict[str, torch.Tensor] = {} + coeff_dict: WaveletDetailDict = {} _split_rec(lll, "", 3, coeff_dict) lll = coeff_dict["aaa"] result_keys = list( @@ -276,17 +268,9 @@ def _split_rec( split_list.append(coeff_dict) split_list.reverse() - result: WaveletCoeffNd = lll, *split_list + coeffs: WaveletCoeffNd = lll, *split_list - if ds: - _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) - result = _map_result(result, _unfold_axes_fn) - - if self.axes != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) - result = _map_result(result, undo_swap_fn) - - return result + return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=self.axes) class MatrixWaverec3(object): @@ -392,7 +376,7 @@ def _construct_synthesis_matrices( current_width // 2, ) - def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Tensor: + def _cat_coeff_recursive(self, input_dict: WaveletDetailDict) -> torch.Tensor: done_dict = {} a_initial_keys = list(filter(lambda x: x[0] == "a", input_dict.keys())) for a_key in a_initial_keys: @@ -422,20 +406,8 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: Raises: ValueError: If the data structure is inconsistent. """ - if self.axes != (-3, -2, -1): - swap_axes_fn = partial(_swap_axes, axes=list(self.axes)) - coefficients = _map_result(coefficients, swap_axes_fn) - - ds = None - # the Union[tensor, dict] idea is coming from pywt. We don't change it here. - res_lll = _check_if_tensor(coefficients[0]) - if res_lll.dim() < 3: - raise ValueError( - "Three dimensional transforms require at least three dimensions." - ) - elif res_lll.dim() >= 5: - coefficients, ds = _waverec3d_fold_channels_3d_list(coefficients) - res_lll = _check_if_tensor(coefficients[0]) + coefficients, ds = _preprocess_coeffs(coefficients, ndim=3, axes=self.axes) + torch_device, torch_dtype = _check_same_device_dtype(coefficients) level = len(coefficients) - 1 if type(coefficients[-1]) is dict: @@ -459,25 +431,13 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: self.level = level re_build = True - lll = coefficients[0] - if not isinstance(lll, torch.Tensor): - raise ValueError( - "First element of coeffs must be the approximation coefficient tensor." - ) - - torch_device = lll.device - torch_dtype = lll.dtype - - if not _is_dtype_supported(torch_dtype): - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) + lll = coefficients[0] for c_pos, coeff_dict in enumerate(coefficients[1:]): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: raise ValueError( @@ -489,10 +449,6 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: for coeff in coeff_dict.values(): if test_shape is None: test_shape = coeff.shape - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") elif test_shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" @@ -504,10 +460,4 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: for dim, mat in enumerate(self.ifwt_matrix_list[level - 1 - c_pos][::-1]): lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1)) - if ds: - lll = _unfold_axes(lll, ds, 3) - - if self.axes != (-3, -2, -1): - lll = _undo_swap_axes(lll, list(self.axes)) - - return lll + return _postprocess_tensor(lll, ndim=3, ds=ds, axes=self.axes) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index c89b9931..4e715e0d 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -9,7 +9,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Union import numpy as np @@ -18,22 +17,23 @@ from ._util import ( Wavelet, _as_wavelet, - _check_axes_argument, - _check_if_tensor, - _fold_axes, - _is_dtype_supported, - _map_result, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _check_same_device_dtype, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, +) +from .constants import ( + BoundaryMode, + WaveletCoeff2dSeparable, + WaveletCoeffNd, + WaveletDetailDict, ) -from .constants import BoundaryMode, WaveletCoeff2dSeparable, WaveletCoeffNd from .conv_transform import wavedec, waverec -from .conv_transform_2 import _preprocess_tensor_dec2d def _separable_conv_dwtn_( - rec_dict: dict[str, torch.Tensor], + rec_dict: WaveletDetailDict, input_arg: torch.Tensor, wavelet: Union[Wavelet, str], *, @@ -45,6 +45,8 @@ def _separable_conv_dwtn_( All but the first axes are transformed. Args: + rec_dict (WaveletDetailDict): The result will be stored here + in place. input_arg (torch.Tensor): Tensor of shape ``[batch, data_1, ... data_n]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. @@ -54,8 +56,6 @@ def _separable_conv_dwtn_( Defaults to "reflect". key (str): The filter application path. Defaults to "". - dict (dict[str, torch.Tensor]): The result will be stored here - in place. Defaults to {}. """ axis_total = len(input_arg.shape) - 1 if len(key) == axis_total: @@ -70,12 +70,12 @@ def _separable_conv_dwtn_( def _separable_conv_idwtn( - in_dict: dict[str, torch.Tensor], wavelet: Union[Wavelet, str] + in_dict: WaveletDetailDict, wavelet: Union[Wavelet, str] ) -> torch.Tensor: """Separable single level inverse fast wavelet transform. Args: - in_dict (dict[str, torch.Tensor]): The dictionary produced + in_dict (WaveletDetailDict): The dictionary produced by _separable_conv_dwtn_ . wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used by ``_separable_conv_dwtn_``. @@ -131,7 +131,7 @@ def _separable_conv_wavedecn( 'a' denoting the low pass or approximation filter and 'd' the high-pass or detail filter. """ - result: list[dict[str, torch.Tensor]] = [] + result: list[WaveletDetailDict] = [] approx = input if level is None: @@ -141,7 +141,7 @@ def _separable_conv_wavedecn( ) for _ in range(level): - level_dict: dict[str, torch.Tensor] = {} + level_dict: WaveletDetailDict = {} _separable_conv_dwtn_(level_dict, approx, wavelet, mode=mode, key="") approx_key = "a" * (len(input.shape) - 1) approx = level_dict.pop(approx_key) @@ -216,38 +216,13 @@ def fswavedec2( as keys. 'a' denotes the low pass or approximation filter and 'd' the high-pass or detail filter. - Raises: - ValueError: If the data is not a batched 2D signal. - Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - wavelet = _as_wavelet(wavelet) - data, ds = _preprocess_tensor_dec2d(data) - data = data.squeeze(1) - res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - - if ds: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - res = _map_result(res, _unfold_axes2) - - if axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - res = _map_result(res, undo_swap_fn) - - return res + return _fswavedecn(data, wavelet, ndim=2, mode=mode, level=level, axes=axes) def fswavedec3( @@ -284,43 +259,13 @@ def fswavedec3( as keys. 'a' denotes the low pass or approximation filter and 'd' the high-pass or detail filter. - Raises: - ValueError: If the input is not a batched 3D signal. - - Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - wavelet = _as_wavelet(wavelet) - ds = None - if len(data.shape) >= 5: - data, ds = _fold_axes(data, 3) - elif len(data.shape) < 4: - raise ValueError("At lest four input dimensions are required.") - data = data.squeeze(1) - res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - - if ds: - _unfold_axes3 = partial(_unfold_axes, ds=ds, keep_no=3) - res = _map_result(res, _unfold_axes3) - - if axes != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - res = _map_result(res, undo_swap_fn) - - return res + return _fswavedecn(data, wavelet, ndim=3, mode=mode, level=level, axes=axes) def fswaverec2( @@ -347,9 +292,6 @@ def fswaverec2( Returns: A reconstruction of the signal encoded in the wavelet coefficients. - Raises: - ValueError: If the axes argument is not a tuple of two integers. - Example: >>> import torch >>> import ptwt @@ -357,38 +299,7 @@ def fswaverec2( >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") """ - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - _check_axes_argument(list(axes)) - swap_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_fn) - - ds = None - wavelet = _as_wavelet(wavelet) - - res_ll = _check_if_tensor(coeffs[0]) - torch_dtype = res_ll.dtype - - if res_ll.dim() >= 4: - # avoid the channel sum, fold the channels into batches. - ds = _check_if_tensor(coeffs[0]).shape - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) - res_ll = _check_if_tensor(coeffs[0]) - - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - - res_ll = _separable_conv_waverecn(coeffs, wavelet) - - if ds: - res_ll = _unfold_axes(res_ll, list(ds), 2) - - if axes != (-2, -1): - res_ll = _undo_swap_axes(res_ll, list(axes)) - - return res_ll + return _fswaverecn(coeffs, wavelet, ndim=2, axes=axes) def fswaverec3( @@ -412,10 +323,6 @@ def fswaverec3( Returns: A reconstruction of the signal encoded in the wavelet coefficients. - Raises: - ValueError: If the axes argument is not a tuple with - three ints. - Example: >>> import torch >>> import ptwt @@ -423,34 +330,96 @@ def fswaverec3( >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") """ - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("2D transforms work with two axes.") - else: - _check_axes_argument(list(axes)) - swap_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_fn) + return _fswaverecn(coeffs, wavelet, ndim=3, axes=axes) + + +def _fswavedecn( + data: torch.Tensor, + wavelet: Union[Wavelet, str], + ndim: int, + *, + mode: BoundaryMode = "reflect", + level: Optional[int] = None, + axes: Optional[tuple[int, ...]] = None, +) -> WaveletCoeffNd: + """Compute a fully separable :math:`N`-dimensional padded FWT. + + Args: + data (torch.Tensor): An input signal with at least :math:`N` dimensions. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. Refer to the output of + ``pywt.wavelist(kind="discrete")`` for possible choices. + ndim (int): The number of dimentsions :math:`N`. + mode: + The desired padding mode for extending the signal along the edges. + Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. + level (int): The number of desired scales. Defaults to None. + axes (tuple[int, ...], optional): Compute the transform over these axes + instead of the last :math:`N`. If None, the last :math:`N` + axes are transformed. Defaults to None. + + Returns: + A tuple with the lll coefficients and for each scale a dictionary + containing the detail coefficients, + see :data:`ptwt.constants.WaveletCoeffNd`. + + Example: + >>> import torch + >>> from ptwt.separable_conv_transform import _fswavedecn + >>> data = torch.randn(5, 10, 10, 10) + >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) - ds = None - wavelet = _as_wavelet(wavelet) - res_ll = _check_if_tensor(coeffs[0]) - torch_dtype = res_ll.dtype + Note: + ND-Transforms are generally out of this project's scope. + """ + if axes is None: + axes = tuple(range(-ndim, 0)) - if res_ll.dim() >= 5: - # avoid the channel sum, fold the channels into batches. - ds = _check_if_tensor(coeffs[0]).shape - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 3)[0]) - res_ll = _check_if_tensor(coeffs[0]) + data, ds = _preprocess_tensor(data, ndim=ndim, axes=axes, add_channel_dim=False) + coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) + return _postprocess_coeffs(coeffs, ndim=ndim, ds=ds, axes=axes) - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - res_ll = _separable_conv_waverecn(coeffs, wavelet) +def _fswaverecn( + coeffs: WaveletCoeffNd, + wavelet: Union[Wavelet, str], + ndim: int, + axes: Optional[tuple[int, ...]] = None, +) -> torch.Tensor: + """Invert a fully separable :math:`N`-dimensional padded FWT. + + Args: + coeffs (WaveletCoeffNd): + The wavelet coefficients as computed by :func:`fswavedecn`, + see :data:`ptwt.constants.WaveletCoeffNd`. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. + ndim (int): The number of dimentsions :math:`N`. + axes (tuple[int, ...], optional): Compute the transform over these axes + instead of the last :math:`N`. If None, the last :math:`N` + axes are transformed. Defaults to None. - if ds: - res_ll = _unfold_axes(res_ll, list(ds), 3) + Returns: + A reconstruction of the signal encoded in the wavelet coefficients. + + Example: + >>> import torch + >>> from ptwt.separable_conv_transform import _fswavedecn, _fswaverecn + >>> data = torch.randn(5, 10, 10, 10) + >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) + >>> rec = _fswaverecn(coeff, "haar", ndim=3) + + Note: + ND-Transforms are generally out of this project's scope. + """ + if axes is None: + axes = tuple(range(-ndim, 0)) - if axes != (-3, -2, -1): - res_ll = _undo_swap_axes(res_ll, list(axes)) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=ndim, axes=axes) + _check_same_device_dtype(coeffs) + + res_ll = _separable_conv_waverecn(coeffs, wavelet) - return res_ll + return _postprocess_tensor(res_ll, ndim=ndim, ds=ds, axes=axes) diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index c4be3bc7..b93ffcd0 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -7,13 +7,16 @@ import torch import torch.nn.functional as F # noqa:N812 -from ._util import Wavelet, _as_wavelet, _unfold_axes -from .conv_transform import ( - _get_filter_tensors, - _postprocess_result_list_dec1d, - _preprocess_result_list_rec1d, - _preprocess_tensor_dec1d, +from ._util import ( + Wavelet, + _as_wavelet, + _check_same_device_dtype, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) +from .conv_transform import _get_filter_tensors def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor: @@ -69,17 +72,8 @@ def swt( Returns: Same as wavedec. Equivalent to pywt.swt with trim_approx=True. - - Raises: - ValueError: Is the axis argument is not an integer. """ - if axis != -1: - if isinstance(axis, int): - data = data.swapaxes(axis, -1) - else: - raise ValueError("swt transforms a single axis only.") - - data, ds = _preprocess_tensor_dec1d(data) + data, ds = _preprocess_tensor(data, ndim=1, axes=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -102,16 +96,15 @@ def swt( # result_list.append((res_lo.squeeze(1), res_hi.squeeze(1))) result_list.append(res_hi.squeeze(1)) result_list.append(res_lo.squeeze(1)) + result_list.reverse() - result_list = _postprocess_result_list_dec1d(result_list, ds, axis) - - return result_list[::-1] + return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis) def iswt( coeffs: Sequence[torch.Tensor], wavelet: Union[pywt.Wavelet, str], - axis: Optional[int] = -1, + axis: int = -1, ) -> torch.Tensor: """Invert a 1d stationary wavelet transform. @@ -120,29 +113,20 @@ def iswt( by the swt function. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used in the forward transform. - axis (int, optional): The axis the forward trasform was computed over. + axis (int): The axis the forward trasform was computed over. Defaults to -1. Returns: A reconstruction of the original swt input. - - Raises: - ValueError: If the axis argument is not an integer. """ - if axis != -1: - swap = [] - if isinstance(axis, int): - for coeff in coeffs: - swap.append(coeff.swapaxes(axis, -1)) - coeffs = swap - else: - raise ValueError("iswt transforms a single axis only.") - - coeffs, ds = _preprocess_result_list_rec1d(coeffs) + if not isinstance(coeffs, list): + coeffs = list(coeffs) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( - wavelet, flip=False, dtype=coeffs[0].dtype, device=coeffs[0].device + wavelet, flip=False, dtype=torch_dtype, device=torch_device ) filt_len = rec_lo.shape[-1] rec_filt = torch.stack([rec_lo, rec_hi], 0) @@ -161,12 +145,6 @@ def iswt( 1, ) - if len(ds) == 1: - res_lo = res_lo.squeeze(0) - elif len(ds) > 2: - res_lo = _unfold_axes(res_lo, ds, 1) - - if axis != -1: - res_lo = res_lo.swapaxes(axis, -1) + res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis) return res_lo