Skip to content

Commit e7af4d4

Browse files
committed
Merge branch 'main' into feature/packets-partial-refinement
2 parents 2081036 + 4e271a8 commit e7af4d4

File tree

3 files changed

+210
-116
lines changed

3 files changed

+210
-116
lines changed

src/ptwt/constants.py

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@
5151
Choose ``gramschmidt`` if ``qr`` runs out of memory.
5252
"""
5353

54+
PacketNodeOrder = Literal["freq", "natural"]
55+
"""
56+
This is a type literal for the order of wavelet packet tree nodes.
57+
58+
- frequency order (``freq``)
59+
- natural order (``natural``)
60+
"""
61+
5462

5563
class WaveletDetailTuple2d(NamedTuple):
5664
"""Detail coefficients of a 2d wavelet transform for a given level.

src/ptwt/packets.py

+84-31
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
from collections.abc import Callable, Sequence
77
from functools import partial
88
from itertools import product
9-
from typing import TYPE_CHECKING, Optional, Union
9+
from typing import TYPE_CHECKING, Literal, Optional, Union, overload
1010

1111
import numpy as np
1212
import pywt
1313
import torch
1414

15-
from ._util import Wavelet, _as_wavelet
15+
from ._util import Wavelet, _as_wavelet, _swap_axes, _undo_swap_axes
1616
from .constants import (
1717
ExtendedBoundaryMode,
1818
OrthogonalizeMethod,
19+
PacketNodeOrder,
1920
WaveletCoeff2d,
2021
WaveletCoeffNd,
2122
WaveletDetailTuple2d,
@@ -115,9 +116,6 @@ def __init__(
115116
self.maxlevel: Optional[int] = None
116117
self.axis = axis
117118
if data is not None:
118-
if len(data.shape) == 1:
119-
# add a batch dimension.
120-
data = data.unsqueeze(0)
121119
self.transform(data, maxlevel, lazy_init=lazy_init)
122120
else:
123121
self.data = {}
@@ -146,7 +144,7 @@ def transform(
146144
"""
147145
self.data = {"": data}
148146
if maxlevel is None:
149-
maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
147+
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
150148
self.maxlevel = maxlevel
151149
if not lazy_init:
152150
self._recursive_dwt(path="")
@@ -188,13 +186,15 @@ def _test_key(key: str) -> None:
188186

189187
data_a = self[node + "a"]
190188
data_d = self[node + "d"]
191-
rec = self._get_waverec(data_a.shape[-1])([data_a, data_d])
189+
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_d])
192190
if level > 0:
193-
if rec.shape[-1] != self[node].shape[-1]:
191+
if rec.shape[self.axis] != self[node].shape[self.axis]:
194192
assert (
195-
rec.shape[-1] == self[node].shape[-1] + 1
193+
rec.shape[self.axis] == self[node].shape[self.axis] + 1
196194
), "padding error, please open an issue on github"
197-
rec = rec[..., :-1]
195+
rec = rec.swapaxes(self.axis, -1)[..., :-1].swapaxes(
196+
-1, self.axis
197+
)
198198
self[node] = rec
199199
return self
200200

@@ -226,18 +226,34 @@ def _get_waverec(
226226
else:
227227
return partial(waverec, wavelet=self.wavelet, axis=self.axis)
228228

229-
def get_level(self, level: int) -> list[str]:
230-
"""Return the graycode-ordered paths to the filter tree nodes.
229+
@staticmethod
230+
def get_level(level: int, order: PacketNodeOrder = "freq") -> list[str]:
231+
"""Return the paths to the filter tree nodes.
231232
232233
Args:
233234
level (int): The depth of the tree.
235+
order: The order the paths are in.
236+
Choose from frequency order (``freq``) and
237+
natural order (``natural``).
238+
Defaults to ``freq``.
234239
235240
Returns:
236241
A list with the paths to each node.
242+
243+
Raises:
244+
ValueError: If `order` is neither ``freq`` nor ``natural``.
237245
"""
238-
return self._get_graycode_order(level)
246+
if order == "freq":
247+
return WaveletPacket._get_graycode_order(level)
248+
elif order == "natural":
249+
return ["".join(p) for p in product(["a", "d"], repeat=level)]
250+
else:
251+
raise ValueError(
252+
f"Unsupported order '{order}'. Choose from 'freq' and 'natural'."
253+
)
239254

240-
def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[str]:
255+
@staticmethod
256+
def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
241257
graycode_order = [x, y]
242258
for _ in range(level - 1):
243259
graycode_order = [x + path for path in graycode_order] + [
@@ -250,12 +266,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
250266

251267
def _expand_node(self, path: str) -> None:
252268
data = self[path]
253-
res_lo, res_hi = self._get_wavedec(data.shape[-1])(data)
269+
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
254270
self.data[path + "a"] = res_lo
255271
self.data[path + "d"] = res_hi
256272

257273
def _recursive_dwt(self, path: str) -> None:
258-
if not self.maxlevel:
274+
if self.maxlevel is None:
259275
raise AssertionError
260276

261277
if len(path) >= self.maxlevel:
@@ -395,13 +411,10 @@ def transform(
395411
"""
396412
self.data = {"": data}
397413
if maxlevel is None:
398-
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
414+
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
415+
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
399416
self.maxlevel = maxlevel
400417

401-
if data.dim() == 2:
402-
# add batch dim to unbatched input
403-
data = data.unsqueeze(0)
404-
405418
if not lazy_init:
406419
self._recursive_dwt2d(path="")
407420
return self
@@ -415,9 +428,8 @@ def reconstruct(self) -> WaveletPacket2D:
415428
a reconstruction from the leaves.
416429
"""
417430
if self.maxlevel is None:
418-
self.maxlevel = pywt.dwt_max_level(
419-
min(self[""].shape[-2:]), self.wavelet.dec_len
420-
)
431+
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
432+
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
421433

422434
for level in reversed(range(self.maxlevel)):
423435
for node in WaveletPacket2D.get_natural_order(level):
@@ -434,26 +446,31 @@ def _test_key(key: str) -> None:
434446
data_h = self[node + "h"]
435447
data_v = self[node + "v"]
436448
data_d = self[node + "d"]
437-
rec = self._get_waverec(data_a.shape[-2:])(
449+
transform_size = _swap_axes(data_a, self.axes).shape[-2:]
450+
rec = self._get_waverec(transform_size)(
438451
(data_a, WaveletDetailTuple2d(data_h, data_v, data_d))
439452
)
440453
if level > 0:
441-
if rec.shape[-1] != self[node].shape[-1]:
454+
rec = _swap_axes(rec, self.axes)
455+
swapped_node = _swap_axes(self[node], self.axes)
456+
if rec.shape[-1] != swapped_node.shape[-1]:
442457
assert (
443-
rec.shape[-1] == self[node].shape[-1] + 1
458+
rec.shape[-1] == swapped_node.shape[-1] + 1
444459
), "padding error, please open an issue on GitHub"
445460
rec = rec[..., :-1]
446-
if rec.shape[-2] != self[node].shape[-2]:
461+
if rec.shape[-2] != swapped_node.shape[-2]:
447462
assert (
448-
rec.shape[-2] == self[node].shape[-2] + 1
463+
rec.shape[-2] == swapped_node.shape[-2] + 1
449464
), "padding error, please open an issue on GitHub"
450465
rec = rec[..., :-1, :]
466+
rec = _undo_swap_axes(rec, self.axes)
451467
self[node] = rec
452468
return self
453469

454470
def _expand_node(self, path: str) -> None:
455471
data = self[path]
456-
result = self._get_wavedec(data.shape[-2:])(data)
472+
transform_size = _swap_axes(data, self.axes).shape[-2:]
473+
result = self._get_wavedec(transform_size)(data)
457474

458475
# assert for type checking
459476
assert len(result) == 2
@@ -545,7 +562,7 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
545562
return _fsdict_func
546563

547564
def _recursive_dwt2d(self, path: str) -> None:
548-
if not self.maxlevel:
565+
if self.maxlevel is None:
549566
raise AssertionError
550567

551568
if len(path) >= self.maxlevel:
@@ -598,6 +615,42 @@ def __getitem__(self, key: str) -> torch.Tensor:
598615

599616
return super().__getitem__(key)
600617

618+
@overload
619+
@staticmethod
620+
def get_level(level: int, order: Literal["freq"]) -> list[list[str]]: ...
621+
622+
@overload
623+
@staticmethod
624+
def get_level(level: int, order: Literal["natural"]) -> list[str]: ...
625+
626+
@staticmethod
627+
def get_level(
628+
level: int, order: PacketNodeOrder = "freq"
629+
) -> Union[list[str], list[list[str]]]:
630+
"""Return the paths to the filter tree nodes.
631+
632+
Args:
633+
level (int): The depth of the tree.
634+
order: The order the paths are in.
635+
Choose from frequency order (``freq``) and
636+
natural order (``natural``).
637+
Defaults to ``freq``.
638+
639+
Returns:
640+
A list with the paths to each node.
641+
642+
Raises:
643+
ValueError: If `order` is neither ``freq`` nor ``natural``.
644+
"""
645+
if order == "freq":
646+
return WaveletPacket2D.get_freq_order(level)
647+
elif order == "natural":
648+
return WaveletPacket2D.get_natural_order(level)
649+
else:
650+
raise ValueError(
651+
f"Unsupported order '{order}'. Choose from 'freq' and 'natural'."
652+
)
653+
601654
@staticmethod
602655
def get_natural_order(level: int) -> list[str]:
603656
"""Get the natural ordering for a given decomposition level.

0 commit comments

Comments
 (0)