Skip to content

Make preprocessing and postprocessing consistent accross transforms #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0a1b515
Integrate axis swap into 1d processing funcs
felixblanke Jun 24, 2024
2426060
Make channel dim addition optional
felixblanke Jun 24, 2024
eaa2fb6
Refactor; Add processing funcs for 2d
felixblanke Jun 24, 2024
0e4cb27
Move processung funcs to _util module
felixblanke Jun 24, 2024
3554a53
Generalize tensor processing
felixblanke Jun 24, 2024
4e8ed02
Adapt 1d cases
felixblanke Jun 24, 2024
f889158
Extend _map_result to 1d case
felixblanke Jun 24, 2024
0ed2896
Make _preprocess_coeffs general
felixblanke Jun 24, 2024
03f74b7
Make postprocessing general
felixblanke Jun 24, 2024
ae7b345
Apply process funcs in 3d transforms
felixblanke Jun 24, 2024
21332bf
Format
felixblanke Jun 24, 2024
b461c97
Fix coeff postprocess
felixblanke Jun 24, 2024
105a0a8
Reduce tensor processing to coeff processing
felixblanke Jun 24, 2024
24a25f4
Add fully separable transforms for n dims
felixblanke Jun 24, 2024
169d6cf
Move dtype check to preprocessing
felixblanke Jun 24, 2024
ba1b83d
Encapsulate check for consistent dtype and device
felixblanke Jun 24, 2024
d50a2c2
Revert changes to coeff shape check
felixblanke Jun 24, 2024
2d4620a
Make n-dim separable trafo private
felixblanke Jun 25, 2024
1b45eae
Add explainatory comments
felixblanke Jun 25, 2024
15af78b
Add docstrings
felixblanke Jun 25, 2024
35bab80
Rename _map_result to _apply_to_tensor_elems
felixblanke Jun 25, 2024
37db5db
Format
felixblanke Jun 25, 2024
b4c16e2
rename tree_map
v0lta Jun 26, 2024
a9569f2
rename coeff tree map.
v0lta Jun 26, 2024
d1339f0
Add remark on JAX tree map
felixblanke Jun 26, 2024
6254f13
Merge branch 'main' into fix/keep-ndims-Nd
v0lta Jul 1, 2024
e281c85
merge.
v0lta Jul 1, 2024
75ad167
formatting.
v0lta Jul 1, 2024
e3fe0f4
fix typing.
v0lta Jul 1, 2024
d6d8259
Fix ndim sep trafo usage comments
felixblanke Jul 1, 2024
de1b7c8
Fix docstr
felixblanke Jul 1, 2024
986692f
nd-transforms are out of scope.
v0lta Jul 1, 2024
f008cfa
Merge branch 'fix/keep-ndims-Nd' of github.com:v0lta/PyTorch-Wavelet-…
v0lta Jul 1, 2024
77ca368
short note.
v0lta Jul 1, 2024
cef6a04
move note.
v0lta Jul 1, 2024
9981521
add note to forward and backward.
v0lta Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
407 changes: 393 additions & 14 deletions src/ptwt/_util.py

Large diffs are not rendered by default.

126 changes: 14 additions & 112 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from ._util import (
Wavelet,
_as_wavelet,
_fold_axes,
_check_same_device_dtype,
_get_len,
_is_dtype_supported,
_pad_symmetric,
_unfold_axes,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
)
from .constants import BoundaryMode, WaveletCoeff2d

