Skip to content

Commit fde2caa

Browse files
committed
Move reused util funcs to _util and types to constants
1 parent 055533e commit fde2caa

14 files changed

+251
-253
lines changed

src/ptwt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""
22

3-
from ._util import Wavelet, WaveletTensorTuple
43
from .constants import (
4+
Wavelet,
55
WaveletCoeff2d,
66
WaveletCoeff2dSeparable,
77
WaveletCoeffNd,
88
WaveletDetailDict,
99
WaveletDetailTuple2d,
10+
WaveletTensorTuple,
1011
)
1112
from .continuous_transform import cwt
1213
from .conv_transform import wavedec, waverec

src/ptwt/_util.py

Lines changed: 159 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,49 @@
1-
"""Utility methods to compute wavelet decompositions from a dataset."""
1+
"""Utility methods to compute wavelet decompositions."""
22

33
from __future__ import annotations
44

55
import typing
66
from collections.abc import Sequence
7-
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
7+
from typing import Any, Callable, Optional, Union, cast, overload
88

99
import numpy as np
1010
import pywt
1111
import torch
1212

1313
from .constants import (
14+
BoundaryMode,
1415
OrthogonalizeMethod,
16+
Wavelet,
1517
WaveletCoeff2d,
1618
WaveletCoeffNd,
1719
WaveletDetailDict,
1820
WaveletDetailTuple2d,
1921
)
2022

2123

