1
- """Utility methods to compute wavelet decompositions from a dataset ."""
1
+ """Utility methods to compute wavelet decompositions."""
2
2
3
3
from __future__ import annotations
4
4
5
5
import typing
6
6
from collections .abc import Sequence
7
- from typing import Any , Callable , NamedTuple , Optional , Protocol , Union , cast , overload
7
+ from typing import Any , Callable , Optional , Union , cast , overload
8
8
9
9
import numpy as np
10
10
import pywt
11
11
import torch
12
12
13
13
from .constants import (
14
+ BoundaryMode ,
14
15
OrthogonalizeMethod ,
16
+ Wavelet ,
15
17
WaveletCoeff2d ,
16
18
WaveletCoeffNd ,
17
19
WaveletDetailDict ,
18
20
WaveletDetailTuple2d ,
19
21
)
20
22
21
23
22
- class Wavelet (Protocol ):
23
- """Wavelet object interface, based on the pywt wavelet object."""
24
-
25
- name : str
26
- dec_lo : Sequence [float ]
27
- dec_hi : Sequence [float ]
28
- rec_lo : Sequence [float ]
29
- rec_hi : Sequence [float ]
30
- dec_len : int
31
- rec_len : int
32
- filter_bank : tuple [
33
- Sequence [float ], Sequence [float ], Sequence [float ], Sequence [float ]
34
- ]
35
-
36
- def __len__ (self ) -> int :
37
- """Return the number of filter coefficients."""
38
- return len (self .dec_lo )
39
-
40
-
41
- class WaveletTensorTuple (NamedTuple ):
42
- """Named tuple containing the wavelet filter bank to use in JIT code."""
43
-
44
- dec_lo : torch .Tensor
45
- dec_hi : torch .Tensor
46
- rec_lo : torch .Tensor
47
- rec_hi : torch .Tensor
48
-
49
- @property
50
- def dec_len (self ) -> int :
51
- """Length of decomposition filters."""
52
- return len (self .dec_lo )
53
-
54
- @property
55
- def rec_len (self ) -> int :
56
- """Length of reconstruction filters."""
57
- return len (self .rec_lo )
58
-
59
- @property
60
- def filter_bank (
61
- self ,
62
- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
63
- """Filter bank of the wavelet."""
64
- return self
65
-
66
- @classmethod
67
- def from_wavelet (cls , wavelet : Wavelet , dtype : torch .dtype ) -> WaveletTensorTuple :
68
- """Construct Wavelet named tuple from wavelet protocol member."""
69
- return cls (
70
- torch .tensor (wavelet .dec_lo , dtype = dtype ),
71
- torch .tensor (wavelet .dec_hi , dtype = dtype ),
72
- torch .tensor (wavelet .rec_lo , dtype = dtype ),
73
- torch .tensor (wavelet .rec_hi , dtype = dtype ),
74
- )
24
+ def _translate_boundary_strings (pywt_mode : BoundaryMode ) -> str :
25
+ """Translate pywt mode strings to PyTorch mode strings.
26
+
27
+ We support constant, zero, reflect, and periodic.
28
+ Unfortunately, "constant" has different meanings in the
29
+ Pytorch and PyWavelet communities.
30
+
31
+ Raises:
32
+ ValueError: If the padding mode is not supported.
33
+ """
34
+ if pywt_mode == "constant" :
35
+ return "replicate"
36
+ elif pywt_mode == "zero" :
37
+ return "constant"
38
+ elif pywt_mode == "reflect" :
39
+ return pywt_mode
40
+ elif pywt_mode == "periodic" :
41
+ return "circular"
42
+ elif pywt_mode == "symmetric" :
43
+ # pytorch does not support symmetric mode,
44
+ # we have our own implementation.
45
+ return pywt_mode
46
+ raise ValueError (f"Padding mode not supported: { pywt_mode } " )
75
47
76
48
77
49
def _as_wavelet (wavelet : Union [Wavelet , str ]) -> Wavelet :
@@ -90,6 +62,65 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
90
62
return wavelet
91
63
92
64
65
+ def _get_len (wavelet : Union [tuple [torch .Tensor , ...], str , Wavelet ]) -> int :
66
+ """Get number of filter coefficients for various wavelet data types."""
67
+ if isinstance (wavelet , tuple ):
68
+ return wavelet [0 ].shape [0 ]
69
+ else :
70
+ return len (_as_wavelet (wavelet ))
71
+
72
+
73
+ def _get_filter_tensors (
74
+ wavelet : Union [Wavelet , str ],
75
+ flip : bool ,
76
+ device : Union [torch .device , str ],
77
+ dtype : torch .dtype = torch .float32 ,
78
+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
79
+ """Convert input wavelet to filter tensors.
80
+
81
+ Args:
82
+ wavelet (Wavelet or str): A pywt wavelet compatible object or
83
+ the name of a pywt wavelet.
84
+ flip (bool): Flip filters left-right, if true.
85
+ device (torch.device or str): PyTorch target device.
86
+ dtype (torch.dtype): The data type sets the precision of the
87
+ computation. Default: torch.float32.
88
+
89
+ Returns:
90
+ A tuple (dec_lo, dec_hi, rec_lo, rec_hi) containing
91
+ the four filter tensors
92
+ """
93
+ wavelet = _as_wavelet (wavelet )
94
+ device = torch .device (device )
95
+
96
+ if isinstance (wavelet , tuple ):
97
+ dec_lo , dec_hi , rec_lo , rec_hi = wavelet
98
+ else :
99
+ dec_lo , dec_hi , rec_lo , rec_hi = wavelet .filter_bank
100
+ dec_lo_tensor = _create_tensor (dec_lo , flip , device , dtype )
101
+ dec_hi_tensor = _create_tensor (dec_hi , flip , device , dtype )
102
+ rec_lo_tensor = _create_tensor (rec_lo , flip , device , dtype )
103
+ rec_hi_tensor = _create_tensor (rec_hi , flip , device , dtype )
104
+ return dec_lo_tensor , dec_hi_tensor , rec_lo_tensor , rec_hi_tensor
105
+
106
+
107
+ def _create_tensor (
108
+ filter_seq : Sequence [float ], flip : bool , device : torch .device , dtype : torch .dtype
109
+ ) -> torch .Tensor :
110
+ if flip :
111
+ if isinstance (filter_seq , torch .Tensor ):
112
+ return filter_seq .flip (- 1 ).unsqueeze (0 ).to (device = device , dtype = dtype )
113
+ else :
114
+ return torch .tensor (filter_seq [::- 1 ], device = device , dtype = dtype ).unsqueeze (
115
+ 0
116
+ )
117
+ else :
118
+ if isinstance (filter_seq , torch .Tensor ):
119
+ return filter_seq .unsqueeze (0 ).to (device = device , dtype = dtype )
120
+ else :
121
+ return torch .tensor (filter_seq , device = device , dtype = dtype ).unsqueeze (0 )
122
+
123
+
93
124
def _is_boundary_mode_supported (boundary_mode : Optional [OrthogonalizeMethod ]) -> bool :
94
125
return boundary_mode in typing .get_args (OrthogonalizeMethod )
95
126
@@ -107,14 +138,6 @@ def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
107
138
return a_mul * b_mul
108
139
109
140
110
- def _get_len (wavelet : Union [tuple [torch .Tensor , ...], str , Wavelet ]) -> int :
111
- """Get number of filter coefficients for various wavelet data types."""
112
- if isinstance (wavelet , tuple ):
113
- return wavelet [0 ].shape [0 ]
114
- else :
115
- return len (_as_wavelet (wavelet ))
116
-
117
-
118
141
def _pad_symmetric_1d (signal : torch .Tensor , pad_list : tuple [int , int ]) -> torch .Tensor :
119
142
padl , padr = pad_list
120
143
dimlen = signal .shape [0 ]
@@ -150,6 +173,79 @@ def _pad_symmetric(
150
173
return signal
151
174
152
175
176
+ def _get_pad (data_len : int , filt_len : int ) -> tuple [int , int ]:
177
+ """Compute the required padding.
178
+
179
+ Args:
180
+ data_len (int): The length of the input vector.
181
+ filt_len (int): The size of the used filter.
182
+
183
+ Returns:
184
+ A tuple (padr, padl). The first entry specifies how many numbers
185
+ to attach on the right. The second entry covers the left side.
186
+ """
187
+ # pad to ensure we see all filter positions and
188
+ # for pywt compatability.
189
+ # convolution output length:
190
+ # see https://arxiv.org/pdf/1603.07285.pdf section 2.3:
191
+ # floor([data_len - filt_len]/2) + 1
192
+ # should equal pywt output length
193
+ # floor((data_len + filt_len - 1)/2)
194
+ # => floor([data_len + total_pad - filt_len]/2) + 1
195
+ # = floor((data_len + filt_len - 1)/2)
196
+ # (data_len + total_pad - filt_len) + 2 = data_len + filt_len - 1
197
+ # total_pad = 2*filt_len - 3
198
+
199
+ # we pad half of the total requried padding on each side.
200
+ padr = (2 * filt_len - 3 ) // 2
201
+ padl = (2 * filt_len - 3 ) // 2
202
+
203
+ # pad to even singal length.
204
+ padr += data_len % 2
205
+
206
+ return padr , padl
207
+
208
+
209
+ def _adjust_padding_at_reconstruction (
210
+ res_ll_size : int , coeff_size : int , pad_end : int , pad_start : int
211
+ ) -> tuple [int , int ]:
212
+ pred_size = res_ll_size - (pad_start + pad_end )
213
+ next_size = coeff_size
214
+ if next_size == pred_size :
215
+ pass
216
+ elif next_size == pred_size - 1 :
217
+ pad_end += 1
218
+ else :
219
+ raise AssertionError (
220
+ "padding error, please check if dec and rec wavelets are identical."
221
+ )
222
+ return pad_end , pad_start
223
+
224
+
225
+ def _flatten_2d_coeff_lst (
226
+ coeff_lst_2d : WaveletCoeff2d ,
227
+ flatten_tensors : bool = True ,
228
+ ) -> list [torch .Tensor ]:
229
+ """Flattens a sequence of tensor tuples into a single list.
230
+
231
+ Args:
232
+ coeff_lst_2d (WaveletCoeff2d): A pywt-style
233
+ coefficient tuple of torch tensors.
234
+ flatten_tensors (bool): If true, 2d tensors are flattened. Defaults to True.
235
+
236
+ Returns:
237
+ A single 1-d list with all original elements.
238
+ """
239
+
240
+ def _process_tensor (coeff : torch .Tensor ) -> torch .Tensor :
241
+ return coeff .flatten () if flatten_tensors else coeff
242
+
243
+ flat_coeff_lst = [_process_tensor (coeff_lst_2d [0 ])]
244
+ for coeff_tuple in coeff_lst_2d [1 :]:
245
+ flat_coeff_lst .extend (map (_process_tensor , coeff_tuple ))
246
+ return flat_coeff_lst
247
+
248
+
153
249
def _fold_axes (data : torch .Tensor , keep_no : int ) -> tuple [torch .Tensor , list [int ]]:
154
250
"""Fold unchanged leading dimensions into a single batch dimension.
155
251
0 commit comments