6
6
7
7
from __future__ import annotations
8
8
9
- from functools import partial
10
9
from typing import Optional , Union
11
10
12
11
import pywt
13
12
import torch
14
13
15
14
from ._util import (
16
15
Wavelet ,
17
- _as_wavelet ,
18
- _check_axes_argument ,
19
- _check_if_tensor ,
20
- _fold_axes ,
16
+ _check_same_device_dtype ,
21
17
_get_len ,
22
- _is_dtype_supported ,
23
- _map_result ,
24
18
_outer ,
25
19
_pad_symmetric ,
26
- _swap_axes ,
27
- _undo_swap_axes ,
28
- _unfold_axes ,
20
+ _postprocess_coeffs ,
21
+ _postprocess_tensor ,
22
+ _preprocess_coeffs ,
23
+ _preprocess_tensor ,
29
24
)
30
25
from .constants import BoundaryMode , WaveletCoeff2d , WaveletDetailTuple2d
31
26
from .conv_transform import (
@@ -107,32 +102,6 @@ def _fwt_pad2(
107
102
return data_pad
108
103
109
104
110
- def _waverec2d_fold_channels_2d_list (
111
- coeffs : WaveletCoeff2d ,
112
- ) -> tuple [WaveletCoeff2d , list [int ]]:
113
- # fold the input coefficients for processing conv2d_transpose.
114
- ds = list (_check_if_tensor (coeffs [0 ]).shape )
115
- return _map_result (coeffs , lambda t : _fold_axes (t , 2 )[0 ]), ds
116
-
117
-
118
- def _preprocess_tensor_dec2d (
119
- data : torch .Tensor ,
120
- ) -> tuple [torch .Tensor , Union [list [int ], None ]]:
121
- # Preprocess multidimensional input.
122
- ds = None
123
- if len (data .shape ) == 2 :
124
- data = data .unsqueeze (0 ).unsqueeze (0 )
125
- elif len (data .shape ) == 3 :
126
- # add a channel dimension for torch.
127
- data = data .unsqueeze (1 )
128
- elif len (data .shape ) >= 4 :
129
- data , ds = _fold_axes (data , 2 )
130
- data = data .unsqueeze (1 )
131
- elif len (data .shape ) == 1 :
132
- raise ValueError ("More than one input dimension required." )
133
- return data , ds
134
-
135
-
136
105
def wavedec2 (
137
106
data : torch .Tensor ,
138
107
wavelet : Union [Wavelet , str ],
@@ -183,11 +152,6 @@ def wavedec2(
183
152
A tuple containing the wavelet coefficients in pywt order,
184
153
see :data:`ptwt.constants.WaveletCoeff2d`.
185
154
186
- Raises:
187
- ValueError: If the dimensionality or the dtype of the input data tensor
188
- is unsupported or if the provided ``axes``
189
- input has a length other than two.
190
-
191
155
Example:
192
156
>>> import torch
193
157
>>> import ptwt, pywt
@@ -200,17 +164,7 @@ def wavedec2(
200
164
>>> level=2, mode="zero")
201
165
202
166
"""
203
- if not _is_dtype_supported (data .dtype ):
204
- raise ValueError (f"Input dtype { data .dtype } not supported" )
205
-
206
- if tuple (axes ) != (- 2 , - 1 ):
207
- if len (axes ) != 2 :
208
- raise ValueError ("2D transforms work with two axes." )
209
- else :
210
- data = _swap_axes (data , list (axes ))
211
-
212
- wavelet = _as_wavelet (wavelet )
213
- data , ds = _preprocess_tensor_dec2d (data )
167
+ data , ds = _preprocess_tensor (data , ndim = 2 , axes = axes )
214
168
dec_lo , dec_hi , _ , _ = _get_filter_tensors (
215
169
wavelet , flip = True , device = data .device , dtype = data .dtype
216
170
)
@@ -234,13 +188,7 @@ def wavedec2(
234
188
res_ll = res_ll .squeeze (1 )
235
189
result : WaveletCoeff2d = res_ll , * result_lst
236
190
237
- if ds :
238
- _unfold_axes2 = partial (_unfold_axes , ds = ds , keep_no = 2 )
239
- result = _map_result (result , _unfold_axes2 )
240
-
241
- if axes != (- 2 , - 1 ):
242
- undo_swap_fn = partial (_undo_swap_axes , axes = axes )
243
- result = _map_result (result , undo_swap_fn )
191
+ result = _postprocess_coeffs (result , ndim = 2 , ds = ds , axes = axes )
244
192
245
193
return result
246
194
@@ -286,35 +234,16 @@ def waverec2(
286
234
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
287
235
288
236
"""
289
- if tuple (axes ) != (- 2 , - 1 ):
290
- if len (axes ) != 2 :
291
- raise ValueError ("2D transforms work with two axes." )
292
- else :
293
- _check_axes_argument (list (axes ))
294
- swap_fn = partial (_swap_axes , axes = list (axes ))
295
- coeffs = _map_result (coeffs , swap_fn )
296
-
297
- ds = None
298
- wavelet = _as_wavelet (wavelet )
299
-
300
- res_ll = _check_if_tensor (coeffs [0 ])
301
- torch_device = res_ll .device
302
- torch_dtype = res_ll .dtype
303
-
304
- if res_ll .dim () >= 4 :
305
- # avoid the channel sum, fold the channels into batches.
306
- coeffs , ds = _waverec2d_fold_channels_2d_list (coeffs )
307
- res_ll = _check_if_tensor (coeffs [0 ])
308
-
309
- if not _is_dtype_supported (torch_dtype ):
310
- raise ValueError (f"Input dtype { torch_dtype } not supported" )
237
+ coeffs , ds = _preprocess_coeffs (coeffs , ndim = 2 , axes = axes )
238
+ torch_device , torch_dtype = _check_same_device_dtype (coeffs )
311
239
312
240
_ , _ , rec_lo , rec_hi = _get_filter_tensors (
313
241
wavelet , flip = False , device = torch_device , dtype = torch_dtype
314
242
)
315
243
filt_len = rec_lo .shape [- 1 ]
316
244
rec_filt = _construct_2d_filt (lo = rec_lo , hi = rec_hi )
317
245
246
+ res_ll = coeffs [0 ]
318
247
for c_pos , coeff_tuple in enumerate (coeffs [1 :]):
319
248
if not isinstance (coeff_tuple , tuple ) or len (coeff_tuple ) != 3 :
320
249
raise ValueError (
@@ -325,11 +254,7 @@ def waverec2(
325
254
326
255
curr_shape = res_ll .shape
327
256
for coeff in coeff_tuple :
328
- if torch_device != coeff .device :
329
- raise ValueError ("coefficients must be on the same device" )
330
- elif torch_dtype != coeff .dtype :
331
- raise ValueError ("coefficients must have the same dtype" )
332
- elif coeff .shape != curr_shape :
257
+ if coeff .shape != curr_shape :
333
258
raise ValueError (
334
259
"All coefficients on each level must have the same shape"
335
260
)
@@ -362,10 +287,6 @@ def waverec2(
362
287
if padr > 0 :
363
288
res_ll = res_ll [..., :- padr ]
364
289
365
- if ds :
366
- res_ll = _unfold_axes (res_ll , list (ds ), 2 )
367
-
368
- if axes != (- 2 , - 1 ):
369
- res_ll = _undo_swap_axes (res_ll , list (axes ))
290
+ res_ll = _postprocess_tensor (res_ll , ndim = 2 , ds = ds , axes = axes )
370
291
371
292
return res_ll
0 commit comments