Skip to content

Commit b87482f

Browse files
authored
Merge pull request #97 from v0lta/feature/odd-length-padding-mode
Allow specification of odd length boundary padding in MatrixWavedec
2 parents fa7af3d + f2db5e3 commit b87482f

12 files changed

+403
-145
lines changed

src/ptwt/_util.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import typing
6-
from collections.abc import Sequence
7-
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
7+
import warnings
8+
from collections.abc import Callable, Sequence
9+
from typing import Any, NamedTuple, Optional, Protocol, Union, cast, overload
810

911
import numpy as np
1012
import pywt
1113
import torch
14+
from typing_extensions import ParamSpec, TypeVar
1215

1316
from .constants import (
1417
OrthogonalizeMethod,
@@ -90,8 +93,10 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
9093
return wavelet
9194

9295

93-
def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) -> bool:
94-
return boundary_mode in typing.get_args(OrthogonalizeMethod)
96+
def _is_orthogonalize_method_supported(
97+
orthogonalization: Optional[OrthogonalizeMethod],
98+
) -> bool:
99+
return orthogonalization in typing.get_args(OrthogonalizeMethod)
95100

96101

97102
def _is_dtype_supported(dtype: torch.dtype) -> bool:
@@ -253,3 +258,55 @@ def _map_result(
253258
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst
254259
)
255260
return approx, *cast_result_lst
261+
262+
263+
Param = ParamSpec("Param")
264+
RetType = TypeVar("RetType")
265+
266+
267+
def _deprecated_alias(
268+
**aliases: str,
269+
) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]:
270+
"""Handle deprecated function and method arguments.
271+
272+
Use as follows::
273+
274+
@_deprecated_alias(old_arg='new_arg')
275+
def myfunc(new_arg):
276+
...
277+
278+
Adapted from https://stackoverflow.com/a/49802489
279+
"""
280+
281+
def rename_kwargs(
282+
func_name: str,
283+
kwargs: Param.kwargs,
284+
aliases: dict[str, str],
285+
) -> None:
286+
"""Rename deprecated kwarg."""
287+
for alias, new in aliases.items():
288+
if alias in kwargs:
289+
if new in kwargs:
290+
raise TypeError(
291+
f"{func_name} received both {alias} and {new} as arguments!"
292+
f" {alias} is deprecated, use {new} instead."
293+
)
294+
warnings.warn(
295+
message=(
296+
f"`{alias}` is deprecated as an argument to `{func_name}`; use"
297+
f" `{new}` instead."
298+
),
299+
category=DeprecationWarning,
300+
stacklevel=3,
301+
)
302+
kwargs[new] = kwargs.pop(alias)
303+
304+
def deco(f: Callable[Param, RetType]) -> Callable[Param, RetType]:
305+
@functools.wraps(f)
306+
def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
307+
rename_kwargs(f.__name__, kwargs, aliases)
308+
return f(*args, **kwargs)
309+
310+
return wrapper
311+
312+
return deco

src/ptwt/conv_transform.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _fwt_pad(
135135
wavelet: Union[Wavelet, str],
136136
*,
137137
mode: Optional[BoundaryMode] = None,
138+
padding: Optional[tuple[int, int]] = None,
138139
) -> torch.Tensor:
139140
"""Pad the input signal to make the fwt matrix work.
140141
@@ -144,21 +145,25 @@ def _fwt_pad(
144145
data (torch.Tensor): Input data ``[batch_size, 1, time]``
145146
wavelet (Wavelet or str): A pywt wavelet compatible object or
146147
the name of a pywt wavelet.
147-
mode :
148-
The desired padding mode for extending the signal along the edges.
148+
mode: The desired padding mode for extending the signal along the edges.
149149
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
150+
padding (tuple[int, int], optional): A tuple (padl, padr) with the
151+
number of padded values on the left and right side of the last
152+
axes of `data`. If None, the padding values are computed based
153+
on the signal shape and the wavelet length. Defaults to None.
150154
151155
Returns:
152156
A PyTorch tensor with the padded input data
153157
"""
154-
wavelet = _as_wavelet(wavelet)
155-
156158
# convert pywt to pytorch convention.
157159
if mode is None:
158160
mode = "reflect"
159161
pytorch_mode = _translate_boundary_strings(mode)
160162

161-
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
163+
if padding is None:
164+
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
165+
else:
166+
padl, padr = padding
162167
if pytorch_mode == "symmetric":
163168
data_pad = _pad_symmetric(data, [(padl, padr)])
164169
else:

src/ptwt/conv_transform_2.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def _fwt_pad2(
6565
wavelet: Union[Wavelet, str],
6666
*,
6767
mode: Optional[BoundaryMode] = None,
68+
padding: Optional[tuple[int, int, int, int]] = None,
6869
) -> torch.Tensor:
6970
"""Pad data for the 2d FWT.
7071
@@ -76,9 +77,13 @@ def _fwt_pad2(
7677
the name of a pywt wavelet.
7778
Refer to the output from ``pywt.wavelist(kind='discrete')``
7879
for possible choices.
79-
mode :
80-
The desired padding mode for extending the signal along the edges.
80+
mode: The desired padding mode for extending the signal along the edges.
8181
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
82+
padding (tuple[int, int, int, int], optional): A tuple
83+
(padl, padr, padt, padb) with the number of padded values
84+
on the left, right, top and bottom side of the last two
85+
axes of `data`. If None, the padding values are computed based
86+
on the signal shape and the wavelet length. Defaults to None.
8287
8388
Returns:
8489
The padded output tensor.
@@ -87,9 +92,12 @@ def _fwt_pad2(
8792
if mode is None:
8893
mode = "reflect"
8994
pytorch_mode = _translate_boundary_strings(mode)
90-
wavelet = _as_wavelet(wavelet)
91-
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
92-
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
95+
96+
if padding is None:
97+
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
98+
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
99+
else:
100+
padl, padr, padt, padb = padding
93101
if pytorch_mode == "symmetric":
94102
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
95103
else:

src/ptwt/conv_transform_3.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
6565

6666

6767
def _fwt_pad3(
68-
data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode
68+
data: torch.Tensor,
69+
wavelet: Union[Wavelet, str],
70+
*,
71+
mode: BoundaryMode,
72+
padding: Optional[tuple[int, int, int, int, int, int]] = None,
6973
) -> torch.Tensor:
7074
"""Pad data for the 3d-FWT.
7175
@@ -77,19 +81,26 @@ def _fwt_pad3(
7781
the name of a pywt wavelet.
7882
Refer to the output from ``pywt.wavelist(kind='discrete')``
7983
for possible choices.
80-
mode :
81-
The desired padding mode for extending the signal along the edges.
84+
mode: The desired padding mode for extending the signal along the edges.
8285
See :data:`ptwt.constants.BoundaryMode`.
86+
padding (tuple[int, int, int, int, int, int], optional): A tuple
87+
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
88+
with the number of padded values on the respective side of the
89+
last three axes of `data`.
90+
If None, the padding values are computed based
91+
on the signal shape and the wavelet length. Defaults to None.
8392
8493
Returns:
8594
The padded output tensor.
8695
"""
8796
pytorch_mode = _translate_boundary_strings(mode)
8897

89-
wavelet = _as_wavelet(wavelet)
90-
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
91-
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
92-
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
98+
if padding is None:
99+
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
100+
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
101+
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
102+
else:
103+
pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back = padding
93104
if pytorch_mode == "symmetric":
94105
data_pad = _pad_symmetric(
95106
data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)]

0 commit comments

Comments
 (0)