Skip to content

Commit 153f9f2

Browse files
committed
lint.
1 parent 1f7a6f0 commit 153f9f2

27 files changed

+50
-21
lines changed

noxfile.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module implements our CI function calls."""
2+
23
import nox
34

45

src/ptwt/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""
2+
23
from ._util import Wavelet
34
from .continuous_transform import cwt
45
from .conv_transform import wavedec, waverec

src/ptwt/_stationary_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This module implements stationary wavelet transforms."""
2+
23
# Created by moritz wolter, in 2024
34

45
from typing import List, Optional, Union

src/ptwt/_util.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utility methods to compute wavelet decompositions from a dataset."""
2+
23
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union
34

45
import numpy as np

src/ptwt/continuous_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
This module is based on pywt's cwt implementation.
44
"""
5+
56
# Written by the Pytorch wavelet toolbox team in 2024
67
from typing import Any, Tuple, Union
78

src/ptwt/conv_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
This module treats boundaries with edge-padding.
44
"""
5+
56
# Created by moritz wolter, 14.04.20
67
from typing import List, Optional, Sequence, Tuple, Union
78

src/ptwt/conv_transform_2.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
The implementation relies on torch.nn.functional.conv2d and
44
torch.nn.functional.conv_transpose2d under the hood.
55
"""
6+
67
# Written by the Pytorch wavelet toolbox team in 2024
78

89

src/ptwt/conv_transform_3.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
The functions here are based on torch.nn.functional.conv3d and it's transpose.
44
"""
5+
56
# Written by the Pytorch wavelet toolbox team in 2024
67

78
from functools import partial
@@ -212,7 +213,10 @@ def wavedec3(
212213

213214
def _waverec3d_fold_channels_3d_list(
214215
coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
215-
) -> Tuple[List[Union[torch.Tensor, Dict[str, torch.Tensor]]], List[int],]:
216+
) -> Tuple[
217+
List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
218+
List[int],
219+
]:
216220
# fold the input coefficients for processing conv2d_transpose.
217221
fold_coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = []
218222
ds = list(_check_if_tensor(coeffs[0]).shape)

src/ptwt/matmul_transform.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
in Strang Nguyen (p. 32), as well as the description
77
of boundary filters in "Ripples in Mathematics" section 10.3 .
88
"""
9+
910
# Created by moritz (wolter@cs.uni-bonn.de) at 14.04.20
1011
import sys
1112
from typing import List, Optional, Union
@@ -281,7 +282,7 @@ def _construct_analysis_matrices(
281282
f"level {curr_level}, the current signal length {curr_length} is "
282283
f"smaller than the filter length {filt_len}. Therefore, the "
283284
"transformation is only computed up to the decomposition level "
284-
f"{curr_level-1}.\n"
285+
f"{curr_level - 1}.\n"
285286
)
286287
break
287288

@@ -563,7 +564,7 @@ def _construct_synthesis_matrices(
563564
f"level {curr_level}, the current signal length {curr_length} is "
564565
f"smaller than the filter length {filt_len}. Therefore, the "
565566
"transformation is only computed up to the decomposition level "
566-
f"{curr_level-1}.\n"
567+
f"{curr_level - 1}.\n"
567568
)
568569
break
569570

src/ptwt/matmul_transform_2.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
This module uses boundary filters to minimize padding.
44
"""
5+
56
# Written by moritz ( @ wolter.tech ) in 2021
67
import sys
78
from functools import partial
@@ -364,7 +365,8 @@ def _construct_analysis_matrices(
364365
f". At level {curr_level}, at least one of the current signal "
365366
f"height and width ({current_height}, {current_width}) is smaller "
366367
f"then the filter length {filt_len}. Therefore, the transformation "
367-
f"is only computed up to the decomposition level {curr_level-1}.\n"
368+
f"is only computed up to the decomposition "
369+
f" level {curr_level - 1}.\n"
368370
)
369371
break
370372
# the conv matrices require even length inputs.
@@ -678,7 +680,8 @@ def _construct_synthesis_matrices(
678680
f". At level {curr_level}, at least one of the current signal "
679681
f"height and width ({current_height}, {current_width}) is smaller "
680682
f"then the filter length {filt_len}. Therefore, the transformation "
681-
f"is only computed up to the decomposition level {curr_level-1}.\n"
683+
f"is only computed up to the "
684+
f" decomposition level {curr_level - 1}.\n"
682685
)
683686
break
684687
current_height, current_width, pad_tuple = _matrix_pad_2(

src/ptwt/matmul_transform_3.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Implement 3D separable boundary transforms."""
2+
23
# Written by the Pytorch wavelet toolbox team in 2024
34

45
import sys
@@ -122,7 +123,8 @@ def _construct_analysis_matrices(
122123
f"depth, height, and width ({current_depth}, {current_height},"
123124
f"{current_width}) is smaller "
124125
f"then the filter length {filt_len}. Therefore, the transformation "
125-
f"is only computed up to the decomposition level {curr_level-1}.\n"
126+
f"is only computed up to the "
127+
f" decomposition level {curr_level - 1}.\n"
126128
)
127129
break
128130
# the conv matrices require even length inputs.
@@ -340,7 +342,8 @@ def _construct_synthesis_matrices(
340342
f" depth, height and width ({current_depth}, {current_height}, "
341343
f"{current_width}) is smaller than the filter length {filt_len}."
342344
f" Therefore, the transformation "
343-
f"is only computed up to the decomposition level {curr_level-1}.\n"
345+
f"is only computed up to the "
346+
f"decomposition level {curr_level - 1}.\n"
344347
)
345348
break
346349
# the conv matrices require even length inputs.

