diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7403759f..a93e7182 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ jobs: strategy: matrix: os: [ ubuntu-latest ] - python-version: [3.9, 3.11] + python-version: [3.12, 3.11] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -26,7 +26,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, 3.11] + python-version: [3.12, 3.11] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -42,7 +42,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, 3.11] + python-version: [3.12, 3.11] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -58,7 +58,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.11] + python-version: [3.12] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/noxfile.py b/noxfile.py index 011e123e..25f5bf30 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,4 +1,5 @@ """This module implements our CI function calls.""" + import nox diff --git a/setup.cfg b/setup.cfg index 6b6c4b3d..5b72a4ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,8 +40,8 @@ classifiers = Intended Audience :: Science/Research Operating System :: OS Independent Programming Language :: Python - Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 Programming Language :: Python :: 3 :: Only Topic :: Scientific/Engineering :: Artificial Intelligence @@ -56,7 +56,7 @@ install_requires = pytest nox -python_requires = >=3.9 +python_requires = >=3.11 packages = find: package_dir = diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py index 03a6f6da..34808704 100644 --- a/src/ptwt/__init__.py +++ b/src/ptwt/__init__.py @@ -1,4 +1,5 @@ """Differentiable and gpu enabled fast wavelet transforms in PyTorch.""" + from ._util import Wavelet from .continuous_transform import cwt from .conv_transform import wavedec, waverec diff --git a/src/ptwt/_stationary_transform.py b/src/ptwt/_stationary_transform.py index 5d2028b7..2beaf9f8 100644 --- a/src/ptwt/_stationary_transform.py +++ b/src/ptwt/_stationary_transform.py @@ -1,4 +1,5 @@ """This module implements stationary wavelet transforms.""" + # Created by moritz wolter, in 2024 from typing import List, Optional, Union diff --git a/src/ptwt/_util.py b/src/ptwt/_util.py index 50277b42..f76f973a 100644 --- a/src/ptwt/_util.py +++ b/src/ptwt/_util.py @@ -1,4 +1,5 @@ """Utility methods to compute wavelet decompositions from a dataset.""" + from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union import numpy as np diff --git a/src/ptwt/continuous_transform.py b/src/ptwt/continuous_transform.py index d91ccb1f..db95d06b 100644 --- a/src/ptwt/continuous_transform.py +++ b/src/ptwt/continuous_transform.py @@ -2,6 +2,7 @@ This module is based on pywt's cwt implementation. """ + # Written by the Pytorch wavelet toolbox team in 2024 from typing import Any, Tuple, Union @@ -195,7 +196,7 @@ def _integrate( if type(arr) is np.ndarray: integral = np.cumsum(arr) elif type(arr) is torch.Tensor: - integral = torch.cumsum(arr, -1) + integral = torch.cumsum(arr, -1) # type: ignore else: raise TypeError("Only ndarrays or tensors are integratable.") integral *= step @@ -271,8 +272,8 @@ def wavefun( """Define a grid and evaluate the wavelet on it.""" length = 2**precision # load the bounds from untyped pywt code. - lower_bound: float = float(self.lower_bound) - upper_bound: float = float(self.upper_bound) + lower_bound: float = float(self.lower_bound) # type: ignore + upper_bound: float = float(self.upper_bound) # type: ignore grid = torch.linspace( lower_bound, upper_bound, @@ -291,10 +292,10 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor: shannon = ( torch.sqrt(self.bandwidth) * ( - torch.sin(torch.pi * self.bandwidth * grid_values) # type: ignore + torch.sin(torch.pi * self.bandwidth * grid_values) / (torch.pi * self.bandwidth * grid_values) ) - * torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore + * torch.exp(1j * 2 * torch.pi * self.center * grid_values) ) return shannon @@ -306,8 +307,8 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor: """Return numerical values for the wavelet on a grid.""" morlet = ( 1.0 - / torch.sqrt(torch.pi * self.bandwidth) # type: ignore + / torch.sqrt(torch.pi * self.bandwidth) * torch.exp(-(grid_values**2) / self.bandwidth) - * torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore + * torch.exp(1j * 2 * torch.pi * self.center * grid_values) ) return morlet diff --git a/src/ptwt/conv_transform.py b/src/ptwt/conv_transform.py index 475f3aa3..e34f9c2f 100644 --- a/src/ptwt/conv_transform.py +++ b/src/ptwt/conv_transform.py @@ -2,6 +2,7 @@ This module treats boundaries with edge-padding. """ + # Created by moritz wolter, 14.04.20 from typing import List, Optional, Sequence, Tuple, Union diff --git a/src/ptwt/conv_transform_2.py b/src/ptwt/conv_transform_2.py index b933f727..f0091612 100644 --- a/src/ptwt/conv_transform_2.py +++ b/src/ptwt/conv_transform_2.py @@ -3,6 +3,7 @@ The implementation relies on torch.nn.functional.conv2d and torch.nn.functional.conv_transpose2d under the hood. """ + # Written by the Pytorch wavelet toolbox team in 2024 @@ -215,7 +216,7 @@ def wavedec2( result_lst = _map_result(result_lst, _unfold_axes2) if axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) + undo_swap_fn = partial(_undo_swap_axes, axes=list(axes)) result_lst = _map_result(result_lst, undo_swap_fn) return result_lst diff --git a/src/ptwt/conv_transform_3.py b/src/ptwt/conv_transform_3.py index dc82cee4..8744e2cf 100644 --- a/src/ptwt/conv_transform_3.py +++ b/src/ptwt/conv_transform_3.py @@ -2,6 +2,7 @@ The functions here are based on torch.nn.functional.conv3d and it's transpose. """ + # Written by the Pytorch wavelet toolbox team in 2024 from functools import partial @@ -204,7 +205,7 @@ def wavedec3( result_lst = _map_result(result_lst, _unfold_axes_fn) if tuple(axes) != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) + undo_swap_fn = partial(_undo_swap_axes, axes=list(axes)) result_lst = _map_result(result_lst, undo_swap_fn) return result_lst @@ -212,7 +213,10 @@ def wavedec3( def _waverec3d_fold_channels_3d_list( coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], -) -> Tuple[List[Union[torch.Tensor, Dict[str, torch.Tensor]]], List[int],]: +) -> Tuple[ + List[Union[torch.Tensor, Dict[str, torch.Tensor]]], + List[int], +]: # fold the input coefficients for processing conv2d_transpose. fold_coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = [] ds = list(_check_if_tensor(coeffs[0]).shape) diff --git a/src/ptwt/matmul_transform.py b/src/ptwt/matmul_transform.py index a31191d9..ce7b1876 100644 --- a/src/ptwt/matmul_transform.py +++ b/src/ptwt/matmul_transform.py @@ -6,6 +6,7 @@ in Strang Nguyen (p. 32), as well as the description of boundary filters in "Ripples in Mathematics" section 10.3 . """ + # Created by moritz (wolter@cs.uni-bonn.de) at 14.04.20 import sys from typing import List, Optional, Union @@ -281,7 +282,7 @@ def _construct_analysis_matrices( f"level {curr_level}, the current signal length {curr_length} is " f"smaller than the filter length {filt_len}. Therefore, the " "transformation is only computed up to the decomposition level " - f"{curr_level-1}.\n" + f"{curr_level - 1}.\n" ) break @@ -563,7 +564,7 @@ def _construct_synthesis_matrices( f"level {curr_level}, the current signal length {curr_length} is " f"smaller than the filter length {filt_len}. Therefore, the " "transformation is only computed up to the decomposition level " - f"{curr_level-1}.\n" + f"{curr_level - 1}.\n" ) break diff --git a/src/ptwt/matmul_transform_2.py b/src/ptwt/matmul_transform_2.py index 1a1642ab..084eaa44 100644 --- a/src/ptwt/matmul_transform_2.py +++ b/src/ptwt/matmul_transform_2.py @@ -2,6 +2,7 @@ This module uses boundary filters to minimize padding. """ + # Written by moritz ( @ wolter.tech ) in 2021 import sys from functools import partial @@ -364,7 +365,8 @@ def _construct_analysis_matrices( f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the decomposition " + f" level {curr_level - 1}.\n" ) break # the conv matrices require even length inputs. @@ -544,7 +546,7 @@ def __call__( split_list = _map_result(split_list, _unfold_axes2) if self.axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) + undo_swap_fn = partial(_undo_swap_axes, axes=list(self.axes)) split_list = _map_result(split_list, undo_swap_fn) return split_list[::-1] @@ -678,7 +680,8 @@ def _construct_synthesis_matrices( f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the " + f" decomposition level {curr_level - 1}.\n" ) break current_height, current_width, pad_tuple = _matrix_pad_2( diff --git a/src/ptwt/matmul_transform_3.py b/src/ptwt/matmul_transform_3.py index cbf02de8..a7898668 100644 --- a/src/ptwt/matmul_transform_3.py +++ b/src/ptwt/matmul_transform_3.py @@ -1,4 +1,5 @@ """Implement 3D separable boundary transforms.""" + # Written by the Pytorch wavelet toolbox team in 2024 import sys @@ -122,7 +123,8 @@ def _construct_analysis_matrices( f"depth, height, and width ({current_depth}, {current_height}," f"{current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the " + f" decomposition level {curr_level - 1}.\n" ) break # the conv matrices require even length inputs. @@ -137,7 +139,7 @@ def _construct_analysis_matrices( matrix_construction_fun = partial( construct_boundary_a, wavelet=self.wavelet, - boundary=self.boundary, + boundary=self.boundary, # type: ignore device=device, dtype=dtype, ) @@ -265,7 +267,7 @@ def _split_rec( split_list = _map_result(split_list, _unfold_axes_fn) if self.axes != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) + undo_swap_fn = partial(_undo_swap_axes, axes=list(self.axes)) split_list = _map_result(split_list, undo_swap_fn) return split_list[::-1] @@ -340,7 +342,8 @@ def _construct_synthesis_matrices( f" depth, height and width ({current_depth}, {current_height}, " f"{current_width}) is smaller than the filter length {filt_len}." f" Therefore, the transformation " - f"is only computed up to the decomposition level {curr_level-1}.\n" + f"is only computed up to the " + f"decomposition level {curr_level - 1}.\n" ) break # the conv matrices require even length inputs. diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index b1cecc66..4fabeba9 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -1,4 +1,5 @@ """Compute analysis wavelet packet representations.""" + # Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de) import collections @@ -104,7 +105,7 @@ def __init__( if len(data.shape) == 1: # add a batch dimension. data = data.unsqueeze(0) - self.transform(data, maxlevel) # type: ignore + self.transform(data, maxlevel) else: self.data = {} @@ -176,7 +177,11 @@ def _get_wavedec( return self._matrix_wavedec_dict[length] else: return partial( - wavedec, wavelet=self.wavelet, level=1, mode=self.mode, axis=self.axis + wavedec, + wavelet=self.wavelet, + level=1, + mode=self.mode, # type: ignore + axis=self.axis, ) def _get_waverec( @@ -382,9 +387,7 @@ def get_natural_order(self, level: int) -> List[str]: """ return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)] - def _get_wavedec( - self, shape: Tuple[int, ...] - ) -> Callable[ + def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[ [torch.Tensor], List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], ]: @@ -415,9 +418,7 @@ def _get_wavedec( wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes ) - def _get_waverec( - self, shape: Tuple[int, ...] - ) -> Callable[ + def _get_waverec(self, shape: Tuple[int, ...]) -> Callable[ [List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]], torch.Tensor, ]: @@ -471,7 +472,7 @@ def _transform_tuple_to_fsdict_func( def _fsdict_func( coeffs: List[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] - ] + ], ) -> torch.Tensor: a, (h, v, d) = coeffs return fsdict_func([cast(torch.Tensor, a), {"ad": h, "da": v, "dd": d}]) diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index 1cd1a322..ddaea945 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -6,6 +6,7 @@ Under the hood, code in this module transforms all dimensions using torch.nn.functional.conv1d and it's transpose. """ + # Written by the Pytorch wavelet toolbox team in 2024 from functools import partial from typing import Dict, List, Optional, Tuple, Union @@ -164,9 +165,9 @@ def _separable_conv_waverecn( approx: torch.Tensor = coeffs[0] for level_dict in coeffs[1:]: - keys = list(level_dict.keys()) - level_dict["a" * max(map(len, keys))] = approx - approx = _separable_conv_idwtn(level_dict, wavelet) + keys = list(level_dict.keys()) # type: ignore + level_dict["a" * max(map(len, keys))] = approx # type: ignore + approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore return approx @@ -235,7 +236,7 @@ def fswavedec2( res = _map_result(res, _unfold_axes2) if axes != (-2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) + undo_swap_fn = partial(_undo_swap_axes, axes=list(axes)) res = _map_result(res, undo_swap_fn) return res @@ -307,7 +308,7 @@ def fswavedec3( res = _map_result(res, _unfold_axes3) if axes != (-3, -2, -1): - undo_swap_fn = partial(_undo_swap_axes, axes=axes) + undo_swap_fn = partial(_undo_swap_axes, axes=list(axes)) res = _map_result(res, undo_swap_fn) return res diff --git a/src/ptwt/sparse_math.py b/src/ptwt/sparse_math.py index aa1ec232..e4eb91d0 100644 --- a/src/ptwt/sparse_math.py +++ b/src/ptwt/sparse_math.py @@ -1,4 +1,5 @@ """Efficiently construct fwt operations using sparse matrices.""" + # Written by moritz ( @ wolter.tech ) 17.09.21 from itertools import product from typing import List diff --git a/src/ptwt/wavelets_learnable.py b/src/ptwt/wavelets_learnable.py index 196b90c0..240abd26 100644 --- a/src/ptwt/wavelets_learnable.py +++ b/src/ptwt/wavelets_learnable.py @@ -2,6 +2,7 @@ See https://arxiv.org/pdf/2004.09569.pdf for more information. """ + # Created by moritz wolter@cs.uni-bonn.de, 14.05.20 # Inspired by Ripples in Mathematics, Jensen and La Cour-Harbo, Chapter 7.7 # import pywt diff --git a/tests/_mackey_glass.py b/tests/_mackey_glass.py index fb8e9d8c..63f17ee3 100644 --- a/tests/_mackey_glass.py +++ b/tests/_mackey_glass.py @@ -1,4 +1,5 @@ """Generate artificial time-series data for debugging purposes.""" + from typing import Optional, Union import torch diff --git a/tests/test_convolution_fwt.py b/tests/test_convolution_fwt.py index d7663c28..001a41e6 100644 --- a/tests/test_convolution_fwt.py +++ b/tests/test_convolution_fwt.py @@ -1,4 +1,5 @@ """Test the conv-fwt code.""" + # Written by moritz ( @ wolter.tech ) in 2021 import numpy as np import pytest diff --git a/tests/test_cwt.py b/tests/test_cwt.py index 01979d36..44148bb3 100644 --- a/tests/test_cwt.py +++ b/tests/test_cwt.py @@ -1,4 +1,5 @@ """Test the continuous transformation code.""" + from typing import Union import numpy as np diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 00000000..a6c84d41 --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,48 @@ +"""Test dtype support for the fwt code.""" + +# Written by moritz ( @ wolter.tech ) in 2025 +import numpy as np +import pytest +import pywt +import torch +from scipy import datasets + +from src.ptwt.conv_transform import _flatten_2d_coeff_lst +from src.ptwt.conv_transform_2 import wavedec2, waverec2 + + +@pytest.mark.slow +@pytest.mark.parametrize( + "dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16] +) +def test_2d_wavedec_rec(dtype): + """Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients. + + wavedec2 and waverec2 must invert each other. + """ + mode = "reflect" + level = 2 + size = (32, 32) + face = np.transpose( + datasets.face()[256 : (512 + size[0]), 256 : (512 + size[1])], [2, 0, 1] + ).astype(np.float32) + wavelet = pywt.Wavelet("db2") + to_transform = torch.from_numpy(face).to(torch.float32) + coeff2d = wavedec2(to_transform, wavelet, mode=mode, level=level) + pywt_coeff2d = pywt.wavedec2(face, wavelet, mode=mode, level=level) + for pos, coeffs in enumerate(pywt_coeff2d): + if type(coeffs) is tuple: + for tuple_pos, tuple_el in enumerate(coeffs): + assert ( + tuple_el.shape == coeff2d[pos][tuple_pos].shape + ), "pywt and ptwt should produce the same shapes." + else: + assert ( + coeffs.shape == coeff2d[pos].shape + ), "pywt and ptwt should produce the same shapes." + flat_coeff_list_pywt = np.concatenate(_flatten_2d_coeff_lst(pywt_coeff2d), -1) + flat_coeff_list_ptwt = torch.cat(_flatten_2d_coeff_lst(coeff2d), -1) + assert np.allclose(flat_coeff_list_pywt, flat_coeff_list_ptwt.numpy(), atol=1e-3) + rec = waverec2(coeff2d, wavelet) + rec = rec.numpy().squeeze().astype(np.float32) + assert np.allclose(face, rec[:, : face.shape[1], : face.shape[2]], atol=1e-3) diff --git a/tests/test_jit.py b/tests/test_jit.py index 1f34e482..b9a28725 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,4 +1,5 @@ """Ensure pytorch's torch.jit.trace feature works properly.""" + from typing import NamedTuple import numpy as np diff --git a/tests/test_matrix_fwt.py b/tests/test_matrix_fwt.py index 0cc82484..edbd12bd 100644 --- a/tests/test_matrix_fwt.py +++ b/tests/test_matrix_fwt.py @@ -1,4 +1,5 @@ """Test the fwt and ifwt matrices.""" + # Written by moritz ( @ wolter.tech ) in 2021 from typing import List diff --git a/tests/test_matrix_fwt_2.py b/tests/test_matrix_fwt_2.py index d975a32e..9ccdb010 100644 --- a/tests/test_matrix_fwt_2.py +++ b/tests/test_matrix_fwt_2.py @@ -1,4 +1,5 @@ """Test code for the 2d boundary wavelets.""" + # Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21 import numpy as np import pytest diff --git a/tests/test_matrix_fwt_3.py b/tests/test_matrix_fwt_3.py index 09790810..9d555043 100644 --- a/tests/test_matrix_fwt_3.py +++ b/tests/test_matrix_fwt_3.py @@ -1,4 +1,5 @@ """Test the 3d matrix-fwt code.""" + from typing import List import numpy as np diff --git a/tests/test_packets.py b/tests/test_packets.py index 488b8035..6b2c0e90 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -1,4 +1,5 @@ """Test the wavelet packet code.""" + # Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de) from itertools import product diff --git a/tests/test_sparse_math.py b/tests/test_sparse_math.py index 3cc45679..06068269 100644 --- a/tests/test_sparse_math.py +++ b/tests/test_sparse_math.py @@ -1,4 +1,5 @@ """Test the sparse math code from ptwt.sparse_math.""" + # Written by moritz ( @ wolter.tech ) in 2021 import numpy as np import pytest diff --git a/tests/test_util.py b/tests/test_util.py index a6b6016c..9123acc9 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,5 @@ """Test the util methods.""" + from typing import Tuple import numpy as np diff --git a/tests/test_wavelet.py b/tests/test_wavelet.py index eea5baf2..b9f8a8bc 100644 --- a/tests/test_wavelet.py +++ b/tests/test_wavelet.py @@ -1,4 +1,5 @@ """Test the adaptive wavelet cost functions.""" + import pytest import pywt import torch