Skip to content

Commit 7226e66

Browse files
committed
Merge branch 'main' into feature/packets-level-order
2 parents 20d3da9 + 848d0f7 commit 7226e66

File tree

2 files changed

+113
-108
lines changed

2 files changed

+113
-108
lines changed

src/ptwt/packets.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pywt
1313
import torch
1414

15-
from ._util import Wavelet, _as_wavelet
15+
from ._util import Wavelet, _as_wavelet, _swap_axes, _undo_swap_axes
1616
from .constants import (
1717
ExtendedBoundaryMode,
1818
OrthogonalizeMethod,
@@ -114,9 +114,6 @@ def __init__(
114114
self.maxlevel: Optional[int] = None
115115
self.axis = axis
116116
if data is not None:
117-
if len(data.shape) == 1:
118-
# add a batch dimension.
119-
data = data.unsqueeze(0)
120117
self.transform(data, maxlevel)
121118
else:
122119
self.data = {}
@@ -135,7 +132,7 @@ def transform(
135132
"""
136133
self.data = {}
137134
if maxlevel is None:
138-
maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
135+
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
139136
self.maxlevel = maxlevel
140137
self._recursive_dwt(data, level=0, path="")
141138
return self
@@ -167,13 +164,15 @@ def reconstruct(self) -> WaveletPacket:
167164
for node in self.get_level(level):
168165
data_a = self[node + "a"]
169166
data_b = self[node + "d"]
170-
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
167+
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b])
171168
if level > 0:
172-
if rec.shape[-1] != self[node].shape[-1]:
169+
if rec.shape[self.axis] != self[node].shape[self.axis]:
173170
assert (
174-
rec.shape[-1] == self[node].shape[-1] + 1
171+
rec.shape[self.axis] == self[node].shape[self.axis] + 1
175172
), "padding error, please open an issue on github"
176-
rec = rec[..., :-1]
173+
rec = rec.swapaxes(self.axis, -1)[..., :-1].swapaxes(
174+
-1, self.axis
175+
)
177176
self[node] = rec
178177
return self
179178

@@ -244,12 +243,12 @@ def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
244243
return graycode_order
245244

246245
def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
247-
if not self.maxlevel:
246+
if self.maxlevel is None:
248247
raise AssertionError
249248

250249
self.data[path] = data
251250
if level < self.maxlevel:
252-
res_lo, res_hi = self._get_wavedec(data.shape[-1])(data)
251+
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
253252
self._recursive_dwt(res_lo, level + 1, path + "a")
254253
self._recursive_dwt(res_hi, level + 1, path + "d")
255254

@@ -357,13 +356,10 @@ def transform(
357356
"""
358357
self.data = {}
359358
if maxlevel is None:
360-
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
359+
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
360+
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
361361
self.maxlevel = maxlevel
362362

363-
if data.dim() == 2:
364-
# add batch dim to unbatched input
365-
data = data.unsqueeze(0)
366-
367363
self._recursive_dwt2d(data, level=0, path="")
368364
return self
369365

@@ -376,30 +372,33 @@ def reconstruct(self) -> WaveletPacket2D:
376372
a reconstruction from the leaves.
377373
"""
378374
if self.maxlevel is None:
379-
self.maxlevel = pywt.dwt_max_level(
380-
min(self[""].shape[-2:]), self.wavelet.dec_len
381-
)
375+
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
376+
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
382377

383378
for level in reversed(range(self.maxlevel)):
384379
for node in WaveletPacket2D.get_natural_order(level):
385380
data_a = self[node + "a"]
386381
data_h = self[node + "h"]
387382
data_v = self[node + "v"]
388383
data_d = self[node + "d"]
389-
rec = self._get_waverec(data_a.shape[-2:])(
384+
transform_size = _swap_axes(data_a, self.axes).shape[-2:]
385+
rec = self._get_waverec(transform_size)(
390386
(data_a, WaveletDetailTuple2d(data_h, data_v, data_d))
391387
)
392388
if level > 0:
393-
if rec.shape[-1] != self[node].shape[-1]:
389+
rec = _swap_axes(rec, self.axes)
390+
swapped_node = _swap_axes(self[node], self.axes)
391+
if rec.shape[-1] != swapped_node.shape[-1]:
394392
assert (
395-
rec.shape[-1] == self[node].shape[-1] + 1
393+
rec.shape[-1] == swapped_node.shape[-1] + 1
396394
), "padding error, please open an issue on GitHub"
397395
rec = rec[..., :-1]
398-
if rec.shape[-2] != self[node].shape[-2]:
396+
if rec.shape[-2] != swapped_node.shape[-2]:
399397
assert (
400-
rec.shape[-2] == self[node].shape[-2] + 1
398+
rec.shape[-2] == swapped_node.shape[-2] + 1
401399
), "padding error, please open an issue on GitHub"
402400
rec = rec[..., :-1, :]
401+
rec = _undo_swap_axes(rec, self.axes)
403402
self[node] = rec
404403
return self
405404

@@ -485,12 +484,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
485484
return _fsdict_func
486485

487486
def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
488-
if not self.maxlevel:
487+
if self.maxlevel is None:
489488
raise AssertionError
490489

491490
self.data[path] = data
492491
if level < self.maxlevel:
493-
result = self._get_wavedec(data.shape[-2:])(data)
492+
transform_size = _swap_axes(data, self.axes).shape[-2:]
493+
result = self._get_wavedec(transform_size)(data)
494494

495495
# assert for type checking
496496
assert len(result) == 2

0 commit comments

Comments
 (0)