Expand Down Expand Up @@ -211,63 +213,6 @@ def _adjust_padding_at_reconstruction(
return pad_end, pad_start


def _preprocess_tensor_dec1d(
data: torch.Tensor,
) -> tuple[torch.Tensor, list[int]]:
"""Preprocess input tensor dimensions.

Args:
data (torch.Tensor): An input tensor of any shape.

Returns:
A tuple (data, ds) where data is a data tensor of shape
[new_batch, 1, to_process] and ds contains the original shape.
"""
ds = list(data.shape)
if len(ds) == 1:
# assume time series
data = data.unsqueeze(0).unsqueeze(0)
elif len(ds) == 2:
# assume batched time series
data = data.unsqueeze(1)
else:
data, ds = _fold_axes(data, 1)
data = data.unsqueeze(1)
return data, ds


def _postprocess_result_list_dec1d(
result_list: list[torch.Tensor], ds: list[int], axis: int
) -> list[torch.Tensor]:
if len(ds) == 1:
result_list = [r_el.squeeze(0) for r_el in result_list]
elif len(ds) > 2:
# Unfold axes for the wavelets
result_list = [_unfold_axes(fres, ds, 1) for fres in result_list]
else:
result_list = result_list

if axis != -1:
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]

return result_list


def _preprocess_result_list_rec1d(
result_lst: Sequence[torch.Tensor],
) -> tuple[Sequence[torch.Tensor], list[int]]:
# Fold axes for the wavelets
ds = list(result_lst[0].shape)
fold_coeffs: Sequence[torch.Tensor]
if len(ds) == 1:
fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst]
elif len(ds) > 2:
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
else:
fold_coeffs = result_lst
return fold_coeffs, ds


