Skip to content

Commit 4e271a8

Browse files
authored
Merge pull request #95 from v0lta/feature/packets-level-order
Improve packets order interface
2 parents 848d0f7 + 7226e66 commit 4e271a8

File tree

3 files changed

+106
-5
lines changed

3 files changed

+106
-5
lines changed

src/ptwt/constants.py

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@
5151
Choose ``gramschmidt`` if ``qr`` runs out of memory.
5252
"""
5353

54+
PacketNodeOrder = Literal["freq", "natural"]
55+
"""
56+
This is a type literal for the order of wavelet packet tree nodes.
57+
58+
- frequency order (``freq``)
59+
- natural order (``natural``)
60+
"""
61+
5462

5563
class WaveletDetailTuple2d(NamedTuple):
5664
"""Detail coefficients of a 2d wavelet transform for a given level.

src/ptwt/packets.py

+58-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Sequence
77
from functools import partial
88
from itertools import product
9-
from typing import TYPE_CHECKING, Callable, Optional, Union
9+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload
1010

1111
import numpy as np
1212
import pywt
@@ -16,6 +16,7 @@
1616
from .constants import (
1717
ExtendedBoundaryMode,
1818
OrthogonalizeMethod,
19+
PacketNodeOrder,
1920
WaveletCoeff2d,
2021
WaveletCoeffNd,
2122
WaveletDetailTuple2d,
@@ -203,18 +204,34 @@ def _get_waverec(
203204
else:
204205
return partial(waverec, wavelet=self.wavelet, axis=self.axis)
205206

206-
def get_level(self, level: int) -> list[str]:
207-
"""Return the graycode-ordered paths to the filter tree nodes.
207+
@staticmethod
208+
def get_level(level: int, order: PacketNodeOrder = "freq") -> list[str]:
209+
"""Return the paths to the filter tree nodes.
208210
209211
Args:
210212
level (int): The depth of the tree.
213+
order: The order the paths are in.
214+
Choose from frequency order (``freq``) and
215+
natural order (``natural``).
216+
Defaults to ``freq``.
211217
212218
Returns:
213219
A list with the paths to each node.
220+
221+
Raises:
222+
ValueError: If `order` is neither ``freq`` nor ``natural``.
214223
"""
215-
return self._get_graycode_order(level)
224+
if order == "freq":
225+
return WaveletPacket._get_graycode_order(level)
226+
elif order == "natural":
227+
return ["".join(p) for p in product(["a", "d"], repeat=level)]
228+
else:
229+
raise ValueError(
230+
f"Unsupported order '{order}'. Choose from 'freq' and 'natural'."
231+
)
216232

217-
def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[str]:
233+
@staticmethod
234+
def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
218235
graycode_order = [x, y]
219236
for _ in range(level - 1):
220237
graycode_order = [x + path for path in graycode_order] + [
@@ -514,6 +531,42 @@ def __getitem__(self, key: str) -> torch.Tensor:
514531
)
515532
return super().__getitem__(key)
516533

534+
@overload
535+
@staticmethod
536+
def get_level(level: int, order: Literal["freq"]) -> list[list[str]]: ...
537+
538+
@overload
539+
@staticmethod
540+
def get_level(level: int, order: Literal["natural"]) -> list[str]: ...
541+
542+
@staticmethod
543+
def get_level(
544+
level: int, order: PacketNodeOrder = "freq"
545+
) -> Union[list[str], list[list[str]]]:
546+
"""Return the paths to the filter tree nodes.
547+
548+
Args:
549+
level (int): The depth of the tree.
550+
order: The order the paths are in.
551+
Choose from frequency order (``freq``) and
552+
natural order (``natural``).
553+
Defaults to ``freq``.
554+
555+
Returns:
556+
A list with the paths to each node.
557+
558+
Raises:
559+
ValueError: If `order` is neither ``freq`` nor ``natural``.
560+
"""
561+
if order == "freq":
562+
return WaveletPacket2D.get_freq_order(level)
563+
elif order == "natural":
564+
return WaveletPacket2D.get_natural_order(level)
565+
else:
566+
raise ValueError(
567+
f"Unsupported order '{order}'. Choose from 'freq' and 'natural'."
568+
)
569+
517570
@staticmethod
518571
def get_natural_order(level: int) -> list[str]:
519572
"""Get the natural ordering for a given decomposition level.

tests/test_packets.py

+40
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,26 @@ def test_boundary_matrix_packets1(
240240
)
241241

242242

243+
@pytest.mark.parametrize("level", [1, 2, 3, 4])
244+
@pytest.mark.parametrize("wavelet_str", ["db2"])
245+
@pytest.mark.parametrize("pywt_boundary", ["zero"])
246+
@pytest.mark.parametrize("order", ["freq", "natural"])
247+
def test_order_1d(level: int, wavelet_str: str, pywt_boundary: str, order: str) -> None:
248+
"""Test the packets in natural order."""
249+
data = np.random.rand(2, 256)
250+
wp_tree = pywt.WaveletPacket(
251+
data=data,
252+
wavelet=wavelet_str,
253+
mode=pywt_boundary,
254+
)
255+
# Get the full decomposition
256+
order_pywt = wp_tree.get_level(level, order)
257+
order_ptwt = WaveletPacket.get_level(level, order)
258+
259+
for order_el, order_path in zip(order_pywt, order_ptwt):
260+
assert order_el.path == order_path
261+
262+
243263
@pytest.mark.parametrize("level", [1, 2, 3, 4])
244264
@pytest.mark.parametrize("wavelet_str", ["db2"])
245265
@pytest.mark.parametrize("pywt_boundary", ["zero"])
@@ -261,6 +281,26 @@ def test_freq_order_2d(level: int, wavelet_str: str, pywt_boundary: str) -> None
261281
assert order_el.path == order_path
262282

263283

284+
@pytest.mark.parametrize("level", [1, 2, 3, 4])
285+
@pytest.mark.parametrize("wavelet_str", ["db2"])
286+
@pytest.mark.parametrize("pywt_boundary", ["zero"])
287+
def test_natural_order_2d(level: int, wavelet_str: str, pywt_boundary: str) -> None:
288+
"""Test the packets in natural order."""
289+
face = datasets.face()
290+
wavelet = pywt.Wavelet(wavelet_str)
291+
wp_tree = pywt.WaveletPacket2D(
292+
data=np.mean(face, axis=-1).astype(np.float64),
293+
wavelet=wavelet,
294+
mode=pywt_boundary,
295+
)
296+
# Get the full decomposition
297+
order_pywt = wp_tree.get_level(level, "natural")
298+
order_ptwt = WaveletPacket2D.get_natural_order(level)
299+
300+
for order_el, order_path in zip(order_pywt, order_ptwt):
301+
assert order_el.path == order_path
302+
303+
264304
def test_packet_harbo_lvl3() -> None:
265305
"""From Jensen, La Cour-Harbo, Rippels in Mathematics, Chapter 8 (page 89)."""
266306
data = np.array([56.0, 40.0, 8.0, 24.0, 48.0, 48.0, 40.0, 16.0])

0 commit comments

Comments
 (0)