Skip to content

Commit 85b898a

Browse files
authored
Merge pull request #93 from v0lta/fix/keep-ndims-Nd
Make preprocessing and postprocessing consistent accross transforms
2 parents b87482f + 9981521 commit 85b898a

9 files changed

+626
-685
lines changed

src/ptwt/_util.py

+393-14
Large diffs are not rendered by default.

src/ptwt/conv_transform.py

+14-112
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
from ._util import (
1515
Wavelet,
1616
_as_wavelet,
17-
_fold_axes,
17+
_check_same_device_dtype,
1818
_get_len,
19-
_is_dtype_supported,
2019
_pad_symmetric,
21-
_unfold_axes,
20+
_postprocess_coeffs,
21+
_postprocess_tensor,
22+
_preprocess_coeffs,
23+
_preprocess_tensor,
2224
)
2325
from .constants import BoundaryMode, WaveletCoeff2d
2426

@@ -211,63 +213,6 @@ def _adjust_padding_at_reconstruction(
211213
return pad_end, pad_start
212214

213215

214-
def _preprocess_tensor_dec1d(
215-
data: torch.Tensor,
216-
) -> tuple[torch.Tensor, list[int]]:
217-
"""Preprocess input tensor dimensions.
218-
219-
Args:
220-
data (torch.Tensor): An input tensor of any shape.
221-
222-
Returns:
223-
A tuple (data, ds) where data is a data tensor of shape
224-
[new_batch, 1, to_process] and ds contains the original shape.
225-
"""
226-
ds = list(data.shape)
227-
if len(ds) == 1:
228-
# assume time series
229-
data = data.unsqueeze(0).unsqueeze(0)
230-
elif len(ds) == 2:
231-
# assume batched time series
232-
data = data.unsqueeze(1)
233-
else:
234-
data, ds = _fold_axes(data, 1)
235-
data = data.unsqueeze(1)
236-
return data, ds
237-
238-
239-
def _postprocess_result_list_dec1d(
240-
result_list: list[torch.Tensor], ds: list[int], axis: int
241-
) -> list[torch.Tensor]:
242-
if len(ds) == 1:
243-
result_list = [r_el.squeeze(0) for r_el in result_list]
244-
elif len(ds) > 2:
245-
# Unfold axes for the wavelets
246-
result_list = [_unfold_axes(fres, ds, 1) for fres in result_list]
247-
else:
248-
result_list = result_list
249-
250-
if axis != -1:
251-
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]
252-
253-
return result_list
254-
255-
256-
def _preprocess_result_list_rec1d(
257-
result_lst: Sequence[torch.Tensor],
258-
) -> tuple[Sequence[torch.Tensor], list[int]]:
259-
# Fold axes for the wavelets
260-
ds = list(result_lst[0].shape)
261-
fold_coeffs: Sequence[torch.Tensor]
262-
if len(ds) == 1:
263-
fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst]
264-
elif len(ds) > 2:
265-
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
266-
else:
267-
fold_coeffs = result_lst
268-
return fold_coeffs, ds
269-
270-
271216
def wavedec(
272217
data: torch.Tensor,
273218
wavelet: Union[Wavelet, str],
@@ -315,10 +260,6 @@ def wavedec(
315260
containing the wavelet coefficients. A denotes
316261
approximation and D detail coefficients.
317262
318-
Raises:
319-
ValueError: If the dtype of the input data tensor is unsupported or
320-
if more than one axis is provided.
321-
322263
Example:
323264
>>> import torch
324265
>>> import ptwt, pywt
@@ -330,16 +271,7 @@ def wavedec(
330271
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
331272
>>> mode='zero', level=2)
332273
"""
333-
if axis != -1:
334-
if isinstance(axis, int):
335-
data = data.swapaxes(axis, -1)
336-
else:
337-
raise ValueError("wavedec transforms a single axis only.")
338-
339-
if not _is_dtype_supported(data.dtype):
340-
raise ValueError(f"Input dtype {data.dtype} not supported")
341-
342-
data, ds = _preprocess_tensor_dec1d(data)
274+
data, ds = _preprocess_tensor(data, ndim=1, axes=axis)
343275

344276
dec_lo, dec_hi, _, _ = _get_filter_tensors(
345277
wavelet, flip=True, device=data.device, dtype=data.dtype
@@ -360,9 +292,7 @@ def wavedec(
360292
result_list.append(res_lo.squeeze(1))
361293
result_list.reverse()
362294

363-
result_list = _postprocess_result_list_dec1d(result_list, ds, axis)
364-
365-
return result_list
295+
return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=axis)
366296

367297

368298
def waverec(
@@ -381,11 +311,6 @@ def waverec(
381311
Returns:
382312
The reconstructed signal tensor.
383313
384-
Raises:
385-
ValueError: If the dtype of the coeffs tensor is unsupported or if the
386-
coefficients have incompatible shapes, dtypes or devices or if
387-
more than one axis is provided.
388-
389314
Example:
390315
>>> import torch
391316
>>> import ptwt, pywt
@@ -399,29 +324,11 @@ def waverec(
399324
>>> pywt.Wavelet('haar'))
400325
401326
"""
402-
torch_device = coeffs[0].device
403-
torch_dtype = coeffs[0].dtype
404-
if not _is_dtype_supported(torch_dtype):
405-
raise ValueError(f"Input dtype {torch_dtype} not supported")
406-
407-
for coeff in coeffs[1:]:
408-
if torch_device != coeff.device:
409-
raise ValueError("coefficients must be on the same device")
410-
elif torch_dtype != coeff.dtype:
411-
raise ValueError("coefficients must have the same dtype")
412-
413-
if axis != -1:
414-
swap = []
415-
if isinstance(axis, int):
416-
for coeff in coeffs:
417-
swap.append(coeff.swapaxes(axis, -1))
418-
coeffs = swap
419-
else:
420-
raise ValueError("waverec transforms a single axis only.")
421-
422-
# fold channels, if necessary.
423-
ds = list(coeffs[0].shape)
424-
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
327+
# fold channels and swap axis, if necessary.
328+
if not isinstance(coeffs, list):
329+
coeffs = list(coeffs)
330+
coeffs, ds = _preprocess_coeffs(coeffs, ndim=1, axes=axis)
331+
torch_device, torch_dtype = _check_same_device_dtype(coeffs)
425332

426333
_, _, rec_lo, rec_hi = _get_filter_tensors(
427334
wavelet, flip=False, device=torch_device, dtype=torch_dtype
@@ -446,12 +353,7 @@ def waverec(
446353
if padr > 0:
447354
res_lo = res_lo[..., :-padr]
448355

449-
if len(ds) == 1:
450-
res_lo = res_lo.squeeze(0)
451-
elif len(ds) > 2:
452-
res_lo = _unfold_axes(res_lo, ds, 1)
453-
454-
if axis != -1:
455-
res_lo = res_lo.swapaxes(axis, -1)
356+
# undo folding and swapping
357+
res_lo = _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=axis)
456358

457359
return res_lo

src/ptwt/conv_transform_2.py

+12-91
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,21 @@
66

77
from __future__ import annotations
88

9-
from functools import partial
109
from typing import Optional, Union
1110

1211
import pywt
1312
import torch
1413

1514
from ._util import (
1615
Wavelet,
17-
_as_wavelet,
18-
_check_axes_argument,
19-
_check_if_tensor,
20-
_fold_axes,
16+
_check_same_device_dtype,
2117
_get_len,
22-
_is_dtype_supported,
23-
_map_result,
2418
_outer,
2519
_pad_symmetric,
26-
_swap_axes,
27-
_undo_swap_axes,
28-
_unfold_axes,
20+
_postprocess_coeffs,
21+
_postprocess_tensor,
22+
_preprocess_coeffs,
23+
_preprocess_tensor,
2924
)
3025
from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d
3126
from .conv_transform import (
@@ -107,32 +102,6 @@ def _fwt_pad2(
107102
return data_pad
108103

109104

110-
def _waverec2d_fold_channels_2d_list(
111-
coeffs: WaveletCoeff2d,
112-
) -> tuple[WaveletCoeff2d, list[int]]:
113-
# fold the input coefficients for processing conv2d_transpose.
114-
ds = list(_check_if_tensor(coeffs[0]).shape)
115-
return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds
116-
117-
118-
def _preprocess_tensor_dec2d(
119-
data: torch.Tensor,
120-
) -> tuple[torch.Tensor, Union[list[int], None]]:
121-
# Preprocess multidimensional input.
122-
ds = None
123-
if len(data.shape) == 2:
124-
data = data.unsqueeze(0).unsqueeze(0)
125-
elif len(data.shape) == 3:
126-
# add a channel dimension for torch.
127-
data = data.unsqueeze(1)
128-
elif len(data.shape) >= 4:
129-
data, ds = _fold_axes(data, 2)
130-
data = data.unsqueeze(1)
131-
elif len(data.shape) == 1:
132-
raise ValueError("More than one input dimension required.")
133-
return data, ds
134-
135-
136105
def wavedec2(
137106
data: torch.Tensor,
138107
wavelet: Union[Wavelet, str],
@@ -183,11 +152,6 @@ def wavedec2(
183152
A tuple containing the wavelet coefficients in pywt order,
184153
see :data:`ptwt.constants.WaveletCoeff2d`.
185154
186-
Raises:
187-
ValueError: If the dimensionality or the dtype of the input data tensor
188-
is unsupported or if the provided ``axes``
189-
input has a length other than two.
190-
191155
Example:
192156
>>> import torch
193157
>>> import ptwt, pywt
@@ -200,17 +164,7 @@ def wavedec2(
200164
>>> level=2, mode="zero")
201165
202166
"""
203-
if not _is_dtype_supported(data.dtype):
204-
raise ValueError(f"Input dtype {data.dtype} not supported")
205-
206-
if tuple(axes) != (-2, -1):
207-
if len(axes) != 2:
208-
raise ValueError("2D transforms work with two axes.")
209-
else:
210-
data = _swap_axes(data, list(axes))
211-
212-
wavelet = _as_wavelet(wavelet)
213-
data, ds = _preprocess_tensor_dec2d(data)
167+
data, ds = _preprocess_tensor(data, ndim=2, axes=axes)
214168
dec_lo, dec_hi, _, _ = _get_filter_tensors(
215169
wavelet, flip=True, device=data.device, dtype=data.dtype
216170
)
@@ -234,13 +188,7 @@ def wavedec2(
234188
res_ll = res_ll.squeeze(1)
235189
result: WaveletCoeff2d = res_ll, *result_lst
236190

237-
if ds:
238-
_unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2)
239-
result = _map_result(result, _unfold_axes2)
240-
241-
if axes != (-2, -1):
242-
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
243-
result = _map_result(result, undo_swap_fn)
191+
result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=axes)
244192

245193
return result
246194

@@ -286,35 +234,16 @@ def waverec2(
286234
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
287235
288236
"""
289-
if tuple(axes) != (-2, -1):
290-
if len(axes) != 2:
291-
raise ValueError("2D transforms work with two axes.")
292-
else:
293-
_check_axes_argument(list(axes))
294-
swap_fn = partial(_swap_axes, axes=list(axes))
295-
coeffs = _map_result(coeffs, swap_fn)
296-
297-
ds = None
298-
wavelet = _as_wavelet(wavelet)
299-
300-
res_ll = _check_if_tensor(coeffs[0])
301-
torch_device = res_ll.device
302-
torch_dtype = res_ll.dtype
303-
304-
if res_ll.dim() >= 4:
305-
# avoid the channel sum, fold the channels into batches.
306-
coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs)
307-
res_ll = _check_if_tensor(coeffs[0])
308-
309-
if not _is_dtype_supported(torch_dtype):
310-
raise ValueError(f"Input dtype {torch_dtype} not supported")
237+
coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes)
238+
torch_device, torch_dtype = _check_same_device_dtype(coeffs)
311239

312240
_, _, rec_lo, rec_hi = _get_filter_tensors(
313241
wavelet, flip=False, device=torch_device, dtype=torch_dtype
314242
)
315243
filt_len = rec_lo.shape[-1]
316244
rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)
317245

246+
res_ll = coeffs[0]
318247
for c_pos, coeff_tuple in enumerate(coeffs[1:]):
319248
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3:
320249
raise ValueError(
@@ -325,11 +254,7 @@ def waverec2(
325254

326255
curr_shape = res_ll.shape
327256
for coeff in coeff_tuple:
328-
if torch_device != coeff.device:
329-
raise ValueError("coefficients must be on the same device")
330-
elif torch_dtype != coeff.dtype:
331-
raise ValueError("coefficients must have the same dtype")
332-
elif coeff.shape != curr_shape:
257+
if coeff.shape != curr_shape:
333258
raise ValueError(
334259
"All coefficients on each level must have the same shape"
335260
)
@@ -362,10 +287,6 @@ def waverec2(
362287
if padr > 0:
363288
res_ll = res_ll[..., :-padr]
364289

365-
if ds:
366-
res_ll = _unfold_axes(res_ll, list(ds), 2)
367-
368-
if axes != (-2, -1):
369-
res_ll = _undo_swap_axes(res_ll, list(axes))
290+
res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes)
370291

371292
return res_ll

0 commit comments

Comments
 (0)