Skip to content

Commit 055533e

Browse files
committed
Add 1d return type
1 parent 4eed41c commit 055533e

File tree

6 files changed

+43
-13
lines changed

6 files changed

+43
-13
lines changed

docs/ref/return-types.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
Wavelet transform return types
66
==============================
77

8+
Transforms in one dimension
9+
---------------------------
10+
11+
.. autodata:: WaveletCoeff1d
12+
13+
814
Transforms in two dimensions
915
----------------------------
1016

src/ptwt/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Constants and types used throughout the PyTorch Wavelet Toolbox."""
22

3+
from collections.abc import Sequence
34
from typing import Literal, NamedTuple, Union
45

56
import torch
@@ -62,6 +63,27 @@
6263
"""
6364

6465

66+
# Note: This data structure was chosen to follow pywt's conventions
67+
WaveletCoeff1d: TypeAlias = Sequence[torch.Tensor]
68+
"""Type alias for 1d wavelet transform results.
69+
70+
This type alias represents the result of a 1d wavelet transform
71+
with :math:`n` levels as a sequence::
72+
73+
[cA_n, cD_n, cD_n-1, …, cD1]
74+
75+
of :math:`n + 1` tensors.
76+
The first entry of the sequence (``cA_n``) is the approximation coefficient tensor.
77+
The following entries (``cD_n`` - ``cD1``) are the detail coefficient tensors
78+
of the respective level.
79+
80+
Note that this type always contains an approximation coefficient tensor but does not
81+
necesseraily contain any detail coefficients.
82+
83+
Alias of ``Sequence[torch.Tensor]``
84+
"""
85+
86+
6587
class WaveletDetailTuple2d(NamedTuple):
6688
"""Detail coefficients of a 2d wavelet transform for a given level.
6789

src/ptwt/conv_transform.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_pad_symmetric,
2222
_unfold_axes,
2323
)
24-
from .constants import BoundaryMode, WaveletCoeff2d
24+
from .constants import BoundaryMode, WaveletCoeff1d, WaveletCoeff2d
2525

2626

2727
def _create_tensor(
@@ -361,12 +361,13 @@ def wavedec(
361361

362362

363363
def waverec(
364-
coeffs: Sequence[torch.Tensor], wavelet: Union[Wavelet, str], axis: int = -1
364+
coeffs: WaveletCoeff1d, wavelet: Union[Wavelet, str], axis: int = -1
365365
) -> torch.Tensor:
366366
"""Reconstruct a 1d signal from wavelet coefficients.
367367
368368
Args:
369-
coeffs (Sequence): The wavelet coefficient sequence produced by wavedec.
369+
coeffs: The wavelet coefficient sequence produced by the forward transform
370+
:func:`wavedec`. See :data:`ptwt.constants.WaveletCoeff1d`.
370371
wavelet (Wavelet or str): A pywt wavelet compatible object or
371372
the name of a pywt wavelet.
372373
Refer to the output from ``pywt.wavelist(kind='discrete')``

src/ptwt/matmul_transform.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
import sys
11-
from collections.abc import Sequence
1211
from typing import Optional, Union
1312

1413
import numpy as np
@@ -21,7 +20,7 @@
2120
_is_dtype_supported,
2221
_unfold_axes,
2322
)
24-
from .constants import OrthogonalizeMethod, PaddingMode
23+
from .constants import OrthogonalizeMethod, PaddingMode, WaveletCoeff1d
2524
from .conv_transform import (
2625
_get_filter_tensors,
2726
_postprocess_result_list_dec1d,
@@ -595,12 +594,12 @@ def _construct_synthesis_matrices(
595594
self.ifwt_matrix_list.append(sn)
596595
curr_length = curr_length // 2
597596

598-
def __call__(self, coefficients: Sequence[torch.Tensor]) -> torch.Tensor:
597+
def __call__(self, coefficients: WaveletCoeff1d) -> torch.Tensor:
599598
"""Run the synthesis or inverse matrix fwt.
600599
601600
Args:
602-
coefficients (Sequence[torch.Tensor]): The coefficients produced
603-
by the forward transform.
601+
coefficients: The coefficients produced by the forward transform
602+
:data:`MatrixWavedec`. See :data:`ptwt.constants.WaveletCoeff1d`.
604603
605604
Returns:
606605
The input signal reconstruction.

src/ptwt/packets.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import collections
6-
from collections.abc import Callable, Iterable, Sequence
6+
from collections.abc import Callable, Iterable
77
from functools import partial
88
from itertools import product
99
from typing import TYPE_CHECKING, Literal, Optional, Union, overload
@@ -17,6 +17,7 @@
1717
ExtendedBoundaryMode,
1818
OrthogonalizeMethod,
1919
PacketNodeOrder,
20+
WaveletCoeff1d,
2021
WaveletCoeff2d,
2122
WaveletCoeffNd,
2223
WaveletDetailTuple2d,
@@ -230,7 +231,7 @@ def _get_wavedec(
230231
def _get_waverec(
231232
self,
232233
length: int,
233-
) -> Callable[[Sequence[torch.Tensor]], torch.Tensor]:
234+
) -> Callable[[WaveletCoeff1d], torch.Tensor]:
234235
if self.mode == "boundary":
235236
if length not in self._matrix_waverec_dict.keys():
236237
self._matrix_waverec_dict[length] = MatrixWaverec(

src/ptwt/stationary_transform.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn.functional as F # noqa:N812
99

1010
from ._util import Wavelet, _as_wavelet, _unfold_axes
11+
from .constants import WaveletCoeff1d
1112
from .conv_transform import (
1213
_get_filter_tensors,
1314
_postprocess_result_list_dec1d,
@@ -109,15 +110,15 @@ def swt(
109110

110111

111112
def iswt(
112-
coeffs: Sequence[torch.Tensor],
113+
coeffs: WaveletCoeff1d,
113114
wavelet: Union[pywt.Wavelet, str],
114115
axis: Optional[int] = -1,
115116
) -> torch.Tensor:
116117
"""Invert a 1d stationary wavelet transform.
117118
118119
Args:
119-
coeffs (Sequence[torch.Tensor]): The coefficients as computed
120-
by the swt function.
120+
coeffs: The wavelet coefficient sequence produced by the forward transform
121+
:func:`swt`. See :data:`ptwt.constants.WaveletCoeff1d`.
121122
wavelet (Wavelet or str): A pywt wavelet compatible object or
122123
the name of a pywt wavelet, as used in the forward transform.
123124
axis (int, optional): The axis the forward trasform was computed over.

0 commit comments

Comments
 (0)