From aacf9ba6c03ecebc85a333f8987405d7de7bd5ba Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 02:44:50 +0200 Subject: [PATCH 01/13] Fix axes for 1d packets --- src/ptwt/packets.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 879f5eff..c17b6b35 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -134,7 +134,7 @@ def transform( """ self.data = {} if maxlevel is None: - maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len) + maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len) self.maxlevel = maxlevel self._recursive_dwt(data, level=0, path="") return self @@ -166,13 +166,15 @@ def reconstruct(self) -> WaveletPacket: for node in self.get_level(level): data_a = self[node + "a"] data_b = self[node + "d"] - rec = self._get_waverec(data_a.shape[-1])([data_a, data_b]) + rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b]) if level > 0: - if rec.shape[-1] != self[node].shape[-1]: + if rec.shape[self.axis] != self[node].shape[self.axis]: assert ( - rec.shape[-1] == self[node].shape[-1] + 1 + rec.shape[self.axis] == self[node].shape[self.axis] + 1 ), "padding error, please open an issue on github" - rec = rec[..., :-1] + rec = rec.swapaxes(self.axis, -1)[..., :-1].swapaxes( + -1, self.axis + ) self[node] = rec return self @@ -232,7 +234,7 @@ def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None: self.data[path] = data if level < self.maxlevel: - res_lo, res_hi = self._get_wavedec(data.shape[-1])(data) + res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data) self._recursive_dwt(res_lo, level + 1, path + "a") self._recursive_dwt(res_hi, level + 1, path + "d") From 43fa6fbf113c49ba3d6c1b7bdfa2e9174f31e88e Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 02:45:09 +0200 Subject: [PATCH 02/13] Remove unsqueezing as this is done in transform --- src/ptwt/packets.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index c17b6b35..4666f5f3 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -113,9 +113,6 @@ def __init__( self.maxlevel: Optional[int] = None self.axis = axis if data is not None: - if len(data.shape) == 1: - # add a batch dimension. - data = data.unsqueeze(0) self.transform(data, maxlevel) else: self.data = {} From 67b94a6ede4f78d56808d1c55fa5ab64e4654d46 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 02:56:27 +0200 Subject: [PATCH 03/13] Fix Packet2d --- src/ptwt/packets.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 4666f5f3..1520178b 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -12,7 +12,7 @@ import pywt import torch -from ._util import Wavelet, _as_wavelet +from ._util import Wavelet, _as_wavelet, _swap_axes, _undo_swap_axes from .constants import ( ExtendedBoundaryMode, OrthogonalizeMethod, @@ -339,13 +339,10 @@ def transform( """ self.data = {} if maxlevel is None: - maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len) + min_transform_size = min(_swap_axes(data, self.axes).shape[-2:]) + maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len) self.maxlevel = maxlevel - if data.dim() == 2: - # add batch dim to unbatched input - data = data.unsqueeze(0) - self._recursive_dwt2d(data, level=0, path="") return self @@ -358,9 +355,8 @@ def reconstruct(self) -> WaveletPacket2D: a reconstruction from the leaves. """ if self.maxlevel is None: - self.maxlevel = pywt.dwt_max_level( - min(self[""].shape[-2:]), self.wavelet.dec_len - ) + min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:]) + self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len) for level in reversed(range(self.maxlevel)): for node in WaveletPacket2D.get_natural_order(level): @@ -368,20 +364,24 @@ def reconstruct(self) -> WaveletPacket2D: data_h = self[node + "h"] data_v = self[node + "v"] data_d = self[node + "d"] - rec = self._get_waverec(data_a.shape[-2:])( + transform_size = _swap_axes(data_a, self.axes).shape[-2:] + rec = self._get_waverec(transform_size)( (data_a, WaveletDetailTuple2d(data_h, data_v, data_d)) ) if level > 0: - if rec.shape[-1] != self[node].shape[-1]: + rec = _swap_axes(rec, self.axes) + swapped_node = _swap_axes(self[node], self.axes) + if rec.shape[-1] != swapped_node.shape[-1]: assert ( - rec.shape[-1] == self[node].shape[-1] + 1 + rec.shape[-1] == swapped_node.shape[-1] + 1 ), "padding error, please open an issue on GitHub" rec = rec[..., :-1] - if rec.shape[-2] != self[node].shape[-2]: + if rec.shape[-2] != swapped_node.shape[-2]: assert ( - rec.shape[-2] == self[node].shape[-2] + 1 + rec.shape[-2] == swapped_node.shape[-2] + 1 ), "padding error, please open an issue on GitHub" rec = rec[..., :-1, :] + rec = _undo_swap_axes(rec, self.axes) self[node] = rec return self @@ -472,7 +472,8 @@ def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None: self.data[path] = data if level < self.maxlevel: - result = self._get_wavedec(data.shape[-2:])(data) + transform_size = _swap_axes(data, self.axes).shape[-2:] + result = self._get_wavedec(transform_size)(data) # assert for type checking assert len(result) == 2 From 4bd1f2873ff38d100f4daeaaeaafbc00af5bfd3b Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 08:53:22 +0200 Subject: [PATCH 04/13] Refactor _compare_trees1 with axis arg --- tests/test_packets.py | 50 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 00cf6c4b..743a4841 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -24,46 +24,44 @@ def _compare_trees1( batch_size: int = 1, transform_mode: bool = False, multiple_transforms: bool = False, + axis: int = -1, ) -> None: data = np.random.rand(batch_size, length) - wavelet = pywt.Wavelet(wavelet_str) - if transform_mode: twp = WaveletPacket( - None, wavelet, mode=ptwt_boundary, maxlevel=max_lev + None, + wavelet_str, + mode=ptwt_boundary, + axis=axis, ).transform(torch.from_numpy(data), maxlevel=max_lev) else: twp = WaveletPacket( - torch.from_numpy(data), wavelet, mode=ptwt_boundary, maxlevel=max_lev + torch.from_numpy(data), + wavelet_str, + mode=ptwt_boundary, + maxlevel=max_lev, + axis=axis, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: twp.transform(torch.from_numpy(data), maxlevel=max_lev) - nodes = twp.get_level(twp.maxlevel) - twp_lst = [] - for node in nodes: - twp_lst.append(twp[node]) - torch_res = torch.cat(twp_lst, -1).numpy() - - np_batches = [] - for batch_index in range(batch_size): - wp = pywt.WaveletPacket( - data=data[batch_index], - wavelet=wavelet, - mode=pywt_boundary, - maxlevel=max_lev, - ) - nodes = [node.path for node in wp.get_level(wp.maxlevel, "freq")] - np_lst = [] - for node in nodes: - np_lst.append(wp[node].data) - np_res = np.concatenate(np_lst, -1) - np_batches.append(np_res) - np_batches = np.stack(np_batches, 0) + torch_res = torch.cat([twp[node] for node in twp.get_level(twp.maxlevel)], axis) + + wp = pywt.WaveletPacket( + data=data, + wavelet=wavelet_str, + mode=pywt_boundary, + maxlevel=max_lev, + axis=axis, + ) + np_res = np.concatenate( + [node.data for node in wp.get_level(wp.maxlevel, "freq")], axis + ) + assert wp.maxlevel == twp.maxlevel - assert np.allclose(torch_res, np_batches) + assert np.allclose(torch_res.numpy(), np_res) def _compare_trees2( From de4cc559a5967c4311d5f62c732e1b614103ffc4 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 09:04:39 +0200 Subject: [PATCH 05/13] Refactor _compare_trees2 with axis arg --- tests/test_packets.py | 70 +++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 32 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 743a4841..10efa468 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -74,51 +74,57 @@ def _compare_trees2( batch_size: int = 1, transform_mode: bool = False, multiple_transforms: bool = False, + axes: tuple[int, int] = (-2, -1), ) -> None: - face = datasets.face()[:height, :width] - face = np.mean(face, axis=-1).astype(np.float64) - wavelet = pywt.Wavelet(wavelet_str) - batch_list = [] - for _ in range(batch_size): - wp_tree = pywt.WaveletPacket2D( - data=face, - wavelet=wavelet, - mode=pywt_boundary, - maxlevel=max_lev, - ) - # Get the full decomposition - wp_keys = list(product(["a", "h", "v", "d"], repeat=wp_tree.maxlevel)) - np_packets = [] - for node in wp_keys: - np_packet = wp_tree["".join(node)].data - np_packets.append(np_packet) - np_packets = np.stack(np_packets, 0) - batch_list.append(np_packets) - batch_np_packets = np.stack(batch_list, 0) + face = datasets.face()[:height, :width].astype(np.float64).mean(-1) + data = np.stack([face] * batch_size, 0) - # get the PyTorch decomposition - pt_data = torch.stack([torch.from_numpy(face)] * batch_size, 0) + wp_tree = pywt.WaveletPacket2D( + data=data, + wavelet=wavelet_str, + mode=pywt_boundary, + maxlevel=max_lev, + axes=axes, + ) + np_packets = np.stack( + [ + node.data + for node in wp_tree.get_level(level=wp_tree.maxlevel, order="natural") + ], + 1, + ) + # get the PyTorch decomposition if transform_mode: ptwt_wp_tree = WaveletPacket2D( - None, wavelet=wavelet, mode=ptwt_boundary - ).transform(pt_data, maxlevel=max_lev) + None, + wavelet=wavelet_str, + mode=ptwt_boundary, + axes=axes, + ).transform(torch.from_numpy(data), maxlevel=max_lev) else: ptwt_wp_tree = WaveletPacket2D( - pt_data, wavelet=wavelet, mode=ptwt_boundary, maxlevel=max_lev + torch.from_numpy(data), + wavelet=wavelet_str, + mode=ptwt_boundary, + maxlevel=max_lev, + axes=axes, ) # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - ptwt_wp_tree.transform(pt_data, maxlevel=max_lev) + ptwt_wp_tree.transform(torch.from_numpy(data), maxlevel=max_lev) + + packets_pt = torch.stack( + [ + ptwt_wp_tree[node] + for node in ptwt_wp_tree.get_natural_order(ptwt_wp_tree.maxlevel) + ], + 1, + ) - packets = [] - for node in wp_keys: - packet = ptwt_wp_tree["".join(node)] - packets.append(packet) - packets_pt = torch.stack(packets, 1).numpy() assert wp_tree.maxlevel == ptwt_wp_tree.maxlevel - assert np.allclose(packets_pt, batch_np_packets) + assert np.allclose(packets_pt.numpy(), np_packets) @pytest.mark.slow From bcc497a41dd9c0368410b175c4098b1c668c1919 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 09:10:31 +0200 Subject: [PATCH 06/13] Add axes to tests --- tests/test_packets.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_packets.py b/tests/test_packets.py index 10efa468..a3184781 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -136,6 +136,7 @@ def _compare_trees2( @pytest.mark.parametrize("batch_size", [2, 1]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) def test_2d_packets( max_lev: Optional[int], wavelet_str: str, @@ -143,6 +144,7 @@ def test_2d_packets( batch_size: int, transform_mode: bool, multiple_transforms: bool, + axes: tuple[int, int], ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 2d packet trees.""" _compare_trees2( @@ -153,6 +155,7 @@ def test_2d_packets( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, + axes=axes, ) @@ -161,11 +164,13 @@ def test_2d_packets( @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("axes", [(-2, -1), (-1, -2), (1, 2), (2, 0), (0, 2)]) def test_boundary_matrix_packets2( max_lev: Optional[int], batch_size: int, transform_mode: bool, multiple_transforms: bool, + axes: tuple[int, int], ) -> None: """Ensure the 2d - sparse matrix haar tree and pywt-tree are the same.""" _compare_trees2( @@ -176,6 +181,7 @@ def test_boundary_matrix_packets2( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, + axes=axes, ) @@ -188,6 +194,7 @@ def test_boundary_matrix_packets2( @pytest.mark.parametrize("batch_size", [2, 1]) @pytest.mark.parametrize("transform_mode", [False, True]) @pytest.mark.parametrize("multiple_transforms", [False, True]) +@pytest.mark.parametrize("axis", [0, -1]) def test_1d_packets( max_lev: int, wavelet_str: str, @@ -195,6 +202,7 @@ def test_1d_packets( batch_size: int, transform_mode: bool, multiple_transforms: bool, + axis: int, ) -> None: """Ensure pywt and ptwt produce equivalent wavelet 1d packet trees.""" _compare_trees1( @@ -205,6 +213,7 @@ def test_1d_packets( batch_size=batch_size, transform_mode=transform_mode, multiple_transforms=multiple_transforms, + axis=axis, ) From efafd458837419b3b5302e2952745899de672220 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 09:10:59 +0200 Subject: [PATCH 07/13] Ensure that test data fits the axes args --- tests/test_packets.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_packets.py b/tests/test_packets.py index a3184781..43d67da5 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -11,6 +11,7 @@ import torch from scipy import datasets +from ptwt._util import _check_axes_argument from ptwt.constants import ExtendedBoundaryMode from ptwt.packets import WaveletPacket, WaveletPacket2D @@ -27,6 +28,8 @@ def _compare_trees1( axis: int = -1, ) -> None: data = np.random.rand(batch_size, length) + data = data.swapaxes(axis, -1) + if transform_mode: twp = WaveletPacket( None, @@ -79,6 +82,10 @@ def _compare_trees2( face = datasets.face()[:height, :width].astype(np.float64).mean(-1) data = np.stack([face] * batch_size, 0) + _check_axes_argument(axes) + data = data.swapaxes(axes[0], -2) + data = data.swapaxes(axes[1], -1) + wp_tree = pywt.WaveletPacket2D( data=data, wavelet=wavelet_str, From 4a9715e85fa0ae6a380149e4eb54d602c9597bbc Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 09:11:11 +0200 Subject: [PATCH 08/13] Fix freq order test --- tests/test_packets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 43d67da5..52f1349e 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -263,10 +263,10 @@ def test_freq_order(level: int, wavelet_str: str, pywt_boundary: str) -> None: print( level, order_el.path, - "".join(tree_el), - order_el.path == "".join(tree_el), + tree_el, + order_el.path == tree_el, ) - assert order_el.path == "".join(tree_el) + assert order_el.path == tree_el def test_packet_harbo_lvl3() -> None: From be10ad9ae4713c0353da40ea1369a228ad07fa76 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 09:33:30 +0200 Subject: [PATCH 09/13] Refactor --- tests/test_packets.py | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 52f1349e..5d04dcc1 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -245,7 +245,7 @@ def test_boundary_matrix_packets1( @pytest.mark.parametrize("level", [1, 2, 3, 4]) @pytest.mark.parametrize("wavelet_str", ["db2"]) @pytest.mark.parametrize("pywt_boundary", ["zero"]) -def test_freq_order(level: int, wavelet_str: str, pywt_boundary: str) -> None: +def test_freq_order_2d(level: int, wavelet_str: str, pywt_boundary: str) -> None: """Test the packets in frequency order.""" face = datasets.face() wavelet = pywt.Wavelet(wavelet_str) @@ -255,18 +255,12 @@ def test_freq_order(level: int, wavelet_str: str, pywt_boundary: str) -> None: mode=pywt_boundary, ) # Get the full decomposition - freq_tree = wp_tree.get_level(level, "freq") - freq_order = WaveletPacket2D.get_freq_order(level) - - for order_list, tree_list in zip(freq_tree, freq_order): - for order_el, tree_el in zip(order_list, tree_list): - print( - level, - order_el.path, - tree_el, - order_el.path == tree_el, - ) - assert order_el.path == tree_el + order_pywt = wp_tree.get_level(level, "freq") + order_ptwt = WaveletPacket2D.get_freq_order(level) + + for node_list, path_list in zip(order_pywt, order_ptwt): + for order_el, order_path in zip(node_list, path_list): + assert order_el.path == order_path def test_packet_harbo_lvl3() -> None: @@ -287,18 +281,11 @@ def filter_bank(self) -> tuple[list[float], ...]: wavelet = pywt.Wavelet("unscaled Haar Wavelet", filter_bank=_MyHaarFilterBank()) twp = WaveletPacket(torch.from_numpy(data), wavelet, mode="reflect") - twp_nodes = twp.get_level(3) - twp_lst = [] - for node in twp_nodes: - twp_lst.append(torch.squeeze(twp[node])) - torch_res = torch.stack(twp_lst).numpy() + torch_res = torch.cat([twp[node] for node in twp.get_level(3)], 0) + wp = pywt.WaveletPacket(data=data, wavelet=wavelet, mode="reflect") - pywt_nodes = [node.path for node in wp.get_level(3, "freq")] - np_lst = [] - for node in pywt_nodes: - np_lst.append(wp[node].data) - np_res = np.concatenate(np_lst) - assert np.allclose(torch_res, np_res) + np_res = np.concatenate([node.data for node in wp.get_level(3, "freq")], 0) + assert np.allclose(torch_res.numpy(), np_res) def test_access_errors_1d() -> None: From 2fa9ec30dd4e50f58d474c73c598abb59ad6f20f Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 11:37:05 +0200 Subject: [PATCH 10/13] Fix 2d data preparation in test --- tests/test_packets.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index 5d04dcc1..c43ab190 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -11,7 +11,7 @@ import torch from scipy import datasets -from ptwt._util import _check_axes_argument +from ptwt._util import _check_axes_argument, _undo_swap_axes from ptwt.constants import ExtendedBoundaryMode from ptwt.packets import WaveletPacket, WaveletPacket2D @@ -80,14 +80,13 @@ def _compare_trees2( axes: tuple[int, int] = (-2, -1), ) -> None: face = datasets.face()[:height, :width].astype(np.float64).mean(-1) - data = np.stack([face] * batch_size, 0) + data = torch.stack([torch.from_numpy(face)] * batch_size, 0) _check_axes_argument(axes) - data = data.swapaxes(axes[0], -2) - data = data.swapaxes(axes[1], -1) + data = _undo_swap_axes(data, axes) wp_tree = pywt.WaveletPacket2D( - data=data, + data=data.numpy(), wavelet=wavelet_str, mode=pywt_boundary, maxlevel=max_lev, @@ -108,10 +107,10 @@ def _compare_trees2( wavelet=wavelet_str, mode=ptwt_boundary, axes=axes, - ).transform(torch.from_numpy(data), maxlevel=max_lev) + ).transform(data, maxlevel=max_lev) else: ptwt_wp_tree = WaveletPacket2D( - torch.from_numpy(data), + data, wavelet=wavelet_str, mode=ptwt_boundary, maxlevel=max_lev, @@ -120,7 +119,7 @@ def _compare_trees2( # if multiple_transform flag is set, recalculcate the packets if multiple_transforms: - ptwt_wp_tree.transform(torch.from_numpy(data), maxlevel=max_lev) + ptwt_wp_tree.transform(data, maxlevel=max_lev) packets_pt = torch.stack( [ From b4f16028dc4069d55a27eb5dc0c31cbd0b0ebfbc Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 11:37:23 +0200 Subject: [PATCH 11/13] Allow case maxlevel=0 --- src/ptwt/packets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 1520178b..23eb9c0d 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -467,7 +467,7 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor: return _fsdict_func def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None: - if not self.maxlevel: + if self.maxlevel is None: raise AssertionError self.data[path] = data From 4777c39208ce04605bebf644a2b524595341f3bc Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 11:38:42 +0200 Subject: [PATCH 12/13] Remove unused import --- tests/test_packets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_packets.py b/tests/test_packets.py index c43ab190..cbc9907d 100644 --- a/tests/test_packets.py +++ b/tests/test_packets.py @@ -2,7 +2,6 @@ # Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de) -from itertools import product from typing import Optional import numpy as np From 48ad03b75d12bff4da09293ff3e2db6bfc0b6c12 Mon Sep 17 00:00:00 2001 From: Felix Blanke Date: Tue, 25 Jun 2024 11:54:41 +0200 Subject: [PATCH 13/13] Change another check to is None --- src/ptwt/packets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index 23eb9c0d..153af538 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -226,7 +226,7 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st return graycode_order def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None: - if not self.maxlevel: + if self.maxlevel is None: raise AssertionError self.data[path] = data