src/ptwt/packets.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Compute analysis wavelet packet representations."""
2+
23
# Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de)
34

45
import collections
@@ -382,9 +383,7 @@ def get_natural_order(self, level: int) -> List[str]:
382383
"""
383384
return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)]
384385

385-
def _get_wavedec(
386-
self, shape: Tuple[int, ...]
387-
) -> Callable[
386+
def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[
388387
[torch.Tensor],
389388
List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
390389
]:
@@ -415,9 +414,7 @@ def _get_wavedec(
415414
wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes
416415
)
417416

418-
def _get_waverec(
419-
self, shape: Tuple[int, ...]
420-
) -> Callable[
417+
def _get_waverec(self, shape: Tuple[int, ...]) -> Callable[
421418
[List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]],
422419
torch.Tensor,
423420
]:
@@ -471,7 +468,7 @@ def _transform_tuple_to_fsdict_func(
471468
def _fsdict_func(
472469
coeffs: List[
473470
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
474-
]
471+
],
475472
) -> torch.Tensor:
476473
a, (h, v, d) = coeffs
477474
return fsdict_func([cast(torch.Tensor, a), {"ad": h, "da": v, "dd": d}])

src/ptwt/separable_conv_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Under the hood, code in this module transforms all dimensions
77
using torch.nn.functional.conv1d and it's transpose.
88
"""
9+
910
# Written by the Pytorch wavelet toolbox team in 2024
1011
from functools import partial
1112
from typing import Dict, List, Optional, Tuple, Union

src/ptwt/sparse_math.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Efficiently construct fwt operations using sparse matrices."""
2+
23
# Written by moritz ( @ wolter.tech ) 17.09.21
34
from itertools import product
45
from typing import List

src/ptwt/wavelets_learnable.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
See https://arxiv.org/pdf/2004.09569.pdf for more information.
44
"""
5+
56
# Created by moritz wolter@cs.uni-bonn.de, 14.05.20
67
# Inspired by Ripples in Mathematics, Jensen and La Cour-Harbo, Chapter 7.7
78
# import pywt

tests/_mackey_glass.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Generate artificial time-series data for debugging purposes."""
2+
23
from typing import Optional, Union
34

45
import torch

tests/test_convolution_fwt.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the conv-fwt code."""
2+
23
# Written by moritz ( @ wolter.tech ) in 2021
34
import numpy as np
45
import pytest

tests/test_cwt.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the continuous transformation code."""
2+
23
from typing import Union
34

45
import numpy as np

tests/test_dtypes.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
"""Test dtype support for the fwt code."""
2+
23
# Written by moritz ( @ wolter.tech ) in 2025
34
import numpy as np
45
import pytest
56
import pywt
67
import torch
78
from scipy import datasets
89

9-
from src.ptwt.conv_transform import (
10-
_flatten_2d_coeff_lst,
11-
wavedec,
12-
waverec,
13-
)
10+
from src.ptwt.conv_transform import _flatten_2d_coeff_lst
1411
from src.ptwt.conv_transform_2 import wavedec2, waverec2
1512

13+
1614
@pytest.mark.slow
17-
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
15+
@pytest.mark.parametrize(
16+
"dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]
17+
)
1818
def test_2d_wavedec_rec(dtype):
1919
"""Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients.
2020
@@ -45,4 +45,4 @@ def test_2d_wavedec_rec(dtype):
4545
assert np.allclose(flat_coeff_list_pywt, flat_coeff_list_ptwt.numpy(), atol=1e-3)
4646
rec = waverec2(coeff2d, wavelet)
4747
rec = rec.numpy().squeeze().astype(np.float32)
48-
assert np.allclose(face, rec[:, : face.shape[1], : face.shape[2]], atol=1e-3)
48+
assert np.allclose(face, rec[:, : face.shape[1], : face.shape[2]], atol=1e-3)

tests/test_jit.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Ensure pytorch's torch.jit.trace feature works properly."""
2+
23
from typing import NamedTuple
34

45
import numpy as np

tests/test_matrix_fwt.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the fwt and ifwt matrices."""
2+
23
# Written by moritz ( @ wolter.tech ) in 2021
34
from typing import List
45

tests/test_matrix_fwt_2.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test code for the 2d boundary wavelets."""
2+
23
# Created by moritz ( wolter@cs.uni-bonn.de ), 08.09.21
34
import numpy as np
45
import pytest

tests/test_matrix_fwt_3.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the 3d matrix-fwt code."""
2+
23
from typing import List
34

45
import numpy as np

tests/test_packets.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the wavelet packet code."""
2+
23
# Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de)
34
from itertools import product
45

tests/test_sparse_math.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the sparse math code from ptwt.sparse_math."""
2+
23
# Written by moritz ( @ wolter.tech ) in 2021
34
import numpy as np
45
import pytest

tests/test_util.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the util methods."""
2+
23
from typing import Tuple
34

45
import numpy as np

tests/test_wavelet.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test the adaptive wavelet cost functions."""
2+
23
import pytest
34
import pywt
45
import torch

0 commit comments

Comments
 (0)