Skip to content

Commit 848d0f7

Browse files
authored
Merge pull request #94 from v0lta/fix/packet-axis
Fix non-default axis in packets
2 parents ef4f80a + 48ad03b commit 848d0f7

File tree

2 files changed

+113
-108
lines changed

2 files changed

+113
-108
lines changed

src/ptwt/packets.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pywt
1313
import torch
1414

15-
from ._util import Wavelet, _as_wavelet
15+
from ._util import Wavelet, _as_wavelet, _swap_axes, _undo_swap_axes
1616
from .constants import (
1717
ExtendedBoundaryMode,
1818
OrthogonalizeMethod,
@@ -113,9 +113,6 @@ def __init__(
113113
self.maxlevel: Optional[int] = None
114114
self.axis = axis
115115
if data is not None:
116-
if len(data.shape) == 1:
117-
# add a batch dimension.
118-
data = data.unsqueeze(0)
119116
self.transform(data, maxlevel)
120117
else:
121118
self.data = {}
@@ -134,7 +131,7 @@ def transform(
134131
"""
135132
self.data = {}
136133
if maxlevel is None:
137-
maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
134+
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
138135
self.maxlevel = maxlevel
139136
self._recursive_dwt(data, level=0, path="")
140137
return self
@@ -166,13 +163,15 @@ def reconstruct(self) -> WaveletPacket:
166163
for node in self.get_level(level):
167164
data_a = self[node + "a"]
168165
data_b = self[node + "d"]
169-
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
166+
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b])
170167
if level > 0:
171-
if rec.shape[-1] != self[node].shape[-1]:
168+
if rec.shape[self.axis] != self[node].shape[self.axis]:
172169
assert (
173-
rec.shape[-1] == self[node].shape[-1] + 1
170+
rec.shape[self.axis] == self[node].shape[self.axis] + 1
174171
), "padding error, please open an issue on github"
175-
rec = rec[..., :-1]
172+
rec = rec.swapaxes(self.axis, -1)[..., :-1].swapaxes(
173+
-1, self.axis
174+
)
176175
self[node] = rec
177176
return self
178177

@@ -227,12 +226,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
227226
return graycode_order
228227

229228
def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
230-
if not self.maxlevel:
229+
if self.maxlevel is None:
231230
raise AssertionError
232231

233232
self.data[path] = data
234233
if level < self.maxlevel:
235-
res_lo, res_hi = self._get_wavedec(data.shape[-1])(data)
234+
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
236235
self._recursive_dwt(res_lo, level + 1, path + "a")
237236
self._recursive_dwt(res_hi, level + 1, path + "d")
238237

@@ -340,13 +339,10 @@ def transform(
340339
"""
341340
self.data = {}
342341
if maxlevel is None:
343-
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
342+
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
343+
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
344344
self.maxlevel = maxlevel
345345

346-
if data.dim() == 2:
347-
# add batch dim to unbatched input
348-
data = data.unsqueeze(0)
349-
350346
self._recursive_dwt2d(data, level=0, path="")
351347
return self
352348

@@ -359,30 +355,33 @@ def reconstruct(self) -> WaveletPacket2D:
359355
a reconstruction from the leaves.
360356
"""
361357
if self.maxlevel is None:
362-
self.maxlevel = pywt.dwt_max_level(
363-
min(self[""].shape[-2:]), self.wavelet.dec_len
364-
)
358+
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
359+
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
365360

366361
for level in reversed(range(self.maxlevel)):
367362
for node in WaveletPacket2D.get_natural_order(level):
368363
data_a = self[node + "a"]
369364
data_h = self[node + "h"]
370365
data_v = self[node + "v"]
371366
data_d = self[node + "d"]
372-
rec = self._get_waverec(data_a.shape[-2:])(
367+
transform_size = _swap_axes(data_a, self.axes).shape[-2:]
368+
rec = self._get_waverec(transform_size)(
373369
(data_a, WaveletDetailTuple2d(data_h, data_v, data_d))
374370
)
375371
if level > 0:
376-
if rec.shape[-1] != self[node].shape[-1]:
372+
rec = _swap_axes(rec, self.axes)
373+
swapped_node = _swap_axes(self[node], self.axes)
374+
if rec.shape[-1] != swapped_node.shape[-1]:
377375
assert (
378-
rec.shape[-1] == self[node].shape[-1] + 1
376+
rec.shape[-1] == swapped_node.shape[-1] + 1
379377
), "padding error, please open an issue on GitHub"
380378
rec = rec[..., :-1]
381-
if rec.shape[-2] != self[node].shape[-2]:
379+
if rec.shape[-2] != swapped_node.shape[-2]:
382380
assert (
383-
rec.shape[-2] == self[node].shape[-2] + 1
381+
rec.shape[-2] == swapped_node.shape[-2] + 1
384382
), "padding error, please open an issue on GitHub"
385383
rec = rec[..., :-1, :]
384+
rec = _undo_swap_axes(rec, self.axes)
386385
self[node] = rec
387386
return self
388387

@@ -468,12 +467,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
468467
return _fsdict_func
469468

470469
def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
471-
if not self.maxlevel:
470+
if self.maxlevel is None:
472471
raise AssertionError
473472

474473
self.data[path] = data
475474
if level < self.maxlevel:
476-
result = self._get_wavedec(data.shape[-2:])(data)
475+
transform_size = _swap_axes(data, self.axes).shape[-2:]
476+
result = self._get_wavedec(transform_size)(data)
477477

478478
# assert for type checking
479479
assert len(result) == 2

0 commit comments

Comments
 (0)