def wavedec(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
Expand Down Expand Up @@ -315,10 +260,6 @@ def wavedec(
containing the wavelet coefficients. A denotes
approximation and D detail coefficients.

Raises:
ValueError: If the dtype of the input data tensor is unsupported or
if more than one axis is provided.

Example:
>>> import torch
>>> import ptwt, pywt
Expand All @@ -330,16 +271,7 @@ def wavedec(
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>> mode='zero', level=2)
"""
if axis != -1:
if isinstance(axis, int):
data = data.swapaxes(axis, -1)
else:
raise ValueError("wavedec transforms a single axis only.")

if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")

data, ds = _preprocess_tensor_dec1d(data)
data, ds = _preprocess_tensor(data, ndim=1, axes=axis)

dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
Expand All @@ -360,9 +292,7 @@ def wavedec(
result_list.append(res_lo.squeeze(1))
result_list.reverse()

result_list = _postprocess_result_list_dec1d(result_list, ds, axis)

return result_list
return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis)


def waverec(
Expand All @@ -381,11 +311,6 @@ def waverec(
Returns:
The reconstructed signal tensor.

Raises:
ValueError: If the dtype of the coeffs tensor is unsupported or if the
coefficients have incompatible shapes, dtypes or devices or if
more than one axis is provided.

Example:
>>> import torch
>>> import ptwt, pywt
Expand All @@ -399,29 +324,11 @@ def waverec(
>>> pywt.Wavelet('haar'))

"""
torch_device = coeffs[0].device
torch_dtype = coeffs[0].dtype
if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")

for coeff in coeffs[1:]:
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")

if axis != -1:
swap = []
if isinstance(axis, int):
for coeff in coeffs:
swap.append(coeff.swapaxes(axis, -1))
coeffs = swap
else:
raise ValueError("waverec transforms a single axis only.")

# fold channels, if necessary.
ds = list(coeffs[0].shape)
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
# fold channels and swap axis, if necessary.
if not isinstance(coeffs, list):
coeffs = list(coeffs)
coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis)
torch_device, torch_dtype = _check_same_device_dtype(coeffs)

_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
Expand All @@ -446,12 +353,7 @@ def waverec(
if padr > 0:
res_lo = res_lo[..., :-padr]

if len(ds) == 1:
res_lo = res_lo.squeeze(0)
elif len(ds) > 2:
res_lo = _unfold_axes(res_lo, ds, 1)

if axis != -1:
res_lo = res_lo.swapaxes(axis, -1)
# undo folding and swapping
res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis)

return res_lo
103 changes: 12 additions & 91 deletions src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,21 @@

from __future__ import annotations

from functools import partial
from typing import Optional, Union

import pywt
import torch

from ._util import (
Wavelet,
_as_wavelet,
_check_axes_argument,
_check_if_tensor,
_fold_axes,
_check_same_device_dtype,
_get_len,
_is_dtype_supported,
_map_result,
_outer,
_pad_symmetric,
_swap_axes,
_undo_swap_axes,
_unfold_axes,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
)
from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d
from .conv_transform import (
Expand Down Expand Up @@ -107,32 +102,6 @@ def _fwt_pad2(
return data_pad


def _waverec2d_fold_channels_2d_list(
coeffs: WaveletCoeff2d,
) -> tuple[WaveletCoeff2d, list[int]]:
# fold the input coefficients for processing conv2d_transpose.
ds = list(_check_if_tensor(coeffs[0]).shape)
return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds


def _preprocess_tensor_dec2d(
data: torch.Tensor,
) -> tuple[torch.Tensor, Union[list[int], None]]:
# Preprocess multidimensional input.
ds = None
if len(data.shape) == 2:
data = data.unsqueeze(0).unsqueeze(0)
elif len(data.shape) == 3:
# add a channel dimension for torch.
data = data.unsqueeze(1)
elif len(data.shape) >= 4:
data, ds = _fold_axes(data, 2)
data = data.unsqueeze(1)
elif len(data.shape) == 1:
raise ValueError("More than one input dimension required.")
return data, ds


def wavedec2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
Expand Down Expand Up @@ -183,11 +152,6 @@ def wavedec2(
A tuple containing the wavelet coefficients in pywt order,
see :data:`ptwt.constants.WaveletCoeff2d`.

Raises:
ValueError: If the dimensionality or the dtype of the input data tensor
is unsupported or if the provided ``axes``
input has a length other than two.

Example:
>>> import torch
>>> import ptwt, pywt
Expand All @@ -200,17 +164,7 @@ def wavedec2(
>>> level=2, mode="zero")

"""
if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")

if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
data = _swap_axes(data, list(axes))

wavelet = _as_wavelet(wavelet)
data, ds = _preprocess_tensor_dec2d(data)
data, ds = _preprocess_tensor(data, ndim=2, axes=axes)
dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
)
Expand All @@ -234,13 +188,7 @@ def wavedec2(
res_ll = res_ll.squeeze(1)
result: WaveletCoeff2d = res_ll, *result_lst

if ds:
_unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2)
result = _map_result(result, _unfold_axes2)

if axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
result = _map_result(result, undo_swap_fn)
result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=axes)

return result

Expand Down Expand Up @@ -286,35 +234,16 @@ def waverec2(
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))

"""
if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
_check_axes_argument(list(axes))
swap_fn = partial(_swap_axes, axes=list(axes))
coeffs = _map_result(coeffs, swap_fn)

ds = None
wavelet = _as_wavelet(wavelet)

res_ll = _check_if_tensor(coeffs[0])
torch_device = res_ll.device
torch_dtype = res_ll.dtype

if res_ll.dim() >= 4:
# avoid the channel sum, fold the channels into batches.
coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs)
res_ll = _check_if_tensor(coeffs[0])

if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")
coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes)
torch_device, torch_dtype = _check_same_device_dtype(coeffs)

_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
)
filt_len = rec_lo.shape[-1]
rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)

res_ll = coeffs[0]
for c_pos, coeff_tuple in enumerate(coeffs[1:]):
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3:
raise ValueError(
Expand All @@ -325,11 +254,7 @@ def waverec2(

curr_shape = res_ll.shape
for coeff in coeff_tuple:
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")
elif coeff.shape != curr_shape:
if coeff.shape != curr_shape:
raise ValueError(
"All coefficients on each level must have the same shape"
)
Expand Down Expand Up @@ -362,10 +287,6 @@ def waverec2(
if padr > 0:
res_ll = res_ll[..., :-padr]

if ds:
res_ll = _unfold_axes(res_ll, list(ds), 2)

if axes != (-2, -1):
res_ll = _undo_swap_axes(res_ll, list(axes))
res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes)

return res_ll
Loading
Loading