Skip to content

Commit 91726be

Browse files
committed
Rename boundary arg to boundary_orthogonalization
1 parent 5862bb5 commit 91726be

File tree

6 files changed

+158
-69
lines changed

6 files changed

+158
-69
lines changed

src/ptwt/_util.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import typing
6-
from collections.abc import Sequence
7-
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
7+
import warnings
8+
from collections.abc import Callable, Sequence
9+
from typing import Any, NamedTuple, Optional, Protocol, Union, cast, overload
810

911
import numpy as np
1012
import pywt
1113
import torch
14+
from typing_extensions import ParamSpec, TypeVar
1215

1316
from .constants import (
1417
OrthogonalizeMethod,
@@ -253,3 +256,55 @@ def _map_result(
253256
Union[list[WaveletDetailDict], list[WaveletDetailTuple2d]], result_lst
254257
)
255258
return approx, *cast_result_lst
259+
260+
261+
Param = ParamSpec("Param")
262+
RetType = TypeVar("RetType")
263+
264+
265+
def _deprecated_alias(
266+
**aliases: str,
267+
) -> Callable[[Callable[Param, RetType]], Callable[Param, RetType]]:
268+
"""Handle deprecated function and method arguments.
269+
270+
Use as follows::
271+
272+
@deprecated_alias(old_arg='new_arg')
273+
def myfunc(new_arg):
274+
...
275+
276+
Adapted from https://stackoverflow.com/a/49802489
277+
"""
278+
279+
def rename_kwargs(
280+
func_name: str,
281+
kwargs: Param.kwargs,
282+
aliases: dict[str, str],
283+
) -> None:
284+
"""Rename deprecated kwarg."""
285+
for alias, new in aliases.items():
286+
if alias in kwargs:
287+
if new in kwargs:
288+
raise TypeError(
289+
f"{func_name} received both {alias} and {new} as arguments!"
290+
f" {alias} is deprecated, use {new} instead."
291+
)
292+
warnings.warn(
293+
message=(
294+
f"`{alias}` is deprecated as an argument to `{func_name}`; use"
295+
f" `{new}` instead."
296+
),
297+
category=DeprecationWarning,
298+
stacklevel=3,
299+
)
300+
kwargs[new] = kwargs.pop(alias)
301+
302+
def deco(f: Callable[Param, RetType]) -> Callable[Param, RetType]:
303+
@functools.wraps(f)
304+
def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
305+
rename_kwargs(f.__name__, kwargs, aliases)
306+
return f(*args, **kwargs)
307+
308+
return wrapper
309+
310+
return deco

src/ptwt/matmul_transform.py

+31-20
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ._util import (
1818
Wavelet,
1919
_as_wavelet,
20+
_deprecated_alias,
2021
_is_boundary_mode_supported,
2122
_is_dtype_supported,
2223
_unfold_axes,
@@ -182,12 +183,13 @@ class MatrixWavedec(BaseMatrixWaveDec):
182183
>>> coefficients = matrix_wavedec(data_torch)
183184
"""
184185

186+
@_deprecated_alias(boundary="boundary_orthogonalization")
185187
def __init__(
186188
self,
187189
wavelet: Union[Wavelet, str],
188190
level: Optional[int] = None,
189191
axis: Optional[int] = -1,
190-
boundary: OrthogonalizeMethod = "qr",
192+
boundary_orthogonalization: OrthogonalizeMethod = "qr",
191193
odd_coeff_padding_mode: BoundaryMode = "zero",
192194
) -> None:
193195
"""Create a sparse matrix fast wavelet transform object.
@@ -202,8 +204,9 @@ def __init__(
202204
None.
203205
axis (int, optional): The axis we would like to transform.
204206
Defaults to -1.
205-
boundary : The method used for boundary filter treatment,
206-
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
207+
boundary_orthogonalization: The method used to orthogonalize
208+
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
209+
Defaults to 'qr'.
207210
odd_coeff_padding_mode: The constructed FWT matrices require inputs
208211
with even lengths. Thus, any odd-length approximation coefficients
209212
are padded to an even length using this mode,
@@ -217,8 +220,8 @@ def __init__(
217220
"""
218221
self.wavelet = _as_wavelet(wavelet)
219222
self.level = level
220-
self.boundary = boundary
221223
self.odd_coeff_padding_mode = odd_coeff_padding_mode
224+
self.boundary_orthogonalization = boundary_orthogonalization
222225

223226
if isinstance(axis, int):
224227
self.axis = axis
@@ -231,7 +234,7 @@ def __init__(
231234
self.padded = False
232235
self.size_list: list[int] = []
233236

234-
if not _is_boundary_mode_supported(self.boundary):
237+
if not _is_boundary_mode_supported(self.boundary_orthogonalization):
235238
raise NotImplementedError
236239

237240
if self.wavelet.dec_len != self.wavelet.rec_len:
@@ -311,7 +314,7 @@ def _construct_analysis_matrices(
311314
an = construct_boundary_a(
312315
self.wavelet,
313316
curr_length,
314-
boundary=self.boundary,
317+
boundary_orthogonalization=self.boundary_orthogonalization,
315318
device=device,
316319
dtype=dtype,
317320
)
@@ -402,11 +405,12 @@ def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]:
402405
return result_list
403406

404407

408+
@_deprecated_alias(boundary="boundary_orthogonalization")
405409
def construct_boundary_a(
406410
wavelet: Union[Wavelet, str],
407411
length: int,
408412
device: Union[torch.device, str] = "cpu",
409-
boundary: OrthogonalizeMethod = "qr",
413+
boundary_orthogonalization: OrthogonalizeMethod = "qr",
410414
dtype: torch.dtype = torch.float64,
411415
) -> torch.Tensor:
412416
"""Construct a boundary-wavelet filter 1d-analysis matrix.
@@ -415,8 +419,9 @@ def construct_boundary_a(
415419
wavelet (Wavelet or str): A pywt wavelet compatible object or
416420
the name of a pywt wavelet.
417421
length (int): The number of entries in the input signal.
418-
boundary : The method used for boundary filter treatment,
419-
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
422+
boundary_orthogonalization: The method used to orthogonalize
423+
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
424+
Defaults to 'qr'.
420425
device: Where to place the matrix. Choose cpu or cuda.
421426
Defaults to cpu.
422427
dtype: Choose float32 or float64.
@@ -426,15 +431,16 @@ def construct_boundary_a(
426431
"""
427432
wavelet = _as_wavelet(wavelet)
428433
a_full = _construct_a(wavelet, length, dtype=dtype, device=device)
429-
a_orth = orthogonalize(a_full, wavelet.dec_len, method=boundary)
434+
a_orth = orthogonalize(a_full, wavelet.dec_len, method=boundary_orthogonalization)
430435
return a_orth
431436

432437

438+
@_deprecated_alias(boundary="boundary_orthogonalization")
433439
def construct_boundary_s(
434440
wavelet: Union[Wavelet, str],
435441
length: int,
436442
device: Union[torch.device, str] = "cpu",
437-
boundary: OrthogonalizeMethod = "qr",
443+
boundary_orthogonalization: OrthogonalizeMethod = "qr",
438444
dtype: torch.dtype = torch.float64,
439445
) -> torch.Tensor:
440446
"""Construct a boundary-wavelet filter 1d-synthesis matarix.
@@ -445,8 +451,9 @@ def construct_boundary_s(
445451
length (int): The number of entries in the input signal.
446452
device (torch.device): Where to place the matrix.
447453
Choose cpu or cuda. Defaults to cpu.
448-
boundary : The method used for boundary filter treatment,
449-
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
454+
boundary_orthogonalization: The method used to orthogonalize
455+
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
456+
Defaults to 'qr'.
450457
dtype: Choose torch.float32 or torch.float64.
451458
Defaults to torch.float64.
452459
@@ -455,7 +462,9 @@ def construct_boundary_s(
455462
"""
456463
wavelet = _as_wavelet(wavelet)
457464
s_full = _construct_s(wavelet, length, dtype=dtype, device=device)
458-
s_orth = orthogonalize(s_full.transpose(1, 0), wavelet.rec_len, method=boundary)
465+
s_orth = orthogonalize(
466+
s_full.transpose(1, 0), wavelet.rec_len, method=boundary_orthogonalization
467+
)
459468
return s_orth.transpose(1, 0)
460469

461470

@@ -476,11 +485,12 @@ class MatrixWaverec(object):
476485
>>> reconstruction = matrix_waverec(coefficients)
477486
"""
478487

488+
@_deprecated_alias(boundary="boundary_orthogonalization")
479489
def __init__(
480490
self,
481491
wavelet: Union[Wavelet, str],
482492
axis: int = -1,
483-
boundary: OrthogonalizeMethod = "qr",
493+
boundary_orthogonalization: OrthogonalizeMethod = "qr",
484494
) -> None:
485495
"""Create the inverse matrix-based fast wavelet transformation.
486496
@@ -491,16 +501,17 @@ def __init__(
491501
for possible choices.
492502
axis (int): The axis transformed by the original decomposition
493503
defaults to -1 or the last axis.
494-
boundary : The method used for boundary filter treatment,
495-
see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'.
504+
boundary_orthogonalization: The method used to orthogonalize
505+
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
506+
Defaults to 'qr'.
496507
497508
Raises:
498509
NotImplementedError: If the selected `boundary` mode is not supported.
499510
ValueError: If the wavelet filters have different lengths or if
500511
axis is not an integer.
501512
"""
502513
self.wavelet = _as_wavelet(wavelet)
503-
self.boundary = boundary
514+
self.boundary_orthogonalization = boundary_orthogonalization
504515
if isinstance(axis, int):
505516
self.axis = axis
506517
else:
@@ -511,7 +522,7 @@ def __init__(
511522
self.input_length: Optional[int] = None
512523
self.padded = False
513524

514-
if not _is_boundary_mode_supported(self.boundary):
525+
if not _is_boundary_mode_supported(self.boundary_orthogonalization):
515526
raise NotImplementedError
516527

517528
if self.wavelet.dec_len != self.wavelet.rec_len:
@@ -591,7 +602,7 @@ def _construct_synthesis_matrices(
591602
sn = construct_boundary_s(
592603
self.wavelet,
593604
curr_length,
594-
boundary=self.boundary,
605+
boundary_orthogonalization=self.boundary_orthogonalization,
595606
device=device,
596607
dtype=dtype,
597608
)

0 commit comments

Comments
 (0)