From 0a1b515bb2d27ea872e95d970806aee5828db790 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 19:28:00 +0200 Subject: [PATCH 01/33] Integrate axis swap into 1d processing funcs --- src/ptwt/conv_transform.py | 63 ++++++++++++++++++-------------- src/ptwt/matmul_transform.py | 22 ++--------- src/ptwt/stationary_transform.py | 32 +++------------- 3 files changed, 45 insertions(+), 72 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index de5876b0..d10067fa 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -207,17 +207,25 @@ def _adjust_padding_at_reconstruction( def _preprocess_tensor_dec1d( - data: torch.Tensor, + data: torch.Tensor, axis: int ) -> tuple[torch.Tensor, list[int]]: """Preprocess input tensor dimensions. Args: data (torch.Tensor): An input tensor of any shape. + axis (int): Compute the transform over this axis instead of the + last one. Returns: A tuple (data, ds) where data is a data tensor of shape [new_batch, 1, to_process] and ds contains the original shape. """ + if axis != -1: + if isinstance(axis, int): + data = data.swapaxes(axis, -1) + else: + raise ValueError("1d transforms operate on a single axis only.") + ds = list(data.shape) if len(ds) == 1: # assume time series @@ -249,8 +257,14 @@ def _postprocess_result_list_dec1d( def _preprocess_result_list_rec1d( - result_lst: Sequence[torch.Tensor], + result_lst: Sequence[torch.Tensor], axis: int ) -> tuple[Sequence[torch.Tensor], list[int]]: + if axis != -1: + if isinstance(axis, int): + result_lst = [coeff.swapaxes(axis, -1) for coeff in result_lst] + else: + raise ValueError("1d transforms operate on a single axis only.") + # Fold axes for the wavelets ds = list(result_lst[0].shape) fold_coeffs: Sequence[torch.Tensor] @@ -263,6 +277,20 @@ def _preprocess_result_list_rec1d( return fold_coeffs, ds +def _postprocess_tensor_rec1d( + data: torch.Tensor, ds: list[int], axis: int +) -> torch.Tensor: + if len(ds) == 1: + data = data.squeeze(0) + elif len(ds) > 2: + data = _unfold_axes(data, ds, 1) + + if axis != -1: + data = data.swapaxes(axis, -1) + + return data + + def wavedec( data: torch.Tensor, wavelet: Union[Wavelet, str], @@ -325,16 +353,10 @@ def wavedec( >>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'), >>> mode='zero', level=2) """ - if axis != -1: - if isinstance(axis, int): - data = data.swapaxes(axis, -1) - else: - raise ValueError("wavedec transforms a single axis only.") - if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_dec1d(data) + data, ds = _preprocess_tensor_dec1d(data, axis=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -405,18 +427,8 @@ def waverec( elif torch_dtype != coeff.dtype: raise ValueError("coefficients must have the same dtype") - if axis != -1: - swap = [] - if isinstance(axis, int): - for coeff in coeffs: - swap.append(coeff.swapaxes(axis, -1)) - coeffs = swap - else: - raise ValueError("waverec transforms a single axis only.") - - # fold channels, if necessary. - ds = list(coeffs[0].shape) - coeffs, ds = _preprocess_result_list_rec1d(coeffs) + # fold channels and swap axis, if necessary. + coeffs, ds = _preprocess_result_list_rec1d(coeffs, axis=axis) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -441,12 +453,7 @@ def waverec( if padr > 0: res_lo = res_lo[..., :-padr] - if len(ds) == 1: - res_lo = res_lo.squeeze(0) - elif len(ds) > 2: - res_lo = _unfold_axes(res_lo, ds, 1) - - if axis != -1: - res_lo = res_lo.swapaxes(axis, -1) + # undo folding and swapping + res_lo = _postprocess_tensor_rec1d(res_lo, ds=ds, axis=axis) return res_lo diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index d9a9316e..c948c030 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -25,6 +25,7 @@ from .conv_transform import ( _get_filter_tensors, _postprocess_result_list_dec1d, + _postprocess_tensor_rec1d, _preprocess_result_list_rec1d, _preprocess_tensor_dec1d, ) @@ -332,10 +333,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - if self.axis != -1: - input_signal = input_signal.swapaxes(self.axis, -1) - - input_signal, ds = _preprocess_tensor_dec1d(input_signal) + input_signal, ds = _preprocess_tensor_dec1d(input_signal, axis=self.axis) input_signal = input_signal.squeeze(1) if not _is_dtype_supported(input_signal.dtype): @@ -594,13 +592,7 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: coefficients are not in the shape as it is returned from a `MatrixWavedec` object. """ - if self.axis != -1: - swap = [] - for coeff in coefficients: - swap.append(coeff.swapaxes(self.axis, -1)) - coefficients = swap - - coefficients, ds = _preprocess_result_list_rec1d(coefficients) + coefficients, ds = _preprocess_result_list_rec1d(coefficients, self.axis) level = len(coefficients) - 1 input_length = coefficients[-1].shape[-1] * 2 @@ -652,12 +644,6 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: res_lo = lo.T - if len(ds) == 1: - res_lo = res_lo.squeeze(0) - elif len(ds) > 2: - res_lo = _unfold_axes(res_lo, ds, 1) - - if self.axis != -1: - res_lo = res_lo.swapaxes(self.axis, -1) + res_lo = _postprocess_tensor_rec1d(res_lo, ds=ds, axis=self.axis) return res_lo diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index c4be3bc7..d9888da2 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -11,6 +11,7 @@ from .conv_transform import ( _get_filter_tensors, _postprocess_result_list_dec1d, + _postprocess_tensor_rec1d, _preprocess_result_list_rec1d, _preprocess_tensor_dec1d, ) @@ -73,13 +74,7 @@ def swt( Raises: ValueError: Is the axis argument is not an integer. """ - if axis != -1: - if isinstance(axis, int): - data = data.swapaxes(axis, -1) - else: - raise ValueError("swt transforms a single axis only.") - - data, ds = _preprocess_tensor_dec1d(data) + data, ds = _preprocess_tensor_dec1d(data, axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -111,7 +106,7 @@ def swt( def iswt( coeffs: Sequence[torch.Tensor], wavelet: Union[pywt.Wavelet, str], - axis: Optional[int] = -1, + axis: int = -1, ) -> torch.Tensor: """Invert a 1d stationary wavelet transform. @@ -120,7 +115,7 @@ def iswt( by the swt function. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used in the forward transform. - axis (int, optional): The axis the forward trasform was computed over. + axis (int): The axis the forward trasform was computed over. Defaults to -1. Returns: @@ -129,16 +124,7 @@ def iswt( Raises: ValueError: If the axis argument is not an integer. """ - if axis != -1: - swap = [] - if isinstance(axis, int): - for coeff in coeffs: - swap.append(coeff.swapaxes(axis, -1)) - coeffs = swap - else: - raise ValueError("iswt transforms a single axis only.") - - coeffs, ds = _preprocess_result_list_rec1d(coeffs) + coeffs, ds = _preprocess_result_list_rec1d(coeffs, axis=axis) wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( @@ -161,12 +147,6 @@ def iswt( 1, ) - if len(ds) == 1: - res_lo = res_lo.squeeze(0) - elif len(ds) > 2: - res_lo = _unfold_axes(res_lo, ds, 1) - - if axis != -1: - res_lo = res_lo.swapaxes(axis, -1) + res_lo = _postprocess_tensor_rec1d(res_lo, ds=ds, axis=axis) return res_lo From 242606024698d83e1110b95fd0cb0decf91c32f2 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 19:34:47 +0200 Subject: [PATCH 02/33] Make channel dim addition optional --- src/ptwt/conv_transform.py | 15 +++++++++------ src/ptwt/matmul_transform.py | 7 +++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index d10067fa..4f09fb4d 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -207,7 +207,7 @@ def _adjust_padding_at_reconstruction( def _preprocess_tensor_dec1d( - data: torch.Tensor, axis: int + data: torch.Tensor, axis: int, add_channel_dim: bool = True ) -> tuple[torch.Tensor, list[int]]: """Preprocess input tensor dimensions. @@ -215,6 +215,9 @@ def _preprocess_tensor_dec1d( data (torch.Tensor): An input tensor of any shape. axis (int): Compute the transform over this axis instead of the last one. + add_channel_dim (bool): If True, ensures that the return has at + least three axes by adding a new axis at dim 1. + Defaults to True. Returns: A tuple (data, ds) where data is a data tensor of shape @@ -229,13 +232,13 @@ def _preprocess_tensor_dec1d( ds = list(data.shape) if len(ds) == 1: # assume time series - data = data.unsqueeze(0).unsqueeze(0) - elif len(ds) == 2: - # assume batched time series - data = data.unsqueeze(1) - else: + data = data.unsqueeze(0) + elif len(ds) > 2: data, ds = _fold_axes(data, 1) + + if add_channel_dim: data = data.unsqueeze(1) + return data, ds diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index c948c030..cdd91fb8 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -333,8 +333,11 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - input_signal, ds = _preprocess_tensor_dec1d(input_signal, axis=self.axis) - input_signal = input_signal.squeeze(1) + input_signal, ds = _preprocess_tensor_dec1d( + input_signal, + axis=self.axis, + add_channel_dim=False, + ) if not _is_dtype_supported(input_signal.dtype): raise ValueError(f"Input dtype {input_signal.dtype} not supported") From eaa2fb6c5aa47037622ec1beefa7f04d39e67fea Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 19:55:00 +0200 Subject: [PATCH 03/33] Refactor; Add processing funcs for 2d --- src/ptwt/conv_transform.py | 9 +- src/ptwt/conv_transform_2.py | 136 +++++++++++++++------------ src/ptwt/matmul_transform_2.py | 51 +++------- src/ptwt/separable_conv_transform.py | 18 +--- 4 files changed, 96 insertions(+), 118 deletions(-) diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 4f09fb4d..f4195bad 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -270,14 +270,11 @@ def _preprocess_result_list_rec1d( # Fold axes for the wavelets ds = list(result_lst[0].shape) - fold_coeffs: Sequence[torch.Tensor] if len(ds) == 1: - fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst] + result_lst = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst] elif len(ds) > 2: - fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] - else: - fold_coeffs = result_lst - return fold_coeffs, ds + result_lst = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] + return result_lst, ds def _postprocess_tensor_rec1d( diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 1136916f..1aa1e732 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -99,32 +99,81 @@ def _fwt_pad2( return data_pad -def _waverec2d_fold_channels_2d_list( - coeffs: WaveletCoeff2d, -) -> tuple[WaveletCoeff2d, list[int]]: - # fold the input coefficients for processing conv2d_transpose. - ds = list(_check_if_tensor(coeffs[0]).shape) - return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds - - def _preprocess_tensor_dec2d( - data: torch.Tensor, -) -> tuple[torch.Tensor, Union[list[int], None]]: + data: torch.Tensor, axes: tuple[int, int], add_channel_dim: bool = True +) -> tuple[torch.Tensor, list[int]]: + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + data = _swap_axes(data, list(axes)) + # Preprocess multidimensional input. - ds = None - if len(data.shape) == 2: - data = data.unsqueeze(0).unsqueeze(0) - elif len(data.shape) == 3: - # add a channel dimension for torch. - data = data.unsqueeze(1) - elif len(data.shape) >= 4: + ds = list(data.shape) + if len(ds) <= 1: + raise ValueError("More than one input dimension required.") + elif len(ds) == 2: + data = data.unsqueeze(0) + elif len(ds) >= 4: data, ds = _fold_axes(data, 2) + + if add_channel_dim: data = data.unsqueeze(1) - elif len(data.shape) == 1: - raise ValueError("More than one input dimension required.") + return data, ds +def _postprocess_coeffs_dec2d( + coeffs: WaveletCoeff2d, ds: list[int], axes: tuple[int, int] +) -> WaveletCoeff2d: + if len(ds) == 2: + coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) + elif len(ds) > 3: + _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) + coeffs = _map_result(coeffs, _unfold_axes2) + + if tuple(axes) != (-2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + coeffs = _map_result(coeffs, undo_swap_fn) + + return coeffs + + +def _preprocess_coeffs_rec2d( + coeffs: WaveletCoeff2d, axes: tuple[int, int] +) -> tuple[WaveletCoeff2d, list[int]]: + # swap axes if necessary + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + swap_fn = partial(_swap_axes, axes=axes) + coeffs = _map_result(coeffs, swap_fn) + + # Fold axes for the wavelets + ds = list(coeffs[0].shape) + if len(ds) <= 1: + raise ValueError("2d transforms require at least 2 input dimensions") + elif len(ds) == 2: + coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) + elif len(ds) > 3: + coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) + return coeffs, ds + + +def _postprocess_tensor_rec2d( + data: torch.Tensor, ds: list[int], axes: tuple[int, int] +) -> torch.Tensor: + if len(ds) == 2: + data = data.squeeze(0) + elif len(ds) > 3: + data = _unfold_axes(data, ds, 2) + + if tuple(axes) != (-2, -1): + data = _undo_swap_axes(data, axes) + return data + + def wavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], @@ -195,14 +244,7 @@ def wavedec2( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - wavelet = _as_wavelet(wavelet) - data, ds = _preprocess_tensor_dec2d(data) + data, ds = _preprocess_tensor_dec2d(data, axes=axes) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -226,13 +268,7 @@ def wavedec2( res_ll = res_ll.squeeze(1) result: WaveletCoeff2d = res_ll, *result_lst - if ds: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - result = _map_result(result, _unfold_axes2) - - if axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - result = _map_result(result, undo_swap_fn) + result = _postprocess_coeffs_dec2d(result, ds=ds, axes=axes) return result @@ -278,35 +314,21 @@ def waverec2( >>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar")) """ - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - _check_axes_argument(list(axes)) - swap_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_fn) - - ds = None - wavelet = _as_wavelet(wavelet) - - res_ll = _check_if_tensor(coeffs[0]) - torch_device = res_ll.device - torch_dtype = res_ll.dtype - - if res_ll.dim() >= 4: - # avoid the channel sum, fold the channels into batches. - coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs) - res_ll = _check_if_tensor(coeffs[0]) - + _check_if_tensor(coeffs[0]) + torch_device = coeffs[0].device + torch_dtype = coeffs[0].dtype if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") + coeffs, ds = _preprocess_coeffs_rec2d(coeffs, axes=axes) + _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype ) filt_len = rec_lo.shape[-1] rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) + res_ll = _check_if_tensor(coeffs[0]) for c_pos, coeff_tuple in enumerate(coeffs[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( @@ -354,10 +376,6 @@ def waverec2( if padr > 0: res_ll = res_ll[..., :-padr] - if ds: - res_ll = _unfold_axes(res_ll, list(ds), 2) - - if axes != (-2, -1): - res_ll = _undo_swap_axes(res_ll, list(axes)) + res_ll = _postprocess_tensor_rec2d(res_ll, ds=ds, axes=axes) return res_ll diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index ca629f14..fec2016a 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -33,8 +33,10 @@ from .conv_transform import _get_filter_tensors from .conv_transform_2 import ( _construct_2d_filt, + _postprocess_coeffs_dec2d, + _postprocess_tensor_rec2d, + _preprocess_coeffs_rec2d, _preprocess_tensor_dec2d, - _waverec2d_fold_channels_2d_list, ) from .matmul_transform import ( BaseMatrixWaveDec, @@ -124,7 +126,6 @@ def _construct_s_2( Returns: The generated fast wavelet synthesis matrix. """ - wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=True, device=device, dtype=dtype ) @@ -290,8 +291,8 @@ def __init__( if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: - _check_axes_argument(list(axes)) - self.axes = tuple(axes) + _check_axes_argument(axes) + self.axes = axes self.level = level self.boundary = boundary self.separable = separable @@ -434,11 +435,9 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - if self.axes != (-2, -1): - input_signal = _swap_axes(input_signal, list(self.axes)) - - input_signal, ds = _preprocess_tensor_dec2d(input_signal) - input_signal = input_signal.squeeze(1) + input_signal, ds = _preprocess_tensor_dec2d( + input_signal, axes=self.axes, add_channel_dim=False + ) batch_size, height, width = input_signal.shape @@ -537,13 +536,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: split_list.reverse() result: WaveletCoeff2d = ll, *split_list - if ds: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - result = _map_result(result, _unfold_axes2) - - if self.axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) - result = _map_result(result, undo_swap_fn) + result = _postprocess_coeffs_dec2d(result, ds=ds, axes=self.axes) return result @@ -599,7 +592,7 @@ def __init__( if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: - _check_axes_argument(list(axes)) + _check_axes_argument(axes) self.axes = axes self.ifwt_matrix_list: list[ @@ -736,23 +729,8 @@ def __call__( coefficients are not in the shape as it is returned from a `MatrixWavedec2` object. """ - ll = _check_if_tensor(coefficients[0]) - - if tuple(self.axes) != (-2, -1): - swap_fn = partial(_swap_axes, axes=list(self.axes)) - coefficients = _map_result(coefficients, swap_fn) - ll = _check_if_tensor(coefficients[0]) - - ds = None - if ll.dim() == 1: - raise ValueError("2d transforms require more than a single input dim.") - elif ll.dim() == 2: - # add batch dim to unbatched input - ll = ll.unsqueeze(0) - elif ll.dim() >= 4: - # avoid the channel sum, fold the channels into batches. - coefficients, ds = _waverec2d_fold_channels_2d_list(coefficients) - ll = _check_if_tensor(coefficients[0]) + coefficients, ds = _preprocess_coeffs_rec2d(coefficients, axes=self.axes) + ll = coefficients[0] level = len(coefficients) - 1 height, width = tuple(c * 2 for c in coefficients[-1][0].shape[-2:]) @@ -845,9 +823,6 @@ def __call__( if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] - if ds: - ll = _unfold_axes(ll, list(ds), 2) + ll = _postprocess_tensor_rec2d(ll, ds=ds, axes=self.axes) - if self.axes != (-2, -1): - ll = _undo_swap_axes(ll, list(self.axes)) return ll diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index c89b9931..da492ce2 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -29,7 +29,7 @@ ) from .constants import BoundaryMode, WaveletCoeff2dSeparable, WaveletCoeffNd from .conv_transform import wavedec, waverec -from .conv_transform_2 import _preprocess_tensor_dec2d +from .conv_transform_2 import _preprocess_tensor_dec2d, _postprocess_tensor_rec2d def _separable_conv_dwtn_( @@ -228,15 +228,7 @@ def fswavedec2( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - wavelet = _as_wavelet(wavelet) - data, ds = _preprocess_tensor_dec2d(data) - data = data.squeeze(1) + data, ds = _preprocess_tensor_dec2d(data, axes=axes, add_channel_dim=False) res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) if ds: @@ -382,11 +374,7 @@ def fswaverec2( res_ll = _separable_conv_waverecn(coeffs, wavelet) - if ds: - res_ll = _unfold_axes(res_ll, list(ds), 2) - - if axes != (-2, -1): - res_ll = _undo_swap_axes(res_ll, list(axes)) + res_ll = _postprocess_tensor_rec2d(res_ll, ds=ds, axes=axes) return res_ll From 0e4cb279b5d34334f496f23a39f0250dd44e1c6b Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 20:18:45 +0200 Subject: [PATCH 04/33] Move processung funcs to _util module --- src/ptwt/_util.py | 164 +++++++++++++++++++++++++++ src/ptwt/conv_transform.py | 99 ++-------------- src/ptwt/conv_transform_2.py | 94 ++------------- src/ptwt/matmul_transform.py | 21 ++-- src/ptwt/matmul_transform_2.py | 26 ++--- src/ptwt/separable_conv_transform.py | 12 +- src/ptwt/stationary_transform.py | 29 ++--- 7 files changed, 215 insertions(+), 230 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 414323f1..cbbe3bea 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -4,6 +4,7 @@ import typing from collections.abc import Sequence +from functools import partial from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np @@ -253,3 +254,166 @@ def _map_result( Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst ) return approx, *cast_result_lst + + +def _preprocess_coeffs_1d( + result_lst: Sequence[torch.Tensor], axis: int +) -> tuple[Sequence[torch.Tensor], list[int]]: + if axis != -1: + if isinstance(axis, int): + result_lst = [coeff.swapaxes(axis, -1) for coeff in result_lst] + else: + raise ValueError("1d transforms operate on a single axis only.") + + # Fold axes for the wavelets + ds = list(result_lst[0].shape) + if len(ds) == 1: + result_lst = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst] + elif len(ds) > 2: + result_lst = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] + return result_lst, ds + + +def _postprocess_coeffs_1d( + result_list: list[torch.Tensor], ds: list[int], axis: int +) -> list[torch.Tensor]: + if len(ds) == 1: + result_list = [r_el.squeeze(0) for r_el in result_list] + elif len(ds) > 2: + # Unfold axes for the wavelets + result_list = [_unfold_axes(fres, ds, 1) for fres in result_list] + else: + result_list = result_list + + if axis != -1: + result_list = [coeff.swapaxes(axis, -1) for coeff in result_list] + + return result_list + + +def _preprocess_coeffs_2d( + coeffs: WaveletCoeff2d, axes: tuple[int, int] +) -> tuple[WaveletCoeff2d, list[int]]: + # swap axes if necessary + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + swap_fn = partial(_swap_axes, axes=axes) + coeffs = _map_result(coeffs, swap_fn) + + # Fold axes for the wavelets + ds = list(coeffs[0].shape) + if len(ds) <= 1: + raise ValueError("2d transforms require at least 2 input dimensions") + elif len(ds) == 2: + coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) + elif len(ds) > 3: + coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) + return coeffs, ds + + +def _postprocess_coeffs_2d( + coeffs: WaveletCoeff2d, ds: list[int], axes: tuple[int, int] +) -> WaveletCoeff2d: + if len(ds) == 2: + coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) + elif len(ds) > 3: + _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) + coeffs = _map_result(coeffs, _unfold_axes2) + + if tuple(axes) != (-2, -1): + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + coeffs = _map_result(coeffs, undo_swap_fn) + + return coeffs + + +def _preprocess_tensor_1d( + data: torch.Tensor, axis: int, add_channel_dim: bool = True +) -> tuple[torch.Tensor, list[int]]: + """Preprocess input tensor dimensions. + + Args: + data (torch.Tensor): An input tensor of any shape. + axis (int): Compute the transform over this axis instead of the + last one. + add_channel_dim (bool): If True, ensures that the return has at + least three axes by adding a new axis at dim 1. + Defaults to True. + + Returns: + A tuple (data, ds) where data is a data tensor of shape + [new_batch, 1, to_process] and ds contains the original shape. + + Raises: + ValueError: if ``axis`` is not a single int. + """ + if axis != -1: + if isinstance(axis, int): + data = data.swapaxes(axis, -1) + else: + raise ValueError("1d transforms operate on a single axis only.") + + ds = list(data.shape) + if len(ds) == 1: + # assume time series + data = data.unsqueeze(0) + elif len(ds) > 2: + data, ds = _fold_axes(data, 1) + + if add_channel_dim: + data = data.unsqueeze(1) + + return data, ds + + +def _postprocess_tensor_1d( + data: torch.Tensor, ds: list[int], axis: int +) -> torch.Tensor: + if len(ds) == 1: + data = data.squeeze(0) + elif len(ds) > 2: + data = _unfold_axes(data, ds, 1) + + if axis != -1: + data = data.swapaxes(axis, -1) + + return data + + +def _preprocess_tensor_2d( + data: torch.Tensor, axes: tuple[int, int], add_channel_dim: bool = True +) -> tuple[torch.Tensor, list[int]]: + if tuple(axes) != (-2, -1): + if len(axes) != 2: + raise ValueError("2D transforms work with two axes.") + else: + data = _swap_axes(data, list(axes)) + + # Preprocess multidimensional input. + ds = list(data.shape) + if len(ds) <= 1: + raise ValueError("More than one input dimension required.") + elif len(ds) == 2: + data = data.unsqueeze(0) + elif len(ds) >= 4: + data, ds = _fold_axes(data, 2) + + if add_channel_dim: + data = data.unsqueeze(1) + + return data, ds + + +def _postprocess_tensor_2d( + data: torch.Tensor, ds: list[int], axes: tuple[int, int] +) -> torch.Tensor: + if len(ds) == 2: + data = data.squeeze(0) + elif len(ds) > 3: + data = _unfold_axes(data, ds, 2) + + if tuple(axes) != (-2, -1): + data = _undo_swap_axes(data, axes) + return data diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index f4195bad..624f09ad 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -14,11 +14,13 @@ from ._util import ( Wavelet, _as_wavelet, - _fold_axes, _get_len, _is_dtype_supported, _pad_symmetric, - _unfold_axes, + _postprocess_coeffs_1d, + _postprocess_tensor_1d, + _preprocess_coeffs_1d, + _preprocess_tensor_1d, ) from .constants import BoundaryMode, WaveletCoeff2d @@ -206,91 +208,6 @@ def _adjust_padding_at_reconstruction( return pad_end, pad_start -def _preprocess_tensor_dec1d( - data: torch.Tensor, axis: int, add_channel_dim: bool = True -) -> tuple[torch.Tensor, list[int]]: - """Preprocess input tensor dimensions. - - Args: - data (torch.Tensor): An input tensor of any shape. - axis (int): Compute the transform over this axis instead of the - last one. - add_channel_dim (bool): If True, ensures that the return has at - least three axes by adding a new axis at dim 1. - Defaults to True. - - Returns: - A tuple (data, ds) where data is a data tensor of shape - [new_batch, 1, to_process] and ds contains the original shape. - """ - if axis != -1: - if isinstance(axis, int): - data = data.swapaxes(axis, -1) - else: - raise ValueError("1d transforms operate on a single axis only.") - - ds = list(data.shape) - if len(ds) == 1: - # assume time series - data = data.unsqueeze(0) - elif len(ds) > 2: - data, ds = _fold_axes(data, 1) - - if add_channel_dim: - data = data.unsqueeze(1) - - return data, ds - - -def _postprocess_result_list_dec1d( - result_list: list[torch.Tensor], ds: list[int], axis: int -) -> list[torch.Tensor]: - if len(ds) == 1: - result_list = [r_el.squeeze(0) for r_el in result_list] - elif len(ds) > 2: - # Unfold axes for the wavelets - result_list = [_unfold_axes(fres, ds, 1) for fres in result_list] - else: - result_list = result_list - - if axis != -1: - result_list = [coeff.swapaxes(axis, -1) for coeff in result_list] - - return result_list - - -def _preprocess_result_list_rec1d( - result_lst: Sequence[torch.Tensor], axis: int -) -> tuple[Sequence[torch.Tensor], list[int]]: - if axis != -1: - if isinstance(axis, int): - result_lst = [coeff.swapaxes(axis, -1) for coeff in result_lst] - else: - raise ValueError("1d transforms operate on a single axis only.") - - # Fold axes for the wavelets - ds = list(result_lst[0].shape) - if len(ds) == 1: - result_lst = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst] - elif len(ds) > 2: - result_lst = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] - return result_lst, ds - - -def _postprocess_tensor_rec1d( - data: torch.Tensor, ds: list[int], axis: int -) -> torch.Tensor: - if len(ds) == 1: - data = data.squeeze(0) - elif len(ds) > 2: - data = _unfold_axes(data, ds, 1) - - if axis != -1: - data = data.swapaxes(axis, -1) - - return data - - def wavedec( data: torch.Tensor, wavelet: Union[Wavelet, str], @@ -356,7 +273,7 @@ def wavedec( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_dec1d(data, axis=axis) + data, ds = _preprocess_tensor_1d(data, axis=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -377,7 +294,7 @@ def wavedec( result_list.append(res_lo.squeeze(1)) result_list.reverse() - result_list = _postprocess_result_list_dec1d(result_list, ds, axis) + result_list = _postprocess_coeffs_1d(result_list, ds, axis) return result_list @@ -428,7 +345,7 @@ def waverec( raise ValueError("coefficients must have the same dtype") # fold channels and swap axis, if necessary. - coeffs, ds = _preprocess_result_list_rec1d(coeffs, axis=axis) + coeffs, ds = _preprocess_coeffs_1d(coeffs, axis=axis) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -454,6 +371,6 @@ def waverec( res_lo = res_lo[..., :-padr] # undo folding and swapping - res_lo = _postprocess_tensor_rec1d(res_lo, ds=ds, axis=axis) + res_lo = _postprocess_tensor_1d(res_lo, ds=ds, axis=axis) return res_lo diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 1aa1e732..358903a3 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -6,7 +6,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Union import pywt @@ -15,17 +14,15 @@ from ._util import ( Wavelet, _as_wavelet, - _check_axes_argument, _check_if_tensor, - _fold_axes, _get_len, _is_dtype_supported, - _map_result, _outer, _pad_symmetric, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs_2d, + _postprocess_tensor_2d, + _preprocess_coeffs_2d, + _preprocess_tensor_2d, ) from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d from .conv_transform import ( @@ -99,81 +96,6 @@ def _fwt_pad2( return data_pad -def _preprocess_tensor_dec2d( - data: torch.Tensor, axes: tuple[int, int], add_channel_dim: bool = True -) -> tuple[torch.Tensor, list[int]]: - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - # Preprocess multidimensional input. - ds = list(data.shape) - if len(ds) <= 1: - raise ValueError("More than one input dimension required.") - elif len(ds) == 2: - data = data.unsqueeze(0) - elif len(ds) >= 4: - data, ds = _fold_axes(data, 2) - - if add_channel_dim: - data = data.unsqueeze(1) - - return data, ds - - -def _postprocess_coeffs_dec2d( - coeffs: WaveletCoeff2d, ds: list[int], axes: tuple[int, int] -) -> WaveletCoeff2d: - if len(ds) == 2: - coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) - elif len(ds) > 3: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - coeffs = _map_result(coeffs, _unfold_axes2) - - if tuple(axes) != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - coeffs = _map_result(coeffs, undo_swap_fn) - - return coeffs - - -def _preprocess_coeffs_rec2d( - coeffs: WaveletCoeff2d, axes: tuple[int, int] -) -> tuple[WaveletCoeff2d, list[int]]: - # swap axes if necessary - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - swap_fn = partial(_swap_axes, axes=axes) - coeffs = _map_result(coeffs, swap_fn) - - # Fold axes for the wavelets - ds = list(coeffs[0].shape) - if len(ds) <= 1: - raise ValueError("2d transforms require at least 2 input dimensions") - elif len(ds) == 2: - coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) - elif len(ds) > 3: - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) - return coeffs, ds - - -def _postprocess_tensor_rec2d( - data: torch.Tensor, ds: list[int], axes: tuple[int, int] -) -> torch.Tensor: - if len(ds) == 2: - data = data.squeeze(0) - elif len(ds) > 3: - data = _unfold_axes(data, ds, 2) - - if tuple(axes) != (-2, -1): - data = _undo_swap_axes(data, axes) - return data - - def wavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], @@ -244,7 +166,7 @@ def wavedec2( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_dec2d(data, axes=axes) + data, ds = _preprocess_tensor_2d(data, axes=axes) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -268,7 +190,7 @@ def wavedec2( res_ll = res_ll.squeeze(1) result: WaveletCoeff2d = res_ll, *result_lst - result = _postprocess_coeffs_dec2d(result, ds=ds, axes=axes) + result = _postprocess_coeffs_2d(result, ds=ds, axes=axes) return result @@ -320,7 +242,7 @@ def waverec2( if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") - coeffs, ds = _preprocess_coeffs_rec2d(coeffs, axes=axes) + coeffs, ds = _preprocess_coeffs_2d(coeffs, axes=axes) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -376,6 +298,6 @@ def waverec2( if padr > 0: res_ll = res_ll[..., :-padr] - res_ll = _postprocess_tensor_rec2d(res_ll, ds=ds, axes=axes) + res_ll = _postprocess_tensor_2d(res_ll, ds=ds, axes=axes) return res_ll diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index cdd91fb8..9d196ccd 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -19,16 +19,13 @@ _as_wavelet, _is_boundary_mode_supported, _is_dtype_supported, - _unfold_axes, + _postprocess_coeffs_1d, + _postprocess_tensor_1d, + _preprocess_coeffs_1d, + _preprocess_tensor_1d, ) from .constants import OrthogonalizeMethod -from .conv_transform import ( - _get_filter_tensors, - _postprocess_result_list_dec1d, - _postprocess_tensor_rec1d, - _preprocess_result_list_rec1d, - _preprocess_tensor_dec1d, -) +from .conv_transform import _get_filter_tensors from .sparse_math import ( _orth_by_gram_schmidt, _orth_by_qr, @@ -333,7 +330,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - input_signal, ds = _preprocess_tensor_dec1d( + input_signal, ds = _preprocess_tensor_1d( input_signal, axis=self.axis, add_channel_dim=False, @@ -380,7 +377,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: result_list = [s.T for s in split_list[::-1]] # unfold if necessary - result_list = _postprocess_result_list_dec1d(result_list, ds, self.axis) + result_list = _postprocess_coeffs_1d(result_list, ds, self.axis) return result_list @@ -595,7 +592,7 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: coefficients are not in the shape as it is returned from a `MatrixWavedec` object. """ - coefficients, ds = _preprocess_result_list_rec1d(coefficients, self.axis) + coefficients, ds = _preprocess_coeffs_1d(coefficients, self.axis) level = len(coefficients) - 1 input_length = coefficients[-1].shape[-1] * 2 @@ -647,6 +644,6 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: res_lo = lo.T - res_lo = _postprocess_tensor_rec1d(res_lo, ds=ds, axis=self.axis) + res_lo = _postprocess_tensor_1d(res_lo, ds=ds, axis=self.axis) return res_lo diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index fec2016a..2a2def91 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -6,7 +6,6 @@ from __future__ import annotations import sys -from functools import partial from typing import Optional, Union, cast import numpy as np @@ -16,13 +15,12 @@ Wavelet, _as_wavelet, _check_axes_argument, - _check_if_tensor, _is_boundary_mode_supported, _is_dtype_supported, - _map_result, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs_2d, + _postprocess_tensor_2d, + _preprocess_coeffs_2d, + _preprocess_tensor_2d, ) from .constants import ( OrthogonalizeMethod, @@ -31,13 +29,7 @@ WaveletDetailTuple2d, ) from .conv_transform import _get_filter_tensors -from .conv_transform_2 import ( - _construct_2d_filt, - _postprocess_coeffs_dec2d, - _postprocess_tensor_rec2d, - _preprocess_coeffs_rec2d, - _preprocess_tensor_dec2d, -) +from .conv_transform_2 import _construct_2d_filt from .matmul_transform import ( BaseMatrixWaveDec, construct_boundary_a, @@ -435,7 +427,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - input_signal, ds = _preprocess_tensor_dec2d( + input_signal, ds = _preprocess_tensor_2d( input_signal, axes=self.axes, add_channel_dim=False ) @@ -536,7 +528,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: split_list.reverse() result: WaveletCoeff2d = ll, *split_list - result = _postprocess_coeffs_dec2d(result, ds=ds, axes=self.axes) + result = _postprocess_coeffs_2d(result, ds=ds, axes=self.axes) return result @@ -729,7 +721,7 @@ def __call__( coefficients are not in the shape as it is returned from a `MatrixWavedec2` object. """ - coefficients, ds = _preprocess_coeffs_rec2d(coefficients, axes=self.axes) + coefficients, ds = _preprocess_coeffs_2d(coefficients, axes=self.axes) ll = coefficients[0] level = len(coefficients) - 1 @@ -823,6 +815,6 @@ def __call__( if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] - ll = _postprocess_tensor_rec2d(ll, ds=ds, axes=self.axes) + ll = _postprocess_tensor_2d(ll, ds=ds, axes=self.axes) return ll diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index da492ce2..b1724ddc 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -23,13 +23,14 @@ _fold_axes, _is_dtype_supported, _map_result, + _postprocess_tensor_2d, + _preprocess_tensor_2d, _swap_axes, _undo_swap_axes, _unfold_axes, ) from .constants import BoundaryMode, WaveletCoeff2dSeparable, WaveletCoeffNd from .conv_transform import wavedec, waverec -from .conv_transform_2 import _preprocess_tensor_dec2d, _postprocess_tensor_rec2d def _separable_conv_dwtn_( @@ -228,7 +229,7 @@ def fswavedec2( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_dec2d(data, axes=axes, add_channel_dim=False) + data, ds = _preprocess_tensor_2d(data, axes=axes, add_channel_dim=False) res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) if ds: @@ -357,15 +358,12 @@ def fswaverec2( swap_fn = partial(_swap_axes, axes=list(axes)) coeffs = _map_result(coeffs, swap_fn) - ds = None - wavelet = _as_wavelet(wavelet) - res_ll = _check_if_tensor(coeffs[0]) + ds = list(res_ll.shape) torch_dtype = res_ll.dtype if res_ll.dim() >= 4: # avoid the channel sum, fold the channels into batches. - ds = _check_if_tensor(coeffs[0]).shape coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) res_ll = _check_if_tensor(coeffs[0]) @@ -374,7 +372,7 @@ def fswaverec2( res_ll = _separable_conv_waverecn(coeffs, wavelet) - res_ll = _postprocess_tensor_rec2d(res_ll, ds=ds, axes=axes) + res_ll = _postprocess_tensor_2d(res_ll, ds=ds, axes=axes) return res_ll diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index d9888da2..05e7e7d6 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -7,14 +7,15 @@ import torch import torch.nn.functional as F # noqa:N812 -from ._util import Wavelet, _as_wavelet, _unfold_axes -from .conv_transform import ( - _get_filter_tensors, - _postprocess_result_list_dec1d, - _postprocess_tensor_rec1d, - _preprocess_result_list_rec1d, - _preprocess_tensor_dec1d, +from ._util import ( + Wavelet, + _as_wavelet, + _postprocess_coeffs_1d, + _postprocess_tensor_1d, + _preprocess_coeffs_1d, + _preprocess_tensor_1d, ) +from .conv_transform import _get_filter_tensors def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor: @@ -70,11 +71,8 @@ def swt( Returns: Same as wavedec. Equivalent to pywt.swt with trim_approx=True. - - Raises: - ValueError: Is the axis argument is not an integer. """ - data, ds = _preprocess_tensor_dec1d(data, axis) + data, ds = _preprocess_tensor_1d(data, axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -98,7 +96,7 @@ def swt( result_list.append(res_hi.squeeze(1)) result_list.append(res_lo.squeeze(1)) - result_list = _postprocess_result_list_dec1d(result_list, ds, axis) + result_list = _postprocess_coeffs_1d(result_list, ds, axis) return result_list[::-1] @@ -120,11 +118,8 @@ def iswt( Returns: A reconstruction of the original swt input. - - Raises: - ValueError: If the axis argument is not an integer. """ - coeffs, ds = _preprocess_result_list_rec1d(coeffs, axis=axis) + coeffs, ds = _preprocess_coeffs_1d(coeffs, axis=axis) wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( @@ -147,6 +142,6 @@ def iswt( 1, ) - res_lo = _postprocess_tensor_rec1d(res_lo, ds=ds, axis=axis) + res_lo = _postprocess_tensor_1d(res_lo, ds=ds, axis=axis) return res_lo From 3554a531e49873b8b57449b99becaf06436cddd1 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 21:04:34 +0200 Subject: [PATCH 05/33] Generalize tensor processing --- src/ptwt/_util.py | 103 ++++++++++++--------------- src/ptwt/conv_transform.py | 8 +-- src/ptwt/conv_transform_2.py | 8 +-- src/ptwt/matmul_transform.py | 11 +-- src/ptwt/matmul_transform_2.py | 10 +-- src/ptwt/separable_conv_transform.py | 17 ++--- src/ptwt/stationary_transform.py | 8 +-- 7 files changed, 74 insertions(+), 91 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index cbbe3bea..370d742c 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -329,76 +329,53 @@ def _postprocess_coeffs_2d( return coeffs -def _preprocess_tensor_1d( - data: torch.Tensor, axis: int, add_channel_dim: bool = True +def _preprocess_tensor( + data: torch.Tensor, + ndim: int, + axes: Union[tuple[int, ...], int], + add_channel_dim: bool = True, ) -> tuple[torch.Tensor, list[int]]: """Preprocess input tensor dimensions. Args: - data (torch.Tensor): An input tensor of any shape. - axis (int): Compute the transform over this axis instead of the - last one. + data (torch.Tensor): An input tensor with at least `ndim` axes. + ndim (int): The number of axes on which the transformation is + applied. + axis (int or tuple of ints): Compute the transform over these axes + instead of the last ones. add_channel_dim (bool): If True, ensures that the return has at least three axes by adding a new axis at dim 1. Defaults to True. Returns: - A tuple (data, ds) where data is a data tensor of shape - [new_batch, 1, to_process] and ds contains the original shape. + A tuple (data, ds) where data is the transformed data tensor + and ds contains the original shape. If `add_channel_dim` is True, + `data` has `ndim` + 2 axes, otherwise `ndim` + 1. Raises: - ValueError: if ``axis`` is not a single int. + ValueError: if `ndim` is not positive, `axes` has not at least + length `ndim` or `data` has not at least `ndim` axes. """ - if axis != -1: - if isinstance(axis, int): - data = data.swapaxes(axis, -1) - else: - raise ValueError("1d transforms operate on a single axis only.") - - ds = list(data.shape) - if len(ds) == 1: - # assume time series - data = data.unsqueeze(0) - elif len(ds) > 2: - data, ds = _fold_axes(data, 1) + if isinstance(axes, int): + axes = (axes,) - if add_channel_dim: - data = data.unsqueeze(1) - - return data, ds - - -def _postprocess_tensor_1d( - data: torch.Tensor, ds: list[int], axis: int -) -> torch.Tensor: - if len(ds) == 1: - data = data.squeeze(0) - elif len(ds) > 2: - data = _unfold_axes(data, ds, 1) + if ndim <= 0: + raise ValueError("Number of dimensions must be positive") - if axis != -1: - data = data.swapaxes(axis, -1) - - return data - - -def _preprocess_tensor_2d( - data: torch.Tensor, axes: tuple[int, int], add_channel_dim: bool = True -) -> tuple[torch.Tensor, list[int]]: - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") + if tuple(axes) != tuple(range(-ndim, 0)): + if len(axes) != ndim: + raise ValueError(f"{ndim}D transforms work with {ndim} axes.") else: - data = _swap_axes(data, list(axes)) + data = _swap_axes(data, axes) # Preprocess multidimensional input. ds = list(data.shape) - if len(ds) <= 1: - raise ValueError("More than one input dimension required.") - elif len(ds) == 2: + if len(ds) < ndim: + raise ValueError(f"More than {ndim} input dimensions required.") + elif len(ds) == ndim: data = data.unsqueeze(0) - elif len(ds) >= 4: - data, ds = _fold_axes(data, 2) + elif len(ds) > ndim + 1: + data, ds = _fold_axes(data, ndim) if add_channel_dim: data = data.unsqueeze(1) @@ -406,14 +383,24 @@ def _preprocess_tensor_2d( return data, ds -def _postprocess_tensor_2d( - data: torch.Tensor, ds: list[int], axes: tuple[int, int] +def _postprocess_tensor( + data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] ) -> torch.Tensor: - if len(ds) == 2: + if isinstance(axes, int): + axes = (axes,) + + if ndim <= 0: + raise ValueError("Number of dimensions must be positive") + + if len(ds) == ndim: data = data.squeeze(0) - elif len(ds) > 3: - data = _unfold_axes(data, ds, 2) + elif len(ds) > ndim + 1: + data = _unfold_axes(data, ds, ndim) + + if tuple(axes) != tuple(range(-ndim, 0)): + if len(axes) != ndim: + raise ValueError(f"{ndim}D transforms work with {ndim} axes.") + else: + data = _undo_swap_axes(data, axes) - if tuple(axes) != (-2, -1): - data = _undo_swap_axes(data, axes) return data diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 624f09ad..b1a8a5cc 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -18,9 +18,9 @@ _is_dtype_supported, _pad_symmetric, _postprocess_coeffs_1d, - _postprocess_tensor_1d, + _postprocess_tensor, _preprocess_coeffs_1d, - _preprocess_tensor_1d, + _preprocess_tensor, ) from .constants import BoundaryMode, WaveletCoeff2d @@ -273,7 +273,7 @@ def wavedec( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_1d(data, axis=axis) + data, ds = _preprocess_tensor(data, ndim=1, axes=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -371,6 +371,6 @@ def waverec( res_lo = res_lo[..., :-padr] # undo folding and swapping - res_lo = _postprocess_tensor_1d(res_lo, ds=ds, axis=axis) + res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis) return res_lo diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 358903a3..b0c4b9ef 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -20,9 +20,9 @@ _outer, _pad_symmetric, _postprocess_coeffs_2d, - _postprocess_tensor_2d, + _postprocess_tensor, _preprocess_coeffs_2d, - _preprocess_tensor_2d, + _preprocess_tensor, ) from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d from .conv_transform import ( @@ -166,7 +166,7 @@ def wavedec2( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_2d(data, axes=axes) + data, ds = _preprocess_tensor(data, ndim=2, axes=axes) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) @@ -298,6 +298,6 @@ def waverec2( if padr > 0: res_ll = res_ll[..., :-padr] - res_ll = _postprocess_tensor_2d(res_ll, ds=ds, axes=axes) + res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) return res_ll diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 9d196ccd..10493369 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -20,9 +20,9 @@ _is_boundary_mode_supported, _is_dtype_supported, _postprocess_coeffs_1d, - _postprocess_tensor_1d, + _postprocess_tensor, _preprocess_coeffs_1d, - _preprocess_tensor_1d, + _preprocess_tensor, ) from .constants import OrthogonalizeMethod from .conv_transform import _get_filter_tensors @@ -330,9 +330,10 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - input_signal, ds = _preprocess_tensor_1d( + input_signal, ds = _preprocess_tensor( input_signal, - axis=self.axis, + ndim=1, + axes=self.axis, add_channel_dim=False, ) @@ -644,6 +645,6 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: res_lo = lo.T - res_lo = _postprocess_tensor_1d(res_lo, ds=ds, axis=self.axis) + res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=self.axis) return res_lo diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 2a2def91..e8cdde13 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -18,9 +18,9 @@ _is_boundary_mode_supported, _is_dtype_supported, _postprocess_coeffs_2d, - _postprocess_tensor_2d, + _postprocess_tensor, _preprocess_coeffs_2d, - _preprocess_tensor_2d, + _preprocess_tensor, ) from .constants import ( OrthogonalizeMethod, @@ -427,8 +427,8 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ - input_signal, ds = _preprocess_tensor_2d( - input_signal, axes=self.axes, add_channel_dim=False + input_signal, ds = _preprocess_tensor( + input_signal, ndim=2, axes=self.axes, add_channel_dim=False ) batch_size, height, width = input_signal.shape @@ -815,6 +815,6 @@ def __call__( if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] - ll = _postprocess_tensor_2d(ll, ds=ds, axes=self.axes) + ll = _postprocess_tensor(ll, ndim=2, ds=ds, axes=self.axes) return ll diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index b1724ddc..167a45ef 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -23,8 +23,8 @@ _fold_axes, _is_dtype_supported, _map_result, - _postprocess_tensor_2d, - _preprocess_tensor_2d, + _postprocess_tensor, + _preprocess_tensor, _swap_axes, _undo_swap_axes, _unfold_axes, @@ -229,7 +229,7 @@ def fswavedec2( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor_2d(data, axes=axes, add_channel_dim=False) + data, ds = _preprocess_tensor(data, ndim=2, axes=axes, add_channel_dim=False) res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) if ds: @@ -372,7 +372,7 @@ def fswaverec2( res_ll = _separable_conv_waverecn(coeffs, wavelet) - res_ll = _postprocess_tensor_2d(res_ll, ds=ds, axes=axes) + res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) return res_ll @@ -417,9 +417,8 @@ def fswaverec3( swap_fn = partial(_swap_axes, axes=list(axes)) coeffs = _map_result(coeffs, swap_fn) - ds = None - wavelet = _as_wavelet(wavelet) res_ll = _check_if_tensor(coeffs[0]) + ds = list(res_ll.shape) torch_dtype = res_ll.dtype if res_ll.dim() >= 5: @@ -433,10 +432,6 @@ def fswaverec3( res_ll = _separable_conv_waverecn(coeffs, wavelet) - if ds: - res_ll = _unfold_axes(res_ll, list(ds), 3) - - if axes != (-3, -2, -1): - res_ll = _undo_swap_axes(res_ll, list(axes)) + res_ll = _postprocess_tensor(res_ll, ndim=3, ds=ds, axes=axes) return res_ll diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index 05e7e7d6..672c6144 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -11,9 +11,9 @@ Wavelet, _as_wavelet, _postprocess_coeffs_1d, - _postprocess_tensor_1d, + _postprocess_tensor, _preprocess_coeffs_1d, - _preprocess_tensor_1d, + _preprocess_tensor, ) from .conv_transform import _get_filter_tensors @@ -72,7 +72,7 @@ def swt( Returns: Same as wavedec. Equivalent to pywt.swt with trim_approx=True. """ - data, ds = _preprocess_tensor_1d(data, axis) + data, ds = _preprocess_tensor(data, ndim=1, axes=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -142,6 +142,6 @@ def iswt( 1, ) - res_lo = _postprocess_tensor_1d(res_lo, ds=ds, axis=axis) + res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis) return res_lo From 4e8ed028fdaaad459e15437f83375f3299cd64ea Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 22:29:38 +0200 Subject: [PATCH 06/33] Adapt 1d cases --- src/ptwt/_util.py | 41 ++++++++++++++++++---------- src/ptwt/conv_transform.py | 4 +-- src/ptwt/matmul_transform.py | 3 +- src/ptwt/separable_conv_transform.py | 2 +- src/ptwt/stationary_transform.py | 5 ++-- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 370d742c..b340b896 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -256,39 +256,52 @@ def _map_result( return approx, *cast_result_lst +def _map_result_1d( + coeffs: Sequence[torch.Tensor], function: Callable[[torch.Tensor], torch.Tensor] +) -> list[torch.Tensor]: + return [function(coeff) for coeff in coeffs] + + def _preprocess_coeffs_1d( - result_lst: Sequence[torch.Tensor], axis: int -) -> tuple[Sequence[torch.Tensor], list[int]]: + coeffs: Sequence[torch.Tensor], axis: int +) -> tuple[list[torch.Tensor], list[int]]: if axis != -1: if isinstance(axis, int): - result_lst = [coeff.swapaxes(axis, -1) for coeff in result_lst] + swap_fn = partial(_swap_axes, axes=(axis,)) + coeffs = _map_result_1d(coeffs, swap_fn) else: raise ValueError("1d transforms operate on a single axis only.") # Fold axes for the wavelets - ds = list(result_lst[0].shape) + ds = list(coeffs[0].shape) if len(ds) == 1: - result_lst = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst] + coeffs = _map_result_1d(coeffs, lambda x: x.unsqueeze(0)) elif len(ds) > 2: - result_lst = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst] - return result_lst, ds + coeffs = _map_result_1d(coeffs, lambda t: _fold_axes(t, 1)[0]) + elif not isinstance(coeffs, list): + coeffs = list(coeffs) + + return coeffs, ds def _postprocess_coeffs_1d( - result_list: list[torch.Tensor], ds: list[int], axis: int + coeffs: Sequence[torch.Tensor], ds: list[int], axis: int ) -> list[torch.Tensor]: if len(ds) == 1: - result_list = [r_el.squeeze(0) for r_el in result_list] + coeffs = _map_result_1d(coeffs, lambda x: x.squeeze(0)) elif len(ds) > 2: # Unfold axes for the wavelets - result_list = [_unfold_axes(fres, ds, 1) for fres in result_list] - else: - result_list = result_list + _unfold_axes1 = partial(_unfold_axes, ds=ds, keep_no=1) + coeffs = _map_result_1d(coeffs, _unfold_axes1) if axis != -1: - result_list = [coeff.swapaxes(axis, -1) for coeff in result_list] + undo_swap_fn = partial(_undo_swap_axes, axes=(axis,)) + coeffs = _map_result_1d(coeffs, undo_swap_fn) - return result_list + if not isinstance(coeffs, list): + coeffs = list(coeffs) + + return coeffs def _preprocess_coeffs_2d( diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index b1a8a5cc..66d70e1f 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -294,9 +294,7 @@ def wavedec( result_list.append(res_lo.squeeze(1)) result_list.reverse() - result_list = _postprocess_coeffs_1d(result_list, ds, axis) - - return result_list + return _postprocess_coeffs_1d(result_list, ds, axis) def waverec( diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 10493369..31c22060 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -378,8 +378,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: result_list = [s.T for s in split_list[::-1]] # unfold if necessary - result_list = _postprocess_coeffs_1d(result_list, ds, self.axis) - return result_list + return _postprocess_coeffs_1d(result_list, ds, self.axis) def construct_boundary_a( diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 167a45ef..f708554b 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -423,7 +423,7 @@ def fswaverec3( if res_ll.dim() >= 5: # avoid the channel sum, fold the channels into batches. - ds = _check_if_tensor(coeffs[0]).shape + ds = list(_check_if_tensor(coeffs[0]).shape) coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 3)[0]) res_ll = _check_if_tensor(coeffs[0]) diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index 672c6144..1616dc30 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -95,10 +95,9 @@ def swt( # result_list.append((res_lo.squeeze(1), res_hi.squeeze(1))) result_list.append(res_hi.squeeze(1)) result_list.append(res_lo.squeeze(1)) + result_list.reverse() - result_list = _postprocess_coeffs_1d(result_list, ds, axis) - - return result_list[::-1] + return _postprocess_coeffs_1d(result_list, ds, axis) def iswt( From f8891586b4b08947ab50afb8e5ac2c9d23b21e0e Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 23:13:12 +0200 Subject: [PATCH 07/33] Extend _map_result to 1d case --- src/ptwt/_util.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index b340b896..71e81af3 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -209,6 +209,13 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: return torch.permute(data, restore_sorted) +@overload +def _map_result( + data: list[torch.Tensor], + function: Callable[[torch.Tensor], torch.Tensor], +) -> list[torch.Tensor]: ... + + @overload def _map_result( data: WaveletCoeff2d, @@ -224,12 +231,13 @@ def _map_result( def _map_result( - data: Union[WaveletCoeff2d, WaveletCoeffNd], + data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], -) -> Union[WaveletCoeff2d, WaveletCoeffNd]: +) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: approx = function(data[0]) result_lst: list[ Union[ + torch.Tensor, WaveletDetailDict, WaveletDetailTuple2d, ] @@ -246,14 +254,25 @@ def _map_result( elif isinstance(element, dict): new_dict = {key: function(value) for key, value in element.items()} result_lst.append(new_dict) + elif isinstance(element, torch.Tensor): + result_lst.append(function(element)) else: raise ValueError(f"Unexpected input type {type(element)}") - # cast since we assume that the full list is of the same type - cast_result_lst = cast( - Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst - ) - return approx, *cast_result_lst + if not result_lst: + # if only approximation coeff: + # use list iff data is a list + return [approx] if isinstance(data, list) else (approx,) + elif isinstance(result_lst[0], torch.Tensor): + # if the first detail coeff is tensor + # -> all are tensors -> return a list + return [approx] + cast(list[torch.Tensor], result_lst) + else: + # cast since we assume that the full list is of the same type + cast_result_lst = cast( + Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst + ) + return approx, *cast_result_lst def _map_result_1d( From 0ed2896826d9acfc8c7da0d3050b7e3b250d530f Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 23:16:05 +0200 Subject: [PATCH 08/33] Make _preprocess_coeffs general --- src/ptwt/_util.py | 117 +++++++++++++++++++------------ src/ptwt/conv_transform.py | 6 +- src/ptwt/conv_transform_2.py | 4 +- src/ptwt/matmul_transform.py | 6 +- src/ptwt/matmul_transform_2.py | 4 +- src/ptwt/stationary_transform.py | 6 +- 6 files changed, 87 insertions(+), 56 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 71e81af3..74a251b8 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -5,7 +5,7 @@ import typing from collections.abc import Sequence from functools import partial -from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload +from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np import pywt @@ -275,47 +275,94 @@ def _map_result( return approx, *cast_result_lst -def _map_result_1d( - coeffs: Sequence[torch.Tensor], function: Callable[[torch.Tensor], torch.Tensor] -) -> list[torch.Tensor]: - return [function(coeff) for coeff in coeffs] +# 1d case +@overload +def _preprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: Literal[1], + axes: int, + add_channel_dim: bool = False, +) -> tuple[list[torch.Tensor], list[int]]: + ... +# 2d case +@overload +def _preprocess_coeffs( + coeffs: WaveletCoeff2d, + ndim: Literal[2], + axes: tuple[int, int], + add_channel_dim: bool = False, +) -> tuple[WaveletCoeff2d, list[int]]: + ... -def _preprocess_coeffs_1d( - coeffs: Sequence[torch.Tensor], axis: int -) -> tuple[list[torch.Tensor], list[int]]: - if axis != -1: - if isinstance(axis, int): - swap_fn = partial(_swap_axes, axes=(axis,)) - coeffs = _map_result_1d(coeffs, swap_fn) +# Nd case +@overload +def _preprocess_coeffs( + coeffs: WaveletCoeffNd, + ndim: int, + axes: tuple[int, ...], + add_channel_dim: bool = False, +) -> tuple[WaveletCoeffNd, list[int]]: + ... + +def _preprocess_coeffs( + coeffs: Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, + ], + ndim: int, + axes: Union[tuple[int, ...], int], + add_channel_dim: bool = False, +) -> tuple[ + Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, + ], + list[int], +]: + if isinstance(axes, int): + axes = (axes,) + + if ndim <= 0: + raise ValueError("Number of dimensions must be positive") + + if tuple(axes) != tuple(range(-ndim, 0)): + if len(axes) != ndim: + raise ValueError(f"{ndim}D transforms work with {ndim} axes.") else: - raise ValueError("1d transforms operate on a single axis only.") + swap_fn = partial(_swap_axes, axes=axes) + coeffs = _map_result(coeffs, swap_fn) # Fold axes for the wavelets ds = list(coeffs[0].shape) - if len(ds) == 1: - coeffs = _map_result_1d(coeffs, lambda x: x.unsqueeze(0)) - elif len(ds) > 2: - coeffs = _map_result_1d(coeffs, lambda t: _fold_axes(t, 1)[0]) - elif not isinstance(coeffs, list): - coeffs = list(coeffs) + if len(ds) < ndim: + raise ValueError(f"At least {ndim} input dimensions required.") + elif len(ds) == ndim: + coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) + elif len(ds) > ndim + 1: + coeffs = _map_result(coeffs, lambda t: _fold_axes(t, ndim)[0]) + + if add_channel_dim: + coeffs = _map_result(coeffs, lambda x: x.unsqueeze(1)) return coeffs, ds def _postprocess_coeffs_1d( - coeffs: Sequence[torch.Tensor], ds: list[int], axis: int + coeffs: list[torch.Tensor], ds: list[int], axis: int ) -> list[torch.Tensor]: if len(ds) == 1: - coeffs = _map_result_1d(coeffs, lambda x: x.squeeze(0)) + coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) elif len(ds) > 2: # Unfold axes for the wavelets _unfold_axes1 = partial(_unfold_axes, ds=ds, keep_no=1) - coeffs = _map_result_1d(coeffs, _unfold_axes1) + coeffs = _map_result(coeffs, _unfold_axes1) if axis != -1: undo_swap_fn = partial(_undo_swap_axes, axes=(axis,)) - coeffs = _map_result_1d(coeffs, undo_swap_fn) + coeffs = _map_result(coeffs, undo_swap_fn) if not isinstance(coeffs, list): coeffs = list(coeffs) @@ -323,28 +370,6 @@ def _postprocess_coeffs_1d( return coeffs -def _preprocess_coeffs_2d( - coeffs: WaveletCoeff2d, axes: tuple[int, int] -) -> tuple[WaveletCoeff2d, list[int]]: - # swap axes if necessary - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - swap_fn = partial(_swap_axes, axes=axes) - coeffs = _map_result(coeffs, swap_fn) - - # Fold axes for the wavelets - ds = list(coeffs[0].shape) - if len(ds) <= 1: - raise ValueError("2d transforms require at least 2 input dimensions") - elif len(ds) == 2: - coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) - elif len(ds) > 3: - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) - return coeffs, ds - - def _postprocess_coeffs_2d( coeffs: WaveletCoeff2d, ds: list[int], axes: tuple[int, int] ) -> WaveletCoeff2d: @@ -403,7 +428,7 @@ def _preprocess_tensor( # Preprocess multidimensional input. ds = list(data.shape) if len(ds) < ndim: - raise ValueError(f"More than {ndim} input dimensions required.") + raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: data = data.unsqueeze(0) elif len(ds) > ndim + 1: diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 66d70e1f..cf7c5943 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -19,7 +19,7 @@ _pad_symmetric, _postprocess_coeffs_1d, _postprocess_tensor, - _preprocess_coeffs_1d, + _preprocess_coeffs, _preprocess_tensor, ) from .constants import BoundaryMode, WaveletCoeff2d @@ -343,7 +343,9 @@ def waverec( raise ValueError("coefficients must have the same dtype") # fold channels and swap axis, if necessary. - coeffs, ds = _preprocess_coeffs_1d(coeffs, axis=axis) + if not isinstance(coeffs, list): + coeffs = list(coeffs) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index b0c4b9ef..af84ec0a 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -21,7 +21,7 @@ _pad_symmetric, _postprocess_coeffs_2d, _postprocess_tensor, - _preprocess_coeffs_2d, + _preprocess_coeffs, _preprocess_tensor, ) from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d @@ -242,7 +242,7 @@ def waverec2( if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") - coeffs, ds = _preprocess_coeffs_2d(coeffs, axes=axes) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 31c22060..6d590c89 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -21,7 +21,7 @@ _is_dtype_supported, _postprocess_coeffs_1d, _postprocess_tensor, - _preprocess_coeffs_1d, + _preprocess_coeffs, _preprocess_tensor, ) from .constants import OrthogonalizeMethod @@ -592,7 +592,9 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: coefficients are not in the shape as it is returned from a `MatrixWavedec` object. """ - coefficients, ds = _preprocess_coeffs_1d(coefficients, self.axis) + if not isinstance(coefficients, list): + coefficients = list(coefficients) + coefficients, ds = _preprocess_coeffs(coefficients, ndim=1, axes=self.axis) level = len(coefficients) - 1 input_length = coefficients[-1].shape[-1] * 2 diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index e8cdde13..5dc2eb7e 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -19,7 +19,7 @@ _is_dtype_supported, _postprocess_coeffs_2d, _postprocess_tensor, - _preprocess_coeffs_2d, + _preprocess_coeffs, _preprocess_tensor, ) from .constants import ( @@ -721,7 +721,7 @@ def __call__( coefficients are not in the shape as it is returned from a `MatrixWavedec2` object. """ - coefficients, ds = _preprocess_coeffs_2d(coefficients, axes=self.axes) + coefficients, ds = _preprocess_coeffs(coefficients, ndim=2, axes=self.axes) ll = coefficients[0] level = len(coefficients) - 1 diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index 1616dc30..b8d2a32c 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -12,7 +12,7 @@ _as_wavelet, _postprocess_coeffs_1d, _postprocess_tensor, - _preprocess_coeffs_1d, + _preprocess_coeffs, _preprocess_tensor, ) from .conv_transform import _get_filter_tensors @@ -118,7 +118,9 @@ def iswt( Returns: A reconstruction of the original swt input. """ - coeffs, ds = _preprocess_coeffs_1d(coeffs, axis=axis) + if not isinstance(coeffs, list): + coeffs = list(coeffs) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis) wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( From 03f74b7d8634e34c8b9fa7d2c28e93c0cacfb41d Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 23:24:42 +0200 Subject: [PATCH 09/33] Make postprocessing general --- src/ptwt/_util.py | 82 ++++++++++++++++++++++---------- src/ptwt/conv_transform.py | 4 +- src/ptwt/conv_transform_2.py | 4 +- src/ptwt/matmul_transform.py | 4 +- src/ptwt/matmul_transform_2.py | 4 +- src/ptwt/stationary_transform.py | 4 +- 6 files changed, 68 insertions(+), 34 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 74a251b8..46e129af 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -350,38 +350,72 @@ def _preprocess_coeffs( return coeffs, ds -def _postprocess_coeffs_1d( - coeffs: list[torch.Tensor], ds: list[int], axis: int +# 1d case +@overload +def _postprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: Literal[1], + ds: list[int], + axes: int, ) -> list[torch.Tensor]: - if len(ds) == 1: - coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) - elif len(ds) > 2: - # Unfold axes for the wavelets - _unfold_axes1 = partial(_unfold_axes, ds=ds, keep_no=1) - coeffs = _map_result(coeffs, _unfold_axes1) + ... - if axis != -1: - undo_swap_fn = partial(_undo_swap_axes, axes=(axis,)) - coeffs = _map_result(coeffs, undo_swap_fn) +# 2d case +@overload +def _postprocess_coeffs( + coeffs: WaveletCoeff2d, + ndim: Literal[2], + ds: list[int], + axes: tuple[int, int], +) -> WaveletCoeff2d: + ... - if not isinstance(coeffs, list): - coeffs = list(coeffs) +# Nd case +@overload +def _postprocess_coeffs( + coeffs: WaveletCoeffNd, + ndim: int, + ds: list[int], + axes: tuple[int, ...], +) -> WaveletCoeffNd: + ... - return coeffs +def _postprocess_coeffs( + coeffs: Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, + ], + ndim: int, + ds: list[int], + axes: Union[tuple[int, ...], int], +) -> Union[ + list[torch.Tensor], + WaveletCoeff2d, + WaveletCoeffNd, +]: + if isinstance(axes, int): + axes = (axes,) + if ndim <= 0: + raise ValueError("Number of dimensions must be positive") -def _postprocess_coeffs_2d( - coeffs: WaveletCoeff2d, ds: list[int], axes: tuple[int, int] -) -> WaveletCoeff2d: - if len(ds) == 2: + # Fold axes for the wavelets + ds = list(coeffs[0].shape) + if len(ds) < ndim: + raise ValueError(f"At least {ndim} input dimensions required.") + elif len(ds) == ndim: coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) - elif len(ds) > 3: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - coeffs = _map_result(coeffs, _unfold_axes2) + elif len(ds) > ndim + 1: + unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) + coeffs = _map_result(coeffs, unfold_axes_fn) - if tuple(axes) != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - coeffs = _map_result(coeffs, undo_swap_fn) + if tuple(axes) != tuple(range(-ndim, 0)): + if len(axes) != ndim: + raise ValueError(f"{ndim}D transforms work with {ndim} axes.") + else: + undo_swap_fn = partial(_undo_swap_axes, axes=axes) + coeffs = _map_result(coeffs, undo_swap_fn) return coeffs diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index cf7c5943..2b6eac69 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -17,7 +17,7 @@ _get_len, _is_dtype_supported, _pad_symmetric, - _postprocess_coeffs_1d, + _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, @@ -294,7 +294,7 @@ def wavedec( result_list.append(res_lo.squeeze(1)) result_list.reverse() - return _postprocess_coeffs_1d(result_list, ds, axis) + return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis) def waverec( diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index af84ec0a..523f55ce 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -19,7 +19,7 @@ _is_dtype_supported, _outer, _pad_symmetric, - _postprocess_coeffs_2d, + _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, @@ -190,7 +190,7 @@ def wavedec2( res_ll = res_ll.squeeze(1) result: WaveletCoeff2d = res_ll, *result_lst - result = _postprocess_coeffs_2d(result, ds=ds, axes=axes) + result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=axes) return result diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 6d590c89..5a5c4f43 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -19,7 +19,7 @@ _as_wavelet, _is_boundary_mode_supported, _is_dtype_supported, - _postprocess_coeffs_1d, + _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, @@ -378,7 +378,7 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: result_list = [s.T for s in split_list[::-1]] # unfold if necessary - return _postprocess_coeffs_1d(result_list, ds, self.axis) + return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=self.axis) def construct_boundary_a( diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 5dc2eb7e..7ffc1b00 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -17,7 +17,7 @@ _check_axes_argument, _is_boundary_mode_supported, _is_dtype_supported, - _postprocess_coeffs_2d, + _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, @@ -528,7 +528,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: split_list.reverse() result: WaveletCoeff2d = ll, *split_list - result = _postprocess_coeffs_2d(result, ds=ds, axes=self.axes) + result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=self.axes) return result diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index b8d2a32c..416077f2 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -10,7 +10,7 @@ from ._util import ( Wavelet, _as_wavelet, - _postprocess_coeffs_1d, + _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, @@ -97,7 +97,7 @@ def swt( result_list.append(res_lo.squeeze(1)) result_list.reverse() - return _postprocess_coeffs_1d(result_list, ds, axis) + return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis) def iswt( From ae7b34525191cf98d2d2aa28bcb8e390c8c407ec Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 23:56:33 +0200 Subject: [PATCH 10/33] Apply process funcs in 3d transforms --- src/ptwt/conv_transform_3.py | 90 +++------------------ src/ptwt/matmul_transform_3.py | 85 +++++--------------- src/ptwt/separable_conv_transform.py | 114 ++++++--------------------- 3 files changed, 60 insertions(+), 229 deletions(-) diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index 1e555cff..b7d2ddb4 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -5,7 +5,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Union import pywt @@ -14,19 +13,17 @@ from ._util import ( Wavelet, _as_wavelet, - _check_axes_argument, _check_if_tensor, - _fold_axes, _get_len, _is_dtype_supported, - _map_result, _outer, _pad_symmetric, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) -from .constants import BoundaryMode, WaveletCoeffNd +from .constants import BoundaryMode, WaveletCoeffNd, WaveletDetailDict from .conv_transform import ( _adjust_padding_at_reconstruction, _get_filter_tensors, @@ -142,21 +139,7 @@ def wavedec3( >>> data = torch.randn(5, 16, 16, 16) >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") """ - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("3D transforms work with three axes.") - else: - _check_axes_argument(list(axes)) - data = _swap_axes(data, list(axes)) - - ds = None - if data.dim() < 3: - raise ValueError("At least three dimensions are required for 3d wavedec.") - elif len(data.shape) == 3: - data = data.unsqueeze(1) - else: - data, ds = _fold_axes(data, 3) - data = data.unsqueeze(1) + data, ds = _preprocess_tensor(data, ndim=3, axes=axes) if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") @@ -172,7 +155,7 @@ def wavedec3( [data.shape[-1], data.shape[-2], data.shape[-3]], wavelet ) - result_lst: list[dict[str, torch.Tensor]] = [] + result_lst: list[WaveletDetailDict] = [] res_lll = data for _ in range(level): if len(res_lll.shape) == 4: @@ -194,34 +177,9 @@ def wavedec3( } ) result_lst.reverse() - result: WaveletCoeffNd = res_lll, *result_lst - - if ds: - _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) - result = _map_result(result, _unfold_axes_fn) - - if tuple(axes) != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - result = _map_result(result, undo_swap_fn) - - return result - + coeffs: WaveletCoeffNd = res_lll, *result_lst -def _waverec3d_fold_channels_3d_list( - coeffs: WaveletCoeffNd, -) -> tuple[ - WaveletCoeffNd, - list[int], -]: - # fold the input coefficients for processing conv2d_transpose. - fold_approx_coeff = _fold_axes(coeffs[0], 3)[0] - fold_coeffs: list[dict[str, torch.Tensor]] = [] - ds = list(_check_if_tensor(coeffs[0]).shape) - fold_coeffs = [ - {key: _fold_axes(value, 3)[0] for key, value in coeff.items()} - for coeff in coeffs[1:] - ] - return (fold_approx_coeff, *fold_coeffs), ds + return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=axes) def waverec3( @@ -256,32 +214,16 @@ def waverec3( >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") >>> reconstruction = ptwt.waverec3(transformed, "haar") """ - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("3D transforms work with three axes") - else: - _check_axes_argument(list(axes)) - swap_axes_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_axes_fn) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=3, axes=axes) wavelet = _as_wavelet(wavelet) - ds = None - # the Union[tensor, dict] idea is coming from pywt. We don't change it here. - res_lll = _check_if_tensor(coeffs[0]) - if res_lll.dim() < 3: - raise ValueError( - "Three dimensional transforms require at least three dimensions." - ) - elif res_lll.dim() >= 5: - coeffs, ds = _waverec3d_fold_channels_3d_list(coeffs) - res_lll = _check_if_tensor(coeffs[0]) + res_lll = _check_if_tensor(coeffs[0]) torch_device = res_lll.device torch_dtype = res_lll.dtype if not _is_dtype_supported(torch_dtype): - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") + raise ValueError(f"Input dtype {torch_dtype} not supported") _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -351,11 +293,5 @@ def waverec3( res_lll = res_lll[..., padfr:, :, :] if padba > 0: res_lll = res_lll[..., :-padba, :, :] - res_lll = res_lll.squeeze(1) - - if ds: - res_lll = _unfold_axes(res_lll, ds, 3) - if axes != (-3, -2, -1): - res_lll = _undo_swap_axes(res_lll, list(axes)) - return res_lll + return _postprocess_tensor(res_lll, ndim=3, ds=ds, axes=axes) diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index dba3fae0..19346e79 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -14,16 +14,14 @@ _as_wavelet, _check_axes_argument, _check_if_tensor, - _fold_axes, _is_boundary_mode_supported, _is_dtype_supported, - _map_result, - _swap_axes, - _undo_swap_axes, - _unfold_axes, + _postprocess_coeffs, + _postprocess_tensor, + _preprocess_coeffs, + _preprocess_tensor, ) -from .constants import OrthogonalizeMethod, WaveletCoeffNd -from .conv_transform_3 import _waverec3d_fold_channels_3d_list +from .constants import OrthogonalizeMethod, WaveletCoeffNd, WaveletDetailDict from .matmul_transform import construct_boundary_a, construct_boundary_s from .sparse_math import _batch_dim_mm @@ -174,17 +172,9 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: Raises: ValueError: If the input dimensions don't work. """ - if self.axes != (-3, -2, -1): - input_signal = _swap_axes(input_signal, list(self.axes)) - - ds = None - if input_signal.dim() < 3: - raise ValueError("At least three dimensions are required for 3d wavedec.") - elif len(input_signal.shape) == 3: - input_signal = input_signal.unsqueeze(1) - else: - input_signal, ds = _fold_axes(input_signal, 3) - + input_signal, ds = _preprocess_tensor( + input_signal, ndim=3, axes=self.axes, add_channel_dim=False + ) _, depth, height, width = input_signal.shape if not _is_dtype_supported(input_signal.dtype): @@ -220,7 +210,7 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: device=input_signal.device, dtype=input_signal.dtype ) - split_list: list[dict[str, torch.Tensor]] = [] + split_list: list[WaveletDetailDict] = [] lll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): # fwt_depth_matrix, fwt_row_matrix, fwt_col_matrix = fwt_mats @@ -240,17 +230,17 @@ def _split_rec( tensor: torch.Tensor, key: str, depth: int, - dict: dict[str, torch.Tensor], + to_dict: WaveletDetailDict, ) -> None: if key: - dict[key] = tensor + to_dict[key] = tensor if len(key) < depth: dim = len(key) + 1 ca, cd = torch.split(tensor, tensor.shape[-dim] // 2, dim=-dim) - _split_rec(ca, "a" + key, depth, dict) - _split_rec(cd, "d" + key, depth, dict) + _split_rec(ca, "a" + key, depth, to_dict) + _split_rec(cd, "d" + key, depth, to_dict) - coeff_dict: dict[str, torch.Tensor] = {} + coeff_dict: WaveletDetailDict = {} _split_rec(lll, "", 3, coeff_dict) lll = coeff_dict["aaa"] result_keys = list( @@ -262,17 +252,9 @@ def _split_rec( split_list.append(coeff_dict) split_list.reverse() - result: WaveletCoeffNd = lll, *split_list - - if ds: - _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) - result = _map_result(result, _unfold_axes_fn) + coeffs: WaveletCoeffNd = lll, *split_list - if self.axes != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) - result = _map_result(result, undo_swap_fn) - - return result + return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=self.axes) class MatrixWaverec3(object): @@ -372,7 +354,7 @@ def _construct_synthesis_matrices( current_width // 2, ) - def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Tensor: + def _cat_coeff_recursive(self, input_dict: WaveletDetailDict) -> torch.Tensor: done_dict = {} a_initial_keys = list(filter(lambda x: x[0] == "a", input_dict.keys())) for a_key in a_initial_keys: @@ -402,20 +384,7 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: Raises: ValueError: If the data structure is inconsistent. """ - if self.axes != (-3, -2, -1): - swap_axes_fn = partial(_swap_axes, axes=list(self.axes)) - coefficients = _map_result(coefficients, swap_axes_fn) - - ds = None - # the Union[tensor, dict] idea is coming from pywt. We don't change it here. - res_lll = _check_if_tensor(coefficients[0]) - if res_lll.dim() < 3: - raise ValueError( - "Three dimensional transforms require at least three dimensions." - ) - elif res_lll.dim() >= 5: - coefficients, ds = _waverec3d_fold_channels_3d_list(coefficients) - res_lll = _check_if_tensor(coefficients[0]) + coefficients, ds = _preprocess_coeffs(coefficients, ndim=3, axes=self.axes) level = len(coefficients) - 1 if type(coefficients[-1]) is dict: @@ -439,18 +408,12 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: self.level = level re_build = True - lll = coefficients[0] - if not isinstance(lll, torch.Tensor): - raise ValueError( - "First element of coeffs must be the approximation coefficient tensor." - ) - + lll = _check_if_tensor(coefficients[0]) torch_device = lll.device torch_dtype = lll.dtype if not _is_dtype_supported(torch_dtype): - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") + raise ValueError(f"Input dtype {torch_dtype} not supported") if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( @@ -484,10 +447,4 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: for dim, mat in enumerate(self.ifwt_matrix_list[level - 1 - c_pos][::-1]): lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1)) - if ds: - lll = _unfold_axes(lll, ds, 3) - - if self.axes != (-3, -2, -1): - lll = _undo_swap_axes(lll, list(self.axes)) - - return lll + return _postprocess_tensor(lll, ndim=3, ds=ds, axes=self.axes) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index f708554b..de24104a 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -9,7 +9,6 @@ from __future__ import annotations -from functools import partial from typing import Optional, Union import numpy as np @@ -18,23 +17,24 @@ from ._util import ( Wavelet, _as_wavelet, - _check_axes_argument, _check_if_tensor, - _fold_axes, _is_dtype_supported, - _map_result, + _postprocess_coeffs, _postprocess_tensor, + _preprocess_coeffs, _preprocess_tensor, - _swap_axes, - _undo_swap_axes, - _unfold_axes, ) -from .constants import BoundaryMode, WaveletCoeff2dSeparable, WaveletCoeffNd +from .constants import ( + BoundaryMode, + WaveletCoeff2dSeparable, + WaveletCoeffNd, + WaveletDetailDict, +) from .conv_transform import wavedec, waverec def _separable_conv_dwtn_( - rec_dict: dict[str, torch.Tensor], + rec_dict: WaveletDetailDict, input_arg: torch.Tensor, wavelet: Union[Wavelet, str], *, @@ -46,6 +46,8 @@ def _separable_conv_dwtn_( All but the first axes are transformed. Args: + rec_dict (WaveletDetailDict): The result will be stored here + in place. input_arg (torch.Tensor): Tensor of shape ``[batch, data_1, ... data_n]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. @@ -55,8 +57,6 @@ def _separable_conv_dwtn_( Defaults to "reflect". key (str): The filter application path. Defaults to "". - dict (dict[str, torch.Tensor]): The result will be stored here - in place. Defaults to {}. """ axis_total = len(input_arg.shape) - 1 if len(key) == axis_total: @@ -71,12 +71,12 @@ def _separable_conv_dwtn_( def _separable_conv_idwtn( - in_dict: dict[str, torch.Tensor], wavelet: Union[Wavelet, str] + in_dict: WaveletDetailDict, wavelet: Union[Wavelet, str] ) -> torch.Tensor: """Separable single level inverse fast wavelet transform. Args: - in_dict (dict[str, torch.Tensor]): The dictionary produced + in_dict (WaveletDetailDict): The dictionary produced by _separable_conv_dwtn_ . wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used by ``_separable_conv_dwtn_``. @@ -132,7 +132,7 @@ def _separable_conv_wavedecn( 'a' denoting the low pass or approximation filter and 'd' the high-pass or detail filter. """ - result: list[dict[str, torch.Tensor]] = [] + result: list[WaveletDetailDict] = [] approx = input if level is None: @@ -142,7 +142,7 @@ def _separable_conv_wavedecn( ) for _ in range(level): - level_dict: dict[str, torch.Tensor] = {} + level_dict: WaveletDetailDict = {} _separable_conv_dwtn_(level_dict, approx, wavelet, mode=mode, key="") approx_key = "a" * (len(input.shape) - 1) approx = level_dict.pop(approx_key) @@ -230,17 +230,9 @@ def fswavedec2( raise ValueError(f"Input dtype {data.dtype} not supported") data, ds = _preprocess_tensor(data, ndim=2, axes=axes, add_channel_dim=False) - res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - - if ds: - _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) - res = _map_result(res, _unfold_axes2) + coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - if axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - res = _map_result(res, undo_swap_fn) - - return res + return _postprocess_coeffs(coeffs, ndim=2, ds=ds, axes=axes) def fswavedec3( @@ -290,30 +282,9 @@ def fswavedec3( if not _is_dtype_supported(data.dtype): raise ValueError(f"Input dtype {data.dtype} not supported") - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("2D transforms work with two axes.") - else: - data = _swap_axes(data, list(axes)) - - wavelet = _as_wavelet(wavelet) - ds = None - if len(data.shape) >= 5: - data, ds = _fold_axes(data, 3) - elif len(data.shape) < 4: - raise ValueError("At lest four input dimensions are required.") - data = data.squeeze(1) - res = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - - if ds: - _unfold_axes3 = partial(_unfold_axes, ds=ds, keep_no=3) - res = _map_result(res, _unfold_axes3) - - if axes != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) - res = _map_result(res, undo_swap_fn) - - return res + data, ds = _preprocess_tensor(data, ndim=3, axes=axes, add_channel_dim=False) + coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) + return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=axes) def fswaverec2( @@ -350,31 +321,15 @@ def fswaverec2( >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") """ - if tuple(axes) != (-2, -1): - if len(axes) != 2: - raise ValueError("2D transforms work with two axes.") - else: - _check_axes_argument(list(axes)) - swap_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_fn) - - res_ll = _check_if_tensor(coeffs[0]) - ds = list(res_ll.shape) - torch_dtype = res_ll.dtype - - if res_ll.dim() >= 4: - # avoid the channel sum, fold the channels into batches. - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]) - res_ll = _check_if_tensor(coeffs[0]) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) + torch_dtype = _check_if_tensor(coeffs[0]).dtype if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") res_ll = _separable_conv_waverecn(coeffs, wavelet) - res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) - - return res_ll + return _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) def fswaverec3( @@ -409,29 +364,12 @@ def fswaverec3( >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") """ - if tuple(axes) != (-3, -2, -1): - if len(axes) != 3: - raise ValueError("2D transforms work with two axes.") - else: - _check_axes_argument(list(axes)) - swap_fn = partial(_swap_axes, axes=list(axes)) - coeffs = _map_result(coeffs, swap_fn) - - res_ll = _check_if_tensor(coeffs[0]) - ds = list(res_ll.shape) - torch_dtype = res_ll.dtype - - if res_ll.dim() >= 5: - # avoid the channel sum, fold the channels into batches. - ds = list(_check_if_tensor(coeffs[0]).shape) - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, 3)[0]) - res_ll = _check_if_tensor(coeffs[0]) + coeffs, ds = _preprocess_coeffs(coeffs, ndim=3, axes=axes) + torch_dtype = _check_if_tensor(coeffs[0]).dtype if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") res_ll = _separable_conv_waverecn(coeffs, wavelet) - res_ll = _postprocess_tensor(res_ll, ndim=3, ds=ds, axes=axes) - - return res_ll + return _postprocess_tensor(res_ll, ndim=3, ds=ds, axes=axes) From 21332bfd1462896323147ca492e8e897b475b5da Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 24 Jun 2024 23:57:34 +0200 Subject: [PATCH 11/33] Format --- src/ptwt/_util.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 46e129af..5bb34ec6 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -3,9 +3,9 @@ from __future__ import annotations import typing -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union, cast, overload +from typing import Any, Literal, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np import pywt @@ -282,8 +282,8 @@ def _preprocess_coeffs( ndim: Literal[1], axes: int, add_channel_dim: bool = False, -) -> tuple[list[torch.Tensor], list[int]]: - ... +) -> tuple[list[torch.Tensor], list[int]]: ... + # 2d case @overload @@ -292,8 +292,8 @@ def _preprocess_coeffs( ndim: Literal[2], axes: tuple[int, int], add_channel_dim: bool = False, -) -> tuple[WaveletCoeff2d, list[int]]: - ... +) -> tuple[WaveletCoeff2d, list[int]]: ... + # Nd case @overload @@ -302,8 +302,8 @@ def _preprocess_coeffs( ndim: int, axes: tuple[int, ...], add_channel_dim: bool = False, -) -> tuple[WaveletCoeffNd, list[int]]: - ... +) -> tuple[WaveletCoeffNd, list[int]]: ... + def _preprocess_coeffs( coeffs: Union[ @@ -357,8 +357,8 @@ def _postprocess_coeffs( ndim: Literal[1], ds: list[int], axes: int, -) -> list[torch.Tensor]: - ... +) -> list[torch.Tensor]: ... + # 2d case @overload @@ -367,8 +367,8 @@ def _postprocess_coeffs( ndim: Literal[2], ds: list[int], axes: tuple[int, int], -) -> WaveletCoeff2d: - ... +) -> WaveletCoeff2d: ... + # Nd case @overload @@ -377,8 +377,8 @@ def _postprocess_coeffs( ndim: int, ds: list[int], axes: tuple[int, ...], -) -> WaveletCoeffNd: - ... +) -> WaveletCoeffNd: ... + def _postprocess_coeffs( coeffs: Union[ @@ -432,10 +432,10 @@ def _preprocess_tensor( data (torch.Tensor): An input tensor with at least `ndim` axes. ndim (int): The number of axes on which the transformation is applied. - axis (int or tuple of ints): Compute the transform over these axes + axes (int or tuple of ints): Compute the transform over these axes instead of the last ones. add_channel_dim (bool): If True, ensures that the return has at - least three axes by adding a new axis at dim 1. + least `ndim` + 2 axes by potentially adding a new axis at dim 1. Defaults to True. Returns: From b461c979f8d2c0692eaed52f8420fc14496820c1 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 00:14:35 +0200 Subject: [PATCH 12/33] Fix coeff postprocess --- src/ptwt/_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 5bb34ec6..bcdd48fd 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -401,7 +401,6 @@ def _postprocess_coeffs( raise ValueError("Number of dimensions must be positive") # Fold axes for the wavelets - ds = list(coeffs[0].shape) if len(ds) < ndim: raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: From 105a0a8c2e0830a1330a4dcc0547cdac5097ce07 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 00:17:24 +0200 Subject: [PATCH 13/33] Reduce tensor processing to coeff processing --- src/ptwt/_util.py | 72 ++++++++++++++++------------------------------- 1 file changed, 25 insertions(+), 47 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index bcdd48fd..fbf6a658 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -305,6 +305,16 @@ def _preprocess_coeffs( ) -> tuple[WaveletCoeffNd, list[int]]: ... +# list of nd tensors +@overload +def _preprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: int, + axes: Union[tuple[int, ...], int], + add_channel_dim: bool = False, +) -> tuple[list[torch.Tensor], list[int]]: ... + + def _preprocess_coeffs( coeffs: Union[ list[torch.Tensor], @@ -380,6 +390,16 @@ def _postprocess_coeffs( ) -> WaveletCoeffNd: ... +# list of nd tensors +@overload +def _postprocess_coeffs( + coeffs: list[torch.Tensor], + ndim: int, + ds: list[int], + axes: Union[tuple[int, ...], int], +) -> list[torch.Tensor]: ... + + def _postprocess_coeffs( coeffs: Union[ list[torch.Tensor], @@ -441,56 +461,14 @@ def _preprocess_tensor( A tuple (data, ds) where data is the transformed data tensor and ds contains the original shape. If `add_channel_dim` is True, `data` has `ndim` + 2 axes, otherwise `ndim` + 1. - - Raises: - ValueError: if `ndim` is not positive, `axes` has not at least - length `ndim` or `data` has not at least `ndim` axes. """ - if isinstance(axes, int): - axes = (axes,) - - if ndim <= 0: - raise ValueError("Number of dimensions must be positive") - - if tuple(axes) != tuple(range(-ndim, 0)): - if len(axes) != ndim: - raise ValueError(f"{ndim}D transforms work with {ndim} axes.") - else: - data = _swap_axes(data, axes) - - # Preprocess multidimensional input. - ds = list(data.shape) - if len(ds) < ndim: - raise ValueError(f"At least {ndim} input dimensions required.") - elif len(ds) == ndim: - data = data.unsqueeze(0) - elif len(ds) > ndim + 1: - data, ds = _fold_axes(data, ndim) - - if add_channel_dim: - data = data.unsqueeze(1) - - return data, ds + data_lst, ds = _preprocess_coeffs( + [data], ndim=ndim, axes=axes, add_channel_dim=add_channel_dim + ) + return data_lst[0], ds def _postprocess_tensor( data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] ) -> torch.Tensor: - if isinstance(axes, int): - axes = (axes,) - - if ndim <= 0: - raise ValueError("Number of dimensions must be positive") - - if len(ds) == ndim: - data = data.squeeze(0) - elif len(ds) > ndim + 1: - data = _unfold_axes(data, ds, ndim) - - if tuple(axes) != tuple(range(-ndim, 0)): - if len(axes) != ndim: - raise ValueError(f"{ndim}D transforms work with {ndim} axes.") - else: - data = _undo_swap_axes(data, axes) - - return data + return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0] From 24a25f457bdd69e5e297d383035511580eeb22f6 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 00:33:27 +0200 Subject: [PATCH 14/33] Add fully separable transforms for n dims --- src/ptwt/separable_conv_transform.py | 133 +++++++++++++++++++-------- 1 file changed, 95 insertions(+), 38 deletions(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index de24104a..31660aea 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -217,22 +217,13 @@ def fswavedec2( as keys. 'a' denotes the low pass or approximation filter and 'd' the high-pass or detail filter. - Raises: - ValueError: If the data is not a batched 2D signal. - Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - - data, ds = _preprocess_tensor(data, ndim=2, axes=axes, add_channel_dim=False) - coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - - return _postprocess_coeffs(coeffs, ndim=2, ds=ds, axes=axes) + return fswavedecn(data, wavelet, ndim=2, mode=mode, level=level, axes=axes) def fswavedec3( @@ -269,22 +260,13 @@ def fswavedec3( as keys. 'a' denotes the low pass or approximation filter and 'd' the high-pass or detail filter. - Raises: - ValueError: If the input is not a batched 3D signal. - - Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - - data, ds = _preprocess_tensor(data, ndim=3, axes=axes, add_channel_dim=False) - coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) - return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=axes) + return fswavedecn(data, wavelet, ndim=3, mode=mode, level=level, axes=axes) def fswaverec2( @@ -311,9 +293,6 @@ def fswaverec2( Returns: A reconstruction of the signal encoded in the wavelet coefficients. - Raises: - ValueError: If the axes argument is not a tuple of two integers. - Example: >>> import torch >>> import ptwt @@ -321,15 +300,7 @@ def fswaverec2( >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") """ - coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) - - torch_dtype = _check_if_tensor(coeffs[0]).dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - - res_ll = _separable_conv_waverecn(coeffs, wavelet) - - return _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) + return fswaverecn(coeffs, wavelet, ndim=2, axes=axes) def fswaverec3( @@ -353,10 +324,6 @@ def fswaverec3( Returns: A reconstruction of the signal encoded in the wavelet coefficients. - Raises: - ValueError: If the axes argument is not a tuple with - three ints. - Example: >>> import torch >>> import ptwt @@ -364,12 +331,102 @@ def fswaverec3( >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") """ - coeffs, ds = _preprocess_coeffs(coeffs, ndim=3, axes=axes) + return fswaverecn(coeffs, wavelet, ndim=3, axes=axes) + +def fswavedecn( + data: torch.Tensor, + wavelet: Union[Wavelet, str], + ndim: int, + *, + mode: BoundaryMode = "reflect", + level: Optional[int] = None, + axes: Optional[tuple[int, ...]] = None, +) -> WaveletCoeffNd: + """Compute a fully separable :math:`N`-dimensional padded FWT. + + Args: + data (torch.Tensor): An input signal with at least :math:`N` dimensions. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. Refer to the output of + ``pywt.wavelist(kind="discrete")`` for possible choices. + ndim (int): The number of dimentsions :math:`N`. + mode: + The desired padding mode for extending the signal along the edges. + Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`. + level (int): The number of desired scales. Defaults to None. + axes (tuple[int, ...], optional): Compute the transform over these axes + instead of the last :math:`N`. If None, the last :math:`N` + axes are transformed. Defaults to None. + + Returns: + A tuple with the lll coefficients and for each scale a dictionary + containing the detail coefficients, + see :data:`ptwt.constants.WaveletCoeffNd`. + + Raises: + ValueError: if the dtype of `data` is not supported. + + Example: + >>> import torch + >>> import ptwt + >>> data = torch.randn(5, 10, 10, 10) + >>> coeff = ptwt.fswavedecn(data, "haar", ndim=3, level=2) + """ + if not _is_dtype_supported(data.dtype): + raise ValueError(f"Input dtype {data.dtype} not supported") + + if axes is None: + axes = tuple(range(-ndim, 0)) + + data, ds = _preprocess_tensor(data, ndim=ndim, axes=axes, add_channel_dim=False) + coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) + return _postprocess_coeffs(coeffs, ndim=ndim, ds=ds, axes=axes) + + +def fswaverecn( + coeffs: WaveletCoeffNd, + wavelet: Union[Wavelet, str], + ndim: int, + axes: Optional[tuple[int, ...]] = None, +) -> torch.Tensor: + """Invert a fully separable :math:`N`-dimensional padded FWT. + + Args: + coeffs (WaveletCoeffNd): + The wavelet coefficients as computed by :func:`fswavedecn`, + see :data:`ptwt.constants.WaveletCoeffNd`. + wavelet (Wavelet or str): A pywt wavelet compatible object or + the name of a pywt wavelet. + Refer to the output from ``pywt.wavelist(kind='discrete')`` + for possible choices. + ndim (int): The number of dimentsions :math:`N`. + axes (tuple[int, ...], optional): Compute the transform over these axes + instead of the last :math:`N`. If None, the last :math:`N` + axes are transformed. Defaults to None. + + Returns: + A reconstruction of the signal encoded in the wavelet coefficients. + + Raises: + ValueError: if the dtype of `data` is not supported. + + Example: + >>> import torch + >>> import ptwt + >>> data = torch.randn(5, 10, 10, 10) + >>> coeff = ptwt.fswavedecn(data, "haar", ndim=3, level=2) + >>> rec = ptwt.fswaverec3(coeff, "haar", ndim=3) + """ torch_dtype = _check_if_tensor(coeffs[0]).dtype if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") + if axes is None: + axes = tuple(range(-ndim, 0)) + + coeffs, ds = _preprocess_coeffs(coeffs, ndim=ndim, axes=axes) + res_ll = _separable_conv_waverecn(coeffs, wavelet) - return _postprocess_tensor(res_ll, ndim=3, ds=ds, axes=axes) + return _postprocess_tensor(res_ll, ndim=ndim, ds=ds, axes=axes) From 169d6cfe7c5baf8cb2b5868b65d65f781f4f5ead Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 00:40:11 +0200 Subject: [PATCH 15/33] Move dtype check to preprocessing --- src/ptwt/_util.py | 4 ++++ src/ptwt/conv_transform.py | 7 ------- src/ptwt/conv_transform_2.py | 6 ------ src/ptwt/conv_transform_3.py | 7 ------- src/ptwt/matmul_transform.py | 7 ------- src/ptwt/matmul_transform_2.py | 9 +-------- src/ptwt/matmul_transform_3.py | 7 ------- src/ptwt/separable_conv_transform.py | 11 ----------- 8 files changed, 5 insertions(+), 53 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index fbf6a658..0dfcb7b4 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -335,6 +335,10 @@ def _preprocess_coeffs( if isinstance(axes, int): axes = (axes,) + torch_dtype = _check_if_tensor(coeffs[0]).dtype + if not _is_dtype_supported(torch_dtype): + raise ValueError(f"Input dtype {torch_dtype} not supported") + if ndim <= 0: raise ValueError("Number of dimensions must be positive") diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 2b6eac69..aa0bae00 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -15,7 +15,6 @@ Wavelet, _as_wavelet, _get_len, - _is_dtype_supported, _pad_symmetric, _postprocess_coeffs, _postprocess_tensor, @@ -270,9 +269,6 @@ def wavedec( >>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'), >>> mode='zero', level=2) """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor(data, ndim=1, axes=axis) dec_lo, dec_hi, _, _ = _get_filter_tensors( @@ -333,9 +329,6 @@ def waverec( """ torch_device = coeffs[0].device torch_dtype = coeffs[0].dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - for coeff in coeffs[1:]: if torch_device != coeff.device: raise ValueError("coefficients must be on the same device") diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 523f55ce..cd003cef 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -16,7 +16,6 @@ _as_wavelet, _check_if_tensor, _get_len, - _is_dtype_supported, _outer, _pad_symmetric, _postprocess_coeffs, @@ -163,9 +162,6 @@ def wavedec2( >>> level=2, mode="zero") """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - data, ds = _preprocess_tensor(data, ndim=2, axes=axes) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -239,8 +235,6 @@ def waverec2( _check_if_tensor(coeffs[0]) torch_device = coeffs[0].device torch_dtype = coeffs[0].dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index b7d2ddb4..a8ba1ba2 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -15,7 +15,6 @@ _as_wavelet, _check_if_tensor, _get_len, - _is_dtype_supported, _outer, _pad_symmetric, _postprocess_coeffs, @@ -141,9 +140,6 @@ def wavedec3( """ data, ds = _preprocess_tensor(data, ndim=3, axes=axes) - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - wavelet = _as_wavelet(wavelet) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype @@ -222,9 +218,6 @@ def waverec3( torch_device = res_lll.device torch_dtype = res_lll.dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype ) diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index 5a5c4f43..d5b6bdbe 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -18,7 +18,6 @@ Wavelet, _as_wavelet, _is_boundary_mode_supported, - _is_dtype_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, @@ -337,9 +336,6 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]: add_channel_dim=False, ) - if not _is_dtype_supported(input_signal.dtype): - raise ValueError(f"Input dtype {input_signal.dtype} not supported") - if input_signal.shape[-1] % 2 != 0: # odd length input # print('input length odd, padding a zero on the right') @@ -613,9 +609,6 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: elif torch_dtype != coeff.dtype: raise ValueError("coefficients must have the same dtype") - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 7ffc1b00..813829fd 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -16,7 +16,6 @@ _as_wavelet, _check_axes_argument, _is_boundary_mode_supported, - _is_dtype_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, @@ -433,9 +432,6 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: batch_size, height, width = input_signal.shape - if not _is_dtype_supported(input_signal.dtype): - raise ValueError(f"Input dtype {input_signal.dtype} not supported") - re_build = False if ( self.input_signal_shape is None @@ -722,7 +718,6 @@ def __call__( `MatrixWavedec2` object. """ coefficients, ds = _preprocess_coeffs(coefficients, ndim=2, axes=self.axes) - ll = coefficients[0] level = len(coefficients) - 1 height, width = tuple(c * 2 for c in coefficients[-1][0].shape[-2:]) @@ -740,13 +735,11 @@ def __call__( self.level = level re_build = True + ll = coefficients[0] batch_size = ll.shape[0] torch_device = ll.device torch_dtype = ll.dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 19346e79..af2485b9 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -15,7 +15,6 @@ _check_axes_argument, _check_if_tensor, _is_boundary_mode_supported, - _is_dtype_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, @@ -177,9 +176,6 @@ def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: ) _, depth, height, width = input_signal.shape - if not _is_dtype_supported(input_signal.dtype): - raise ValueError(f"Input dtype {input_signal.dtype} not supported") - re_build = False if ( self.input_signal_shape is None @@ -412,9 +408,6 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: torch_device = lll.device torch_dtype = lll.dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 31660aea..9d1b7e27 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -18,7 +18,6 @@ Wavelet, _as_wavelet, _check_if_tensor, - _is_dtype_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, @@ -364,18 +363,12 @@ def fswavedecn( containing the detail coefficients, see :data:`ptwt.constants.WaveletCoeffNd`. - Raises: - ValueError: if the dtype of `data` is not supported. - Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedecn(data, "haar", ndim=3, level=2) """ - if not _is_dtype_supported(data.dtype): - raise ValueError(f"Input dtype {data.dtype} not supported") - if axes is None: axes = tuple(range(-ndim, 0)) @@ -418,10 +411,6 @@ def fswaverecn( >>> coeff = ptwt.fswavedecn(data, "haar", ndim=3, level=2) >>> rec = ptwt.fswaverec3(coeff, "haar", ndim=3) """ - torch_dtype = _check_if_tensor(coeffs[0]).dtype - if not _is_dtype_supported(torch_dtype): - raise ValueError(f"Input dtype {torch_dtype} not supported") - if axes is None: axes = tuple(range(-ndim, 0)) From ba1b83da9b23b860983b2f4b81020b96359d860e Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 01:09:58 +0200 Subject: [PATCH 16/33] Encapsulate check for consistent dtype and device --- src/ptwt/_util.py | 26 ++++++++++++++++++++++++++ src/ptwt/conv_transform.py | 19 ++----------------- src/ptwt/conv_transform_2.py | 20 ++++---------------- src/ptwt/conv_transform_3.py | 21 ++++----------------- src/ptwt/matmul_transform.py | 14 +++----------- src/ptwt/matmul_transform_2.py | 15 +++++---------- src/ptwt/matmul_transform_3.py | 17 ++++------------- src/ptwt/separable_conv_transform.py | 6 ++---- src/ptwt/stationary_transform.py | 4 +++- 9 files changed, 53 insertions(+), 89 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 0dfcb7b4..6807a014 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -187,6 +187,32 @@ def _check_axes_argument(axes: Sequence[int]) -> None: raise ValueError("Cant transform the same axis twice.") +def _check_same_device( + tensor: torch.Tensor, torch_device: torch.device +) -> torch.Tensor: + if torch_device != tensor.device: + raise ValueError("coefficients must be on the same device") + return tensor + + +def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.Tensor: + if torch_dtype != tensor.dtype: + raise ValueError("coefficients must have the same dtype") + return tensor + + +def _check_same_device_dtype( + coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], +) -> tuple[torch.device, torch.dtype]: + c = _check_if_tensor(coeffs[0]) + torch_device, torch_dtype = c.device, c.dtype + + _map_result(coeffs, partial(_check_same_device, torch_device=torch_device)) + _map_result(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) + + return torch_device, torch_dtype + + def _get_transpose_order( axes: Sequence[int], data_shape: Sequence[int] ) -> tuple[list[int], list[int]]: diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index aa0bae00..85300e48 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -14,6 +14,7 @@ from ._util import ( Wavelet, _as_wavelet, + _check_same_device_dtype, _get_len, _pad_symmetric, _postprocess_coeffs, @@ -254,10 +255,6 @@ def wavedec( containing the wavelet coefficients. A denotes approximation and D detail coefficients. - Raises: - ValueError: If the dtype of the input data tensor is unsupported or - if more than one axis is provided. - Example: >>> import torch >>> import ptwt, pywt @@ -309,11 +306,6 @@ def waverec( Returns: The reconstructed signal tensor. - Raises: - ValueError: If the dtype of the coeffs tensor is unsupported or if the - coefficients have incompatible shapes, dtypes or devices or if - more than one axis is provided. - Example: >>> import torch >>> import ptwt, pywt @@ -327,18 +319,11 @@ def waverec( >>> pywt.Wavelet('haar')) """ - torch_device = coeffs[0].device - torch_dtype = coeffs[0].dtype - for coeff in coeffs[1:]: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - # fold channels and swap axis, if necessary. if not isinstance(coeffs, list): coeffs = list(coeffs) coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index cd003cef..d7600427 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -14,7 +14,7 @@ from ._util import ( Wavelet, _as_wavelet, - _check_if_tensor, + _check_same_device_dtype, _get_len, _outer, _pad_symmetric, @@ -145,11 +145,6 @@ def wavedec2( A tuple containing the wavelet coefficients in pywt order, see :data:`ptwt.constants.WaveletCoeff2d`. - Raises: - ValueError: If the dimensionality or the dtype of the input data tensor - is unsupported or if the provided ``axes`` - input has a length other than two. - Example: >>> import torch >>> import ptwt, pywt @@ -232,11 +227,8 @@ def waverec2( >>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar")) """ - _check_if_tensor(coeffs[0]) - torch_device = coeffs[0].device - torch_dtype = coeffs[0].dtype - coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -244,7 +236,7 @@ def waverec2( filt_len = rec_lo.shape[-1] rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) - res_ll = _check_if_tensor(coeffs[0]) + res_ll = coeffs[0] for c_pos, coeff_tuple in enumerate(coeffs[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( @@ -255,11 +247,7 @@ def waverec2( curr_shape = res_ll.shape for coeff in coeff_tuple: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif coeff.shape != curr_shape: + if coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index a8ba1ba2..f100ebc2 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -13,7 +13,7 @@ from ._util import ( Wavelet, _as_wavelet, - _check_if_tensor, + _check_same_device_dtype, _get_len, _outer, _pad_symmetric, @@ -128,11 +128,6 @@ def wavedec3( A tuple containing the wavelet coefficients, see :data:`ptwt.constants.WaveletCoeffNd`. - Raises: - ValueError: If the input has fewer than three dimensions or - if the dtype is not supported or - if the provided axes input has length other than three. - Example: >>> import ptwt, torch >>> data = torch.randn(5, 16, 16, 16) @@ -211,12 +206,7 @@ def waverec3( >>> reconstruction = ptwt.waverec3(transformed, "haar") """ coeffs, ds = _preprocess_coeffs(coeffs, ndim=3, axes=axes) - - wavelet = _as_wavelet(wavelet) - - res_lll = _check_if_tensor(coeffs[0]) - torch_device = res_lll.device - torch_dtype = res_lll.dtype + torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype @@ -224,6 +214,7 @@ def waverec3( filt_len = rec_lo.shape[-1] rec_filt = _construct_3d_filt(lo=rec_lo, hi=rec_hi) + res_lll = coeffs[0] coeff_dicts = coeffs[1:] for c_pos, coeff_dict in enumerate(coeff_dicts): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: @@ -233,11 +224,7 @@ def waverec3( "wavedec3." ) for coeff in coeff_dict.values(): - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif res_lll.shape != coeff.shape: + if res_lll.shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" ) diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index d5b6bdbe..c84ba427 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -17,6 +17,7 @@ from ._util import ( Wavelet, _as_wavelet, + _check_same_device_dtype, _is_boundary_mode_supported, _postprocess_coeffs, _postprocess_tensor, @@ -591,6 +592,7 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: if not isinstance(coefficients, list): coefficients = list(coefficients) coefficients, ds = _preprocess_coeffs(coefficients, ndim=1, axes=self.axis) + torch_device, torch_dtype = _check_same_device_dtype(coefficients) level = len(coefficients) - 1 input_length = coefficients[-1].shape[-1] * 2 @@ -601,14 +603,6 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: self.input_length = input_length re_build = True - torch_device = coefficients[0].device - torch_dtype = coefficients[0].dtype - for coeff in coefficients[1:]: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, @@ -639,6 +633,4 @@ def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor: res_lo = lo.T - res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=self.axis) - - return res_lo + return _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=self.axis) diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 813829fd..5f60c00e 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -15,6 +15,7 @@ Wavelet, _as_wavelet, _check_axes_argument, + _check_same_device_dtype, _is_boundary_mode_supported, _postprocess_coeffs, _postprocess_tensor, @@ -718,6 +719,7 @@ def __call__( `MatrixWavedec2` object. """ coefficients, ds = _preprocess_coeffs(coefficients, ndim=2, axes=self.axes) + torch_device, torch_dtype = _check_same_device_dtype(coefficients) level = len(coefficients) - 1 height, width = tuple(c * 2 for c in coefficients[-1][0].shape[-2:]) @@ -735,17 +737,14 @@ def __call__( self.level = level re_build = True - ll = coefficients[0] - batch_size = ll.shape[0] - torch_device = ll.device - torch_dtype = ll.dtype - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) + ll = coefficients[0] + batch_size = ll.shape[0] for c_pos, coeff_tuple in enumerate(coefficients[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( @@ -756,11 +755,7 @@ def __call__( curr_shape = ll.shape for coeff in coeff_tuple: - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif coeff.shape != curr_shape: + if coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index af2485b9..7aafaf20 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -13,7 +13,7 @@ Wavelet, _as_wavelet, _check_axes_argument, - _check_if_tensor, + _check_same_device_dtype, _is_boundary_mode_supported, _postprocess_coeffs, _postprocess_tensor, @@ -381,6 +381,7 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: ValueError: If the data structure is inconsistent. """ coefficients, ds = _preprocess_coeffs(coefficients, ndim=3, axes=self.axes) + torch_device, torch_dtype = _check_same_device_dtype(coefficients) level = len(coefficients) - 1 if type(coefficients[-1]) is dict: @@ -404,16 +405,13 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: self.level = level re_build = True - lll = _check_if_tensor(coefficients[0]) - torch_device = lll.device - torch_dtype = lll.dtype - if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) + lll = coefficients[0] for c_pos, coeff_dict in enumerate(coefficients[1:]): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: raise ValueError( @@ -421,15 +419,8 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: "coefficients must be a dict containing 7 tensors as returned by " "MatrixWavedec3." ) - test_shape = None for coeff in coeff_dict.values(): - if test_shape is None: - test_shape = coeff.shape - if torch_device != coeff.device: - raise ValueError("coefficients must be on the same device") - elif torch_dtype != coeff.dtype: - raise ValueError("coefficients must have the same dtype") - elif test_shape != coeff.shape: + if lll.shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" ) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 9d1b7e27..74e5ff0d 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -17,7 +17,7 @@ from ._util import ( Wavelet, _as_wavelet, - _check_if_tensor, + _check_same_device_dtype, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, @@ -401,9 +401,6 @@ def fswaverecn( Returns: A reconstruction of the signal encoded in the wavelet coefficients. - Raises: - ValueError: if the dtype of `data` is not supported. - Example: >>> import torch >>> import ptwt @@ -415,6 +412,7 @@ def fswaverecn( axes = tuple(range(-ndim, 0)) coeffs, ds = _preprocess_coeffs(coeffs, ndim=ndim, axes=axes) + _check_same_device_dtype(coeffs) res_ll = _separable_conv_waverecn(coeffs, wavelet) diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index 416077f2..b93ffcd0 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -10,6 +10,7 @@ from ._util import ( Wavelet, _as_wavelet, + _check_same_device_dtype, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, @@ -121,10 +122,11 @@ def iswt( if not isinstance(coeffs, list): coeffs = list(coeffs) coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis) + torch_device, torch_dtype = _check_same_device_dtype(coeffs) wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( - wavelet, flip=False, dtype=coeffs[0].dtype, device=coeffs[0].device + wavelet, flip=False, dtype=torch_dtype, device=torch_device ) filt_len = rec_lo.shape[-1] rec_filt = torch.stack([rec_lo, rec_hi], 0) From d50a2c25dd0ccd400ce42c71ff5c0870cbd951ba Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 01:29:52 +0200 Subject: [PATCH 17/33] Revert changes to coeff shape check --- src/ptwt/matmul_transform_3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index 7aafaf20..6d19158c 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -419,8 +419,11 @@ def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: "coefficients must be a dict containing 7 tensors as returned by " "MatrixWavedec3." ) + test_shape = None for coeff in coeff_dict.values(): - if lll.shape != coeff.shape: + if test_shape is None: + test_shape = coeff.shape + elif test_shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" ) From 2d4620aada2e501e4330b2f00002d4999958ab3c Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 10:06:08 +0200 Subject: [PATCH 18/33] Make n-dim separable trafo private --- src/ptwt/separable_conv_transform.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 74e5ff0d..26fd88e0 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -222,7 +222,7 @@ def fswavedec2( >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) """ - return fswavedecn(data, wavelet, ndim=2, mode=mode, level=level, axes=axes) + return _fswavedecn(data, wavelet, ndim=2, mode=mode, level=level, axes=axes) def fswavedec3( @@ -265,7 +265,7 @@ def fswavedec3( >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) """ - return fswavedecn(data, wavelet, ndim=3, mode=mode, level=level, axes=axes) + return _fswavedecn(data, wavelet, ndim=3, mode=mode, level=level, axes=axes) def fswaverec2( @@ -299,7 +299,7 @@ def fswaverec2( >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") """ - return fswaverecn(coeffs, wavelet, ndim=2, axes=axes) + return _fswaverecn(coeffs, wavelet, ndim=2, axes=axes) def fswaverec3( @@ -330,10 +330,10 @@ def fswaverec3( >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") """ - return fswaverecn(coeffs, wavelet, ndim=3, axes=axes) + return _fswaverecn(coeffs, wavelet, ndim=3, axes=axes) -def fswavedecn( +def _fswavedecn( data: torch.Tensor, wavelet: Union[Wavelet, str], ndim: int, @@ -377,7 +377,7 @@ def fswavedecn( return _postprocess_coeffs(coeffs, ndim=ndim, ds=ds, axes=axes) -def fswaverecn( +def _fswaverecn( coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], ndim: int, From 1b45eae0699a25e8b24815e83b359847867866be Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 10:06:22 +0200 Subject: [PATCH 19/33] Add explainatory comments --- src/ptwt/_util.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 6807a014..190a3b8b 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -207,7 +207,9 @@ def _check_same_device_dtype( c = _check_if_tensor(coeffs[0]) torch_device, torch_dtype = c.device, c.dtype + # check for all tensors in `coeffs` that the device matches `torch_device` _map_result(coeffs, partial(_check_same_device, torch_device=torch_device)) + # check for all tensors in `coeffs` that the dtype matches `torch_dtype` _map_result(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) return torch_device, torch_dtype @@ -372,6 +374,7 @@ def _preprocess_coeffs( if len(axes) != ndim: raise ValueError(f"{ndim}D transforms work with {ndim} axes.") else: + # for all tensors in `coeffs`: swap the axes swap_fn = partial(_swap_axes, axes=axes) coeffs = _map_result(coeffs, swap_fn) @@ -380,11 +383,14 @@ def _preprocess_coeffs( if len(ds) < ndim: raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: + # for all tensors in `coeffs`: unsqueeze(0) coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) elif len(ds) > ndim + 1: + # for all tensors in `coeffs`: fold leading dims to batch dim coeffs = _map_result(coeffs, lambda t: _fold_axes(t, ndim)[0]) if add_channel_dim: + # for all tensors in `coeffs`: add channel dim coeffs = _map_result(coeffs, lambda x: x.unsqueeze(1)) return coeffs, ds @@ -454,8 +460,10 @@ def _postprocess_coeffs( if len(ds) < ndim: raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: + # for all tensors in `coeffs`: remove batch dim coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) elif len(ds) > ndim + 1: + # for all tensors in `coeffs`: unfold batch dim unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) coeffs = _map_result(coeffs, unfold_axes_fn) @@ -463,6 +471,7 @@ def _postprocess_coeffs( if len(axes) != ndim: raise ValueError(f"{ndim}D transforms work with {ndim} axes.") else: + # for all tensors in `coeffs`: undo axes swapping undo_swap_fn = partial(_undo_swap_axes, axes=axes) coeffs = _map_result(coeffs, undo_swap_fn) @@ -492,6 +501,8 @@ def _preprocess_tensor( and ds contains the original shape. If `add_channel_dim` is True, `data` has `ndim` + 2 axes, otherwise `ndim` + 1. """ + # interpreting data as the approximation coeffs of a 0-level FWT + # allows us to reuse the `_preprocess_coeffs` code data_lst, ds = _preprocess_coeffs( [data], ndim=ndim, axes=axes, add_channel_dim=add_channel_dim ) @@ -501,4 +512,6 @@ def _preprocess_tensor( def _postprocess_tensor( data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] ) -> torch.Tensor: + # interpreting data as the approximation coeffs of a 0-level FWT + # allows us to reuse the `_postprocess_coeffs` code return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0] From 15af78be7ddaf11cfcf5555bb801ebe3fe4796a9 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 10:39:26 +0200 Subject: [PATCH 20/33] Add docstrings --- src/ptwt/_util.py | 117 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 7 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 190a3b8b..b2f4812e 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -204,6 +204,22 @@ def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.T def _check_same_device_dtype( coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], ) -> tuple[torch.device, torch.dtype]: + """Check coefficients for dtype and device consistency. + + Check that all coefficient tensors in `coeffs` have the same + device and dtype. + + Args: + coeffs (Wavelet coefficients): The resulting coefficients of + a discrete wavelet transform. Can be either of + `list[torch.Tensor]` (1d case), + :data:`ptwt.constants.WaveletCoeff2d` (2d case) or + :data:`ptwt.constants.WaveletCoeffNd` (Nd case). + + Returns: + A tuple (device, dtype) with the shared device and dtype of + all tensors in coeffs. + """ c = _check_if_tensor(coeffs[0]) torch_device, torch_dtype = c.device, c.dtype @@ -262,6 +278,7 @@ def _map_result( data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: + """Apply `function` to all tensor elements in `data`.""" approx = function(data[0]) result_lst: list[ Union[ @@ -360,6 +377,38 @@ def _preprocess_coeffs( ], list[int], ]: + """Preprocess coeff tensor dimensions. + + For each coefficient tensor in `coeffs` the transformed axes + as specified by `axes` are moved to be the last. + Adds a batch dim if a coefficient tensor has none. + If it has has multiple batch dimensions, they are folded into a single + batch dimension. + + Args: + coeffs (Wavelet coefficients): The resulting coefficients of + a discrete wavelet transform. Can be either of + `list[torch.Tensor]` (1d case), + :data:`ptwt.constants.WaveletCoeff2d` (2d case) or + :data:`ptwt.constants.WaveletCoeffNd` (Nd case). + ndim (int): The number of axes :math:`N` on which the transformation + was applied. + axes (int or tuple of ints): Axes on which the transform was calculated. + add_channel_dim (bool): If True, ensures that all returned coefficients + have at least `:math:`N + 2` axes by potentially adding a new axis at dim 1. + Defaults to False. + + Returns: + A tuple ``(coeffs, ds)`` where ``coeffs`` are the transformed + coefficients and ``ds`` contains the original shape of ``coeffs[0]``. + If `add_channel_dim` is True, all coefficient tensors have + :math:`N + 2` axes ([B, 1, c1, ..., cN]). + otherwise :math:`N + 1` ([B, c1, ..., cN]). + + Raises: + ValueError: If the input dtype is unsupported or `ndim` does not + fit to the passed `axes` or `coeffs` dimensions. + """ if isinstance(axes, int): axes = (axes,) @@ -450,6 +499,34 @@ def _postprocess_coeffs( WaveletCoeff2d, WaveletCoeffNd, ]: + """Postprocess coeff tensor dimensions. + + This revereses the operations of :func:`_preprocess_coeffs`. + + Unfolds potentially folded batch dimensions and removes any added + dimensions. + The transformed axes as specified by `axes` are moved back to their + original position. + + Args: + coeffs (Wavelet coefficients): The preprocessed coefficients of + a discrete wavelet transform. Can be either of + `list[torch.Tensor]` (1d case), + :data:`ptwt.constants.WaveletCoeff2d` (2d case) or + :data:`ptwt.constants.WaveletCoeffNd` (Nd case). + ndim (int): The number of axes :math:`N` on which the transformation was + applied. + ds (list of ints): The shape of the original first coefficient before + preprocessing, i.e. of ``coeffs[0]``. + axes (int or tuple of ints): Axes on which the transform was calculated. + + Returns: + The result of undoing the preprocessing operations on `coeffs`. + + Raises: + ValueError: If `ndim` does not fit to the passed `axes` + or `coeffs` dimensions. + """ if isinstance(axes, int): axes = (axes,) @@ -486,20 +563,26 @@ def _preprocess_tensor( ) -> tuple[torch.Tensor, list[int]]: """Preprocess input tensor dimensions. + The transformed axes as specified by `axes` are moved to be the last. + Adds a batch dim if `data` has none. + If `data` has multiple batch dimensions, they are folded into a single + batch dimension. + Args: data (torch.Tensor): An input tensor with at least `ndim` axes. - ndim (int): The number of axes on which the transformation is + ndim (int): The number of axes :math:`N` on which the transformation is applied. - axes (int or tuple of ints): Compute the transform over these axes - instead of the last ones. + axes (int or tuple of ints): Axes on which the transform is calculated. add_channel_dim (bool): If True, ensures that the return has at - least `ndim` + 2 axes by potentially adding a new axis at dim 1. + least :math:`N + 2` axes by potentially adding a new axis at dim 1. Defaults to True. Returns: - A tuple (data, ds) where data is the transformed data tensor - and ds contains the original shape. If `add_channel_dim` is True, - `data` has `ndim` + 2 axes, otherwise `ndim` + 1. + A tuple ``(data, ds)`` where ``data`` is the transformed data tensor + and ``ds`` contains the original shape. + If `add_channel_dim` is True, + `data` has :math:`N + 2` axes ([B, 1, d1, ..., dN]). + otherwise :math:`N + 1` ([B, d1, ..., dN]). """ # interpreting data as the approximation coeffs of a 0-level FWT # allows us to reuse the `_preprocess_coeffs` code @@ -512,6 +595,26 @@ def _preprocess_tensor( def _postprocess_tensor( data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int] ) -> torch.Tensor: + """Postprocess input tensor dimensions. + + This revereses the operations of :func:`_preprocess_tensor`. + + Unfolds potentially folded batch dimensions and removes any added + dimensions. + The transformed axes as specified by `axes` are moved back to their + original position. + + Args: + data (torch.Tensor): An preprocessed input tensor. + ndim (int): The number of axes :math:`N` on which the transformation is + applied. + ds (list of ints): The shape of the original input tensor before + preprocessing. + axes (int or tuple of ints): Axes on which the transform was calculated. + + Returns: + The result of undoing the preprocessing operations on `data`. + """ # interpreting data as the approximation coeffs of a 0-level FWT # allows us to reuse the `_postprocess_coeffs` code return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0] From 35bab80aa6c79a06b708a3e8d74b23423d92b840 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 10:49:20 +0200 Subject: [PATCH 21/33] Rename _map_result to _apply_to_tensor_elems --- src/ptwt/_util.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index b2f4812e..97d9b7e1 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -224,9 +224,9 @@ def _check_same_device_dtype( torch_device, torch_dtype = c.device, c.dtype # check for all tensors in `coeffs` that the device matches `torch_device` - _map_result(coeffs, partial(_check_same_device, torch_device=torch_device)) + _apply_to_tensor_elems(coeffs, partial(_check_same_device, torch_device=torch_device)) # check for all tensors in `coeffs` that the dtype matches `torch_dtype` - _map_result(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) + _apply_to_tensor_elems(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) return torch_device, torch_dtype @@ -254,32 +254,32 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: @overload -def _map_result( - data: list[torch.Tensor], +def _apply_to_tensor_elems( + coeffs: list[torch.Tensor], function: Callable[[torch.Tensor], torch.Tensor], ) -> list[torch.Tensor]: ... @overload -def _map_result( - data: WaveletCoeff2d, +def _apply_to_tensor_elems( + coeffs: WaveletCoeff2d, function: Callable[[torch.Tensor], torch.Tensor], ) -> WaveletCoeff2d: ... @overload -def _map_result( - data: WaveletCoeffNd, +def _apply_to_tensor_elems( + coeffs: WaveletCoeffNd, function: Callable[[torch.Tensor], torch.Tensor], ) -> WaveletCoeffNd: ... -def _map_result( - data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], +def _apply_to_tensor_elems( + coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: """Apply `function` to all tensor elements in `data`.""" - approx = function(data[0]) + approx = function(coeffs[0]) result_lst: list[ Union[ torch.Tensor, @@ -287,7 +287,7 @@ def _map_result( WaveletDetailTuple2d, ] ] = [] - for element in data[1:]: + for element in coeffs[1:]: if isinstance(element, tuple): result_lst.append( WaveletDetailTuple2d( @@ -307,7 +307,7 @@ def _map_result( if not result_lst: # if only approximation coeff: # use list iff data is a list - return [approx] if isinstance(data, list) else (approx,) + return [approx] if isinstance(coeffs, list) else (approx,) elif isinstance(result_lst[0], torch.Tensor): # if the first detail coeff is tensor # -> all are tensors -> return a list @@ -425,7 +425,7 @@ def _preprocess_coeffs( else: # for all tensors in `coeffs`: swap the axes swap_fn = partial(_swap_axes, axes=axes) - coeffs = _map_result(coeffs, swap_fn) + coeffs = _apply_to_tensor_elems(coeffs, swap_fn) # Fold axes for the wavelets ds = list(coeffs[0].shape) @@ -433,14 +433,14 @@ def _preprocess_coeffs( raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: # for all tensors in `coeffs`: unsqueeze(0) - coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0)) + coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.unsqueeze(0)) elif len(ds) > ndim + 1: # for all tensors in `coeffs`: fold leading dims to batch dim - coeffs = _map_result(coeffs, lambda t: _fold_axes(t, ndim)[0]) + coeffs = _apply_to_tensor_elems(coeffs, lambda t: _fold_axes(t, ndim)[0]) if add_channel_dim: # for all tensors in `coeffs`: add channel dim - coeffs = _map_result(coeffs, lambda x: x.unsqueeze(1)) + coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.unsqueeze(1)) return coeffs, ds @@ -538,11 +538,11 @@ def _postprocess_coeffs( raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: # for all tensors in `coeffs`: remove batch dim - coeffs = _map_result(coeffs, lambda x: x.squeeze(0)) + coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.squeeze(0)) elif len(ds) > ndim + 1: # for all tensors in `coeffs`: unfold batch dim unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) - coeffs = _map_result(coeffs, unfold_axes_fn) + coeffs = _apply_to_tensor_elems(coeffs, unfold_axes_fn) if tuple(axes) != tuple(range(-ndim, 0)): if len(axes) != ndim: @@ -550,7 +550,7 @@ def _postprocess_coeffs( else: # for all tensors in `coeffs`: undo axes swapping undo_swap_fn = partial(_undo_swap_axes, axes=axes) - coeffs = _map_result(coeffs, undo_swap_fn) + coeffs = _apply_to_tensor_elems(coeffs, undo_swap_fn) return coeffs From 37db5db12bb546c10396a947bcb1be85abbdfbfc Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 10:51:10 +0200 Subject: [PATCH 22/33] Format --- src/ptwt/_util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 97d9b7e1..81905d29 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -224,7 +224,9 @@ def _check_same_device_dtype( torch_device, torch_dtype = c.device, c.dtype # check for all tensors in `coeffs` that the device matches `torch_device` - _apply_to_tensor_elems(coeffs, partial(_check_same_device, torch_device=torch_device)) + _apply_to_tensor_elems( + coeffs, partial(_check_same_device, torch_device=torch_device) + ) # check for all tensors in `coeffs` that the dtype matches `torch_dtype` _apply_to_tensor_elems(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) From b4c16e298f13d7133de9d24a0566e88c24760f48 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Wed, 26 Jun 2024 13:46:07 +0200 Subject: [PATCH 23/33] rename tree_map --- src/ptwt/_util.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 81905d29..9bb4f2b7 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -224,11 +224,11 @@ def _check_same_device_dtype( torch_device, torch_dtype = c.device, c.dtype # check for all tensors in `coeffs` that the device matches `torch_device` - _apply_to_tensor_elems( + _coeff_tree_map( coeffs, partial(_check_same_device, torch_device=torch_device) ) # check for all tensors in `coeffs` that the dtype matches `torch_dtype` - _apply_to_tensor_elems(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) + _coeff_tree_map(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) return torch_device, torch_dtype @@ -256,31 +256,31 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor: @overload -def _apply_to_tensor_elems( +def _coeff_tree_map( coeffs: list[torch.Tensor], function: Callable[[torch.Tensor], torch.Tensor], ) -> list[torch.Tensor]: ... @overload -def _apply_to_tensor_elems( +def _coeff_tree_map( coeffs: WaveletCoeff2d, function: Callable[[torch.Tensor], torch.Tensor], ) -> WaveletCoeff2d: ... @overload -def _apply_to_tensor_elems( +def _coeff_tree_map( coeffs: WaveletCoeffNd, function: Callable[[torch.Tensor], torch.Tensor], ) -> WaveletCoeffNd: ... -def _apply_to_tensor_elems( +def _coeff_tree_map( coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: - """Apply `function` to all tensor elements in `data`.""" + """Apply `function` to all tensor elements in `coeffs`.""" approx = function(coeffs[0]) result_lst: list[ Union[ @@ -427,7 +427,7 @@ def _preprocess_coeffs( else: # for all tensors in `coeffs`: swap the axes swap_fn = partial(_swap_axes, axes=axes) - coeffs = _apply_to_tensor_elems(coeffs, swap_fn) + coeffs = _coeff_tree_map(coeffs, swap_fn) # Fold axes for the wavelets ds = list(coeffs[0].shape) @@ -435,14 +435,14 @@ def _preprocess_coeffs( raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: # for all tensors in `coeffs`: unsqueeze(0) - coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.unsqueeze(0)) + coeffs = _coeff_tree_map(coeffs, lambda x: x.unsqueeze(0)) elif len(ds) > ndim + 1: # for all tensors in `coeffs`: fold leading dims to batch dim - coeffs = _apply_to_tensor_elems(coeffs, lambda t: _fold_axes(t, ndim)[0]) + coeffs = _coeff_tree_map(coeffs, lambda t: _fold_axes(t, ndim)[0]) if add_channel_dim: # for all tensors in `coeffs`: add channel dim - coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.unsqueeze(1)) + coeffs = _coeff_tree_map(coeffs, lambda x: x.unsqueeze(1)) return coeffs, ds @@ -540,11 +540,11 @@ def _postprocess_coeffs( raise ValueError(f"At least {ndim} input dimensions required.") elif len(ds) == ndim: # for all tensors in `coeffs`: remove batch dim - coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.squeeze(0)) + coeffs = _coeff_tree_map(coeffs, lambda x: x.squeeze(0)) elif len(ds) > ndim + 1: # for all tensors in `coeffs`: unfold batch dim unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim) - coeffs = _apply_to_tensor_elems(coeffs, unfold_axes_fn) + coeffs = _coeff_tree_map(coeffs, unfold_axes_fn) if tuple(axes) != tuple(range(-ndim, 0)): if len(axes) != ndim: @@ -552,7 +552,7 @@ def _postprocess_coeffs( else: # for all tensors in `coeffs`: undo axes swapping undo_swap_fn = partial(_undo_swap_axes, axes=axes) - coeffs = _apply_to_tensor_elems(coeffs, undo_swap_fn) + coeffs = _coeff_tree_map(coeffs, undo_swap_fn) return coeffs From a9569f2e20d262a4aa4600d52a0dc0c8c1e73022 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Wed, 26 Jun 2024 13:52:20 +0200 Subject: [PATCH 24/33] rename coeff tree map. --- src/ptwt/_util.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 9bb4f2b7..befea358 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -224,9 +224,7 @@ def _check_same_device_dtype( torch_device, torch_dtype = c.device, c.dtype # check for all tensors in `coeffs` that the device matches `torch_device` - _coeff_tree_map( - coeffs, partial(_check_same_device, torch_device=torch_device) - ) + _coeff_tree_map(coeffs, partial(_check_same_device, torch_device=torch_device)) # check for all tensors in `coeffs` that the dtype matches `torch_dtype` _coeff_tree_map(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype)) @@ -280,7 +278,14 @@ def _coeff_tree_map( coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd], function: Callable[[torch.Tensor], torch.Tensor], ) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: - """Apply `function` to all tensor elements in `coeffs`.""" + """Apply `function` to all tensor elements in `coeffs`. + + The idea here is to save us from having to loop over the coefficient- + data trees during input pre- and post-processing. + + Raises: + ValueError: If the input type is not supported. + """ approx = function(coeffs[0]) result_lst: list[ Union[ From d1339f031e2418b19e164869731416adbf1afd45 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Wed, 26 Jun 2024 16:25:08 +0200 Subject: [PATCH 25/33] Add remark on JAX tree map --- src/ptwt/_util.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index befea358..f4a721e7 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -280,8 +280,14 @@ def _coeff_tree_map( ) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]: """Apply `function` to all tensor elements in `coeffs`. - The idea here is to save us from having to loop over the coefficient- - data trees during input pre- and post-processing. + Applying a function to all tensors in the (potentially nested) + coefficient data structure is a common requirement in coefficient + pre- and postprocessing. This function saves us from having to loop + over the coefficient data structures in processing. + + Conceptually, this function is inspired by the + pytree processing philosophy of the JAX framework, see + https://jax.readthedocs.io/en/latest/working-with-pytrees.html Raises: ValueError: If the input type is not supported. From 75ad167ea01939b9fd717fd220e6131243a229d2 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Mon, 1 Jul 2024 09:58:24 +0200 Subject: [PATCH 26/33] formatting. --- src/ptwt/_util.py | 4 ++-- src/ptwt/conv_transform_2.py | 1 - src/ptwt/matmul_transform.py | 16 ++++------------ src/ptwt/matmul_transform_2.py | 10 +++------- src/ptwt/matmul_transform_3.py | 12 ++++++++---- 5 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 91b85dbf..1491c6ad 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -3,10 +3,10 @@ from __future__ import annotations import functools -import warnings import typing -from functools import partial +import warnings from collections.abc import Callable, Sequence +from functools import partial from typing import Any, Literal, NamedTuple, Optional, Protocol, Union, cast, overload import numpy as np diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index 3bf58b8d..ddb96231 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -13,7 +13,6 @@ from ._util import ( Wavelet, - _as_wavelet, _check_same_device_dtype, _get_len, _outer, diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index cd29ff61..f4082e13 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -18,23 +18,15 @@ Wavelet, _as_wavelet, _check_same_device_dtype, + _deprecated_alias, + _is_orthogonalize_method_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _deprecated_alias, - _is_orthogonalize_method_supported, -) -from .constants import ( - BoundaryMode, - OrthogonalizeMethod, - OrthogonalizeMethod -) -from .conv_transform import _get_filter_tensors -from .conv_transform import ( - _fwt_pad, - _get_filter_tensors, ) +from .constants import BoundaryMode, OrthogonalizeMethod +from .conv_transform import _fwt_pad, _get_filter_tensors from .sparse_math import ( _orth_by_gram_schmidt, _orth_by_qr, diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 9cb3bb0a..d7696a85 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -16,12 +16,12 @@ _as_wavelet, _check_axes_argument, _check_same_device_dtype, + _deprecated_alias, + _is_orthogonalize_method_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _deprecated_alias, - _is_orthogonalize_method_supported, ) from .constants import ( BoundaryMode, @@ -30,11 +30,7 @@ WaveletCoeff2d, WaveletDetailTuple2d, ) -from .conv_transform_2 import ( - _get_filter_tensors, - _construct_2d_filt, - _fwt_pad2, -) +from .conv_transform_2 import _construct_2d_filt, _fwt_pad2, _get_filter_tensors from .matmul_transform import ( BaseMatrixWaveDec, construct_boundary_a, diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index dfad27cd..bd0249eb 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -14,15 +14,19 @@ _as_wavelet, _check_axes_argument, _check_same_device_dtype, + _deprecated_alias, + _is_orthogonalize_method_supported, _postprocess_coeffs, _postprocess_tensor, _preprocess_coeffs, _preprocess_tensor, - _deprecated_alias, - _is_orthogonalize_method_supported, ) -from .constants import OrthogonalizeMethod, WaveletCoeffNd, WaveletDetailDict -from .constants import BoundaryMode, OrthogonalizeMethod, WaveletCoeffNd +from .constants import ( + BoundaryMode, + OrthogonalizeMethod, + WaveletCoeffNd, + WaveletDetailDict, +) from .conv_transform_3 import _fwt_pad3 from .matmul_transform import construct_boundary_a, construct_boundary_s from .sparse_math import _batch_dim_mm From e3fe0f4f96b01cef6dcd73e80a3ecabf2f39fe09 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Mon, 1 Jul 2024 10:10:50 +0200 Subject: [PATCH 27/33] fix typing. --- src/ptwt/matmul_transform_2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index d7696a85..c36b4402 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -30,7 +30,8 @@ WaveletCoeff2d, WaveletDetailTuple2d, ) -from .conv_transform_2 import _construct_2d_filt, _fwt_pad2, _get_filter_tensors +from .conv_transform import _get_filter_tensors +from .conv_transform_2 import _construct_2d_filt, _fwt_pad2 from .matmul_transform import ( BaseMatrixWaveDec, construct_boundary_a, From d6d82598e8e84456e0ca063696b163b3355c761f Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 1 Jul 2024 10:24:47 +0200 Subject: [PATCH 28/33] Fix ndim sep trafo usage comments --- src/ptwt/separable_conv_transform.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 26fd88e0..8bb13d07 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -365,9 +365,9 @@ def _fswavedecn( Example: >>> import torch - >>> import ptwt + >>> from ptwt.separable_conv_transform import _fswavedecn >>> data = torch.randn(5, 10, 10, 10) - >>> coeff = ptwt.fswavedecn(data, "haar", ndim=3, level=2) + >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) """ if axes is None: axes = tuple(range(-ndim, 0)) @@ -403,10 +403,10 @@ def _fswaverecn( Example: >>> import torch - >>> import ptwt + >>> from ptwt.separable_conv_transform import _fswavedecn, _fswaverecn >>> data = torch.randn(5, 10, 10, 10) - >>> coeff = ptwt.fswavedecn(data, "haar", ndim=3, level=2) - >>> rec = ptwt.fswaverec3(coeff, "haar", ndim=3) + >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) + >>> rec = _fswaverec3(coeff, "haar", ndim=3) """ if axes is None: axes = tuple(range(-ndim, 0)) From de1b7c8a35f919dc8347dd3f3fe07f4bdf69ff9e Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Mon, 1 Jul 2024 10:25:41 +0200 Subject: [PATCH 29/33] Fix docstr --- src/ptwt/separable_conv_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 8bb13d07..c4d880ed 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -406,7 +406,7 @@ def _fswaverecn( >>> from ptwt.separable_conv_transform import _fswavedecn, _fswaverecn >>> data = torch.randn(5, 10, 10, 10) >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) - >>> rec = _fswaverec3(coeff, "haar", ndim=3) + >>> rec = _fswaverecn(coeff, "haar", ndim=3) """ if axes is None: axes = tuple(range(-ndim, 0)) From 986692f08c5c0d389df1e35ee293d295eb5726de Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Mon, 1 Jul 2024 10:30:52 +0200 Subject: [PATCH 30/33] nd-transforms are out of scope. --- src/ptwt/separable_conv_transform.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 26fd88e0..2246ec60 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -44,6 +44,10 @@ def _separable_conv_dwtn_( All but the first axes are transformed. + Note: + Please note, that ND-Transforms are generally out + of this project's scope. + Args: rec_dict (WaveletDetailDict): The result will be stored here in place. From 77ca36812e7c13720a136fc415e8fe4418c60475 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Mon, 1 Jul 2024 10:31:45 +0200 Subject: [PATCH 31/33] short note. --- src/ptwt/separable_conv_transform.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index a4792e24..c3f653c4 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -45,8 +45,7 @@ def _separable_conv_dwtn_( All but the first axes are transformed. Note: - Please note, that ND-Transforms are generally out - of this project's scope. + ND-Transforms are generally out of this project's scope. Args: rec_dict (WaveletDetailDict): The result will be stored here From cef6a0435271b6200ae5edf30c88b854ec5af296 Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Mon, 1 Jul 2024 11:04:25 +0200 Subject: [PATCH 32/33] move note. --- src/ptwt/separable_conv_transform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index c3f653c4..7b51e505 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -44,9 +44,6 @@ def _separable_conv_dwtn_( All but the first axes are transformed. - Note: - ND-Transforms are generally out of this project's scope. - Args: rec_dict (WaveletDetailDict): The result will be stored here in place. @@ -410,6 +407,9 @@ def _fswaverecn( >>> data = torch.randn(5, 10, 10, 10) >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) >>> rec = _fswaverecn(coeff, "haar", ndim=3) + + Note: + ND-Transforms are generally out of this project's scope. """ if axes is None: axes = tuple(range(-ndim, 0)) From 998152101671f119cbb6a17cf5a45e9a31d0f5ca Mon Sep 17 00:00:00 2001 From: Moritz Wolter Date: Mon, 1 Jul 2024 11:05:48 +0200 Subject: [PATCH 33/33] add note to forward and backward. --- src/ptwt/separable_conv_transform.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 7b51e505..4e715e0d 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -368,6 +368,9 @@ def _fswavedecn( >>> from ptwt.separable_conv_transform import _fswavedecn >>> data = torch.randn(5, 10, 10, 10) >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) + + Note: + ND-Transforms are generally out of this project's scope. """ if axes is None: axes = tuple(range(-ndim, 0))