22-
class Wavelet(Protocol):
23-
"""Wavelet object interface, based on the pywt wavelet object."""
24-
25-
name: str
26-
dec_lo: Sequence[float]
27-
dec_hi: Sequence[float]
28-
rec_lo: Sequence[float]
29-
rec_hi: Sequence[float]
30-
dec_len: int
31-
rec_len: int
32-
filter_bank: tuple[
33-
Sequence[float], Sequence[float], Sequence[float], Sequence[float]
34-
]
35-
36-
def __len__(self) -> int:
37-
"""Return the number of filter coefficients."""
38-
return len(self.dec_lo)
39-
40-
41-
class WaveletTensorTuple(NamedTuple):
42-
"""Named tuple containing the wavelet filter bank to use in JIT code."""
43-
44-
dec_lo: torch.Tensor
45-
dec_hi: torch.Tensor
46-
rec_lo: torch.Tensor
47-
rec_hi: torch.Tensor
48-
49-
@property
50-
def dec_len(self) -> int:
51-
"""Length of decomposition filters."""
52-
return len(self.dec_lo)
53-
54-
@property
55-
def rec_len(self) -> int:
56-
"""Length of reconstruction filters."""
57-
return len(self.rec_lo)
58-
59-
@property
60-
def filter_bank(
61-
self,
62-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
63-
"""Filter bank of the wavelet."""
64-
return self
65-
66-
@classmethod
67-
def from_wavelet(cls, wavelet: Wavelet, dtype: torch.dtype) -> WaveletTensorTuple:
68-
"""Construct Wavelet named tuple from wavelet protocol member."""
69-
return cls(
70-
torch.tensor(wavelet.dec_lo, dtype=dtype),
71-
torch.tensor(wavelet.dec_hi, dtype=dtype),
72-
torch.tensor(wavelet.rec_lo, dtype=dtype),
73-
torch.tensor(wavelet.rec_hi, dtype=dtype),
74-
)
24+
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
25+
"""Translate pywt mode strings to PyTorch mode strings.
26+
27+
We support constant, zero, reflect, and periodic.
28+
Unfortunately, "constant" has different meanings in the
29+
Pytorch and PyWavelet communities.
30+
31+
Raises:
32+
ValueError: If the padding mode is not supported.
33+
"""
34+
if pywt_mode == "constant":
35+
return "replicate"
36+
elif pywt_mode == "zero":
37+
return "constant"
38+
elif pywt_mode == "reflect":
39+
return pywt_mode
40+
elif pywt_mode == "periodic":
41+
return "circular"
42+
elif pywt_mode == "symmetric":
43+
# pytorch does not support symmetric mode,
44+
# we have our own implementation.
45+
return pywt_mode
46+
raise ValueError(f"Padding mode not supported: {pywt_mode}")
7547

7648

7749
def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
@@ -90,6 +62,65 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
9062
return wavelet
9163

9264

65+
def _get_len(wavelet: Union[tuple[torch.Tensor, ...], str, Wavelet]) -> int:
66+
"""Get number of filter coefficients for various wavelet data types."""
67+
if isinstance(wavelet, tuple):
68+
return wavelet[0].shape[0]
69+
else:
70+
return len(_as_wavelet(wavelet))
71+
72+
73+
def _get_filter_tensors(
74+
wavelet: Union[Wavelet, str],
75+
flip: bool,
76+
device: Union[torch.device, str],
77+
dtype: torch.dtype = torch.float32,
78+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
79+
"""Convert input wavelet to filter tensors.
80+
81+
Args:
82+
wavelet (Wavelet or str): A pywt wavelet compatible object or
83+
the name of a pywt wavelet.
84+
flip (bool): Flip filters left-right, if true.
85+
device (torch.device or str): PyTorch target device.
86+
dtype (torch.dtype): The data type sets the precision of the
87+
computation. Default: torch.float32.
88+
89+
Returns:
90+
A tuple (dec_lo, dec_hi, rec_lo, rec_hi) containing
91+
the four filter tensors
92+
"""
93+
wavelet = _as_wavelet(wavelet)
94+
device = torch.device(device)
95+
96+
if isinstance(wavelet, tuple):
97+
dec_lo, dec_hi, rec_lo, rec_hi = wavelet
98+
else:
99+
dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
100+
dec_lo_tensor = _create_tensor(dec_lo, flip, device, dtype)
101+
dec_hi_tensor = _create_tensor(dec_hi, flip, device, dtype)
102+
rec_lo_tensor = _create_tensor(rec_lo, flip, device, dtype)
103+
rec_hi_tensor = _create_tensor(rec_hi, flip, device, dtype)
104+
return dec_lo_tensor, dec_hi_tensor, rec_lo_tensor, rec_hi_tensor
105+
106+
107+
def _create_tensor(
108+
filter_seq: Sequence[float], flip: bool, device: torch.device, dtype: torch.dtype
109+
) -> torch.Tensor:
110+
if flip:
111+
if isinstance(filter_seq, torch.Tensor):
112+
return filter_seq.flip(-1).unsqueeze(0).to(device=device, dtype=dtype)
113+
else:
114+
return torch.tensor(filter_seq[::-1], device=device, dtype=dtype).unsqueeze(
115+
0
116+
)
117+
else:
118+
if isinstance(filter_seq, torch.Tensor):
119+
return filter_seq.unsqueeze(0).to(device=device, dtype=dtype)
120+
else:
121+
return torch.tensor(filter_seq, device=device, dtype=dtype).unsqueeze(0)
122+
123+
93124
def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) -> bool:
94125
return boundary_mode in typing.get_args(OrthogonalizeMethod)
95126

@@ -107,14 +138,6 @@ def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
107138
return a_mul * b_mul
108139

109140

110-
def _get_len(wavelet: Union[tuple[torch.Tensor, ...], str, Wavelet]) -> int:
111-
"""Get number of filter coefficients for various wavelet data types."""
112-
if isinstance(wavelet, tuple):
113-
return wavelet[0].shape[0]
114-
else:
115-
return len(_as_wavelet(wavelet))
116-
117-
118141
def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch.Tensor:
119142
padl, padr = pad_list
120143
dimlen = signal.shape[0]
@@ -150,6 +173,79 @@ def _pad_symmetric(
150173
return signal
151174

152175

176+
def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]:
177+
"""Compute the required padding.
178+
179+
Args:
180+
data_len (int): The length of the input vector.
181+
filt_len (int): The size of the used filter.
182+
183+
Returns:
184+
A tuple (padr, padl). The first entry specifies how many numbers
185+
to attach on the right. The second entry covers the left side.
186+
"""
187+
# pad to ensure we see all filter positions and
188+
# for pywt compatability.
189+
# convolution output length:
190+
# see https://arxiv.org/pdf/1603.07285.pdf section 2.3:
191+
# floor([data_len - filt_len]/2) + 1
192+
# should equal pywt output length
193+
# floor((data_len + filt_len - 1)/2)
194+
# => floor([data_len + total_pad - filt_len]/2) + 1
195+
# = floor((data_len + filt_len - 1)/2)
196+
# (data_len + total_pad - filt_len) + 2 = data_len + filt_len - 1
197+
# total_pad = 2*filt_len - 3
198+
199+
# we pad half of the total requried padding on each side.
200+
padr = (2 * filt_len - 3) // 2
201+
padl = (2 * filt_len - 3) // 2
202+
203+
# pad to even singal length.
204+
padr += data_len % 2
205+
206+
return padr, padl
207+
208+
209+
def _adjust_padding_at_reconstruction(
210+
res_ll_size: int, coeff_size: int, pad_end: int, pad_start: int
211+
) -> tuple[int, int]:
212+
pred_size = res_ll_size - (pad_start + pad_end)
213+
next_size = coeff_size
214+
if next_size == pred_size:
215+
pass
216+
elif next_size == pred_size - 1:
217+
pad_end += 1
218+
else:
219+
raise AssertionError(
220+
"padding error, please check if dec and rec wavelets are identical."
221+
)
222+
return pad_end, pad_start
223+
224+
225+
def _flatten_2d_coeff_lst(
226+
coeff_lst_2d: WaveletCoeff2d,
227+
flatten_tensors: bool = True,
228+
) -> list[torch.Tensor]:
229+
"""Flattens a sequence of tensor tuples into a single list.
230+
231+
Args:
232+
coeff_lst_2d (WaveletCoeff2d): A pywt-style
233+
coefficient tuple of torch tensors.
234+
flatten_tensors (bool): If true, 2d tensors are flattened. Defaults to True.
235+
236+
Returns:
237+
A single 1-d list with all original elements.
238+
"""
239+
240+
def _process_tensor(coeff: torch.Tensor) -> torch.Tensor:
241+
return coeff.flatten() if flatten_tensors else coeff
242+
243+
flat_coeff_lst = [_process_tensor(coeff_lst_2d[0])]
244+
for coeff_tuple in coeff_lst_2d[1:]:
245+
flat_coeff_lst.extend(map(_process_tensor, coeff_tuple))
246+
return flat_coeff_lst
247+
248+
153249
def _fold_axes(data: torch.Tensor, keep_no: int) -> tuple[torch.Tensor, list[int]]:
154250
"""Fold unchanged leading dimensions into a single batch dimension.
155251

src/ptwt/constants.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Constants and types used throughout the PyTorch Wavelet Toolbox."""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Sequence
4-
from typing import Literal, NamedTuple, Union
6+
from typing import Literal, NamedTuple, Protocol, Union
57

68
import torch
79
from typing_extensions import TypeAlias, Unpack
@@ -18,6 +20,62 @@
1820
"WaveletDetailDict",
1921
]
2022

