Skip to content

Fix non-default axis in packets #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 26, 2024
52 changes: 26 additions & 26 deletions src/ptwt/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pywt
import torch

from ._util import Wavelet, _as_wavelet
from ._util import Wavelet, _as_wavelet, _swap_axes, _undo_swap_axes
from .constants import (
ExtendedBoundaryMode,
OrthogonalizeMethod,
Expand Down Expand Up @@ -113,9 +113,6 @@ def __init__(
self.maxlevel: Optional[int] = None
self.axis = axis
if data is not None:
if len(data.shape) == 1:
# add a batch dimension.
data = data.unsqueeze(0)
self.transform(data, maxlevel)
else:
self.data = {}
Expand All @@ -134,7 +131,7 @@ def transform(
"""
self.data = {}
if maxlevel is None:
maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
self.maxlevel = maxlevel
self._recursive_dwt(data, level=0, path="")
return self
Expand Down Expand Up @@ -166,13 +163,15 @@ def reconstruct(self) -> WaveletPacket:
for node in self.get_level(level):
data_a = self[node + "a"]
data_b = self[node + "d"]
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b])
if level > 0:
if rec.shape[-1] != self[node].shape[-1]:
if rec.shape[self.axis] != self[node].shape[self.axis]:
assert (
rec.shape[-1] == self[node].shape[-1] + 1
rec.shape[self.axis] == self[node].shape[self.axis] + 1
), "padding error, please open an issue on github"
rec = rec[..., :-1]
rec = rec.swapaxes(self.axis, -1)[..., :-1].swapaxes(
-1, self.axis
)
self[node] = rec
return self

Expand Down Expand Up @@ -227,12 +226,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
return graycode_order

def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
if not self.maxlevel:
if self.maxlevel is None:
raise AssertionError

self.data[path] = data
if level < self.maxlevel:
res_lo, res_hi = self._get_wavedec(data.shape[-1])(data)
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
self._recursive_dwt(res_lo, level + 1, path + "a")
self._recursive_dwt(res_hi, level + 1, path + "d")

Expand Down Expand Up @@ -340,13 +339,10 @@ def transform(
"""
self.data = {}
if maxlevel is None:
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
self.maxlevel = maxlevel

if data.dim() == 2:
# add batch dim to unbatched input
data = data.unsqueeze(0)

self._recursive_dwt2d(data, level=0, path="")
return self

Expand All @@ -359,30 +355,33 @@ def reconstruct(self) -> WaveletPacket2D:
a reconstruction from the leaves.
"""
if self.maxlevel is None:
self.maxlevel = pywt.dwt_max_level(
min(self[""].shape[-2:]), self.wavelet.dec_len
)
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)

for level in reversed(range(self.maxlevel)):
for node in WaveletPacket2D.get_natural_order(level):
data_a = self[node + "a"]
data_h = self[node + "h"]
data_v = self[node + "v"]
data_d = self[node + "d"]
rec = self._get_waverec(data_a.shape[-2:])(
transform_size = _swap_axes(data_a, self.axes).shape[-2:]
rec = self._get_waverec(transform_size)(
(data_a, WaveletDetailTuple2d(data_h, data_v, data_d))
)
if level > 0:
if rec.shape[-1] != self[node].shape[-1]:
rec = _swap_axes(rec, self.axes)
swapped_node = _swap_axes(self[node], self.axes)
if rec.shape[-1] != swapped_node.shape[-1]:
assert (
rec.shape[-1] == self[node].shape[-1] + 1
rec.shape[-1] == swapped_node.shape[-1] + 1
), "padding error, please open an issue on GitHub"
rec = rec[..., :-1]
if rec.shape[-2] != self[node].shape[-2]:
if rec.shape[-2] != swapped_node.shape[-2]:
assert (
rec.shape[-2] == self[node].shape[-2] + 1
rec.shape[-2] == swapped_node.shape[-2] + 1
), "padding error, please open an issue on GitHub"
rec = rec[..., :-1, :]
rec = _undo_swap_axes(rec, self.axes)
self[node] = rec
return self

Expand Down Expand Up @@ -468,12 +467,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
return _fsdict_func

def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
if not self.maxlevel:
if self.maxlevel is None:
raise AssertionError

self.data[path] = data
if level < self.maxlevel:
result = self._get_wavedec(data.shape[-2:])(data)
transform_size = _swap_axes(data, self.axes).shape[-2:]
result = self._get_wavedec(transform_size)(data)

# assert for type checking
assert len(result) == 2
Expand Down
Loading
Loading