23+
24+
class Wavelet(Protocol):
25+
"""Wavelet object interface, based on the pywt wavelet object."""
26+
27+
name: str
28+
dec_lo: Sequence[float]
29+
dec_hi: Sequence[float]
30+
rec_lo: Sequence[float]
31+
rec_hi: Sequence[float]
32+
dec_len: int
33+
rec_len: int
34+
filter_bank: tuple[
35+
Sequence[float], Sequence[float], Sequence[float], Sequence[float]
36+
]
37+
38+
def __len__(self) -> int:
39+
"""Return the number of filter coefficients."""
40+
return len(self.dec_lo)
41+
42+
43+
class WaveletTensorTuple(NamedTuple):
44+
"""Named tuple containing the wavelet filter bank to use in JIT code."""
45+
46+
dec_lo: torch.Tensor
47+
dec_hi: torch.Tensor
48+
rec_lo: torch.Tensor
49+
rec_hi: torch.Tensor
50+
51+
@property
52+
def dec_len(self) -> int:
53+
"""Length of decomposition filters."""
54+
return len(self.dec_lo)
55+
56+
@property
57+
def rec_len(self) -> int:
58+
"""Length of reconstruction filters."""
59+
return len(self.rec_lo)
60+
61+
@property
62+
def filter_bank(
63+
self,
64+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
65+
"""Filter bank of the wavelet."""
66+
return self
67+
68+
@classmethod
69+
def from_wavelet(cls, wavelet: Wavelet, dtype: torch.dtype) -> WaveletTensorTuple:
70+
"""Construct Wavelet named tuple from wavelet protocol member."""
71+
return cls(
72+
torch.tensor(wavelet.dec_lo, dtype=dtype),
73+
torch.tensor(wavelet.dec_hi, dtype=dtype),
74+
torch.tensor(wavelet.rec_lo, dtype=dtype),
75+
torch.tensor(wavelet.rec_hi, dtype=dtype),
76+
)
77+
78+
2179
BoundaryMode = Literal["constant", "zero", "reflect", "periodic", "symmetric"]
2280
"""
2381
This is a type literal for the way of padding used at boundaries.
@@ -37,6 +95,7 @@
3795
for padding options) or ``boundary`` to use boundary wavelets.
3896
"""
3997

98+
4099
# TODO: Add documentation on the different values of PaddingMode
41100

42101
PaddingMode = Literal["full", "valid", "same", "sameshift"]

0 commit comments

Comments
 (0)