Skip to content

Commit fa7af3d

Browse files
authored
Merge pull request #96 from v0lta/feature/packets-partial-refinement
Add lazy init to packets for partial tree expansion
2 parents 4e271a8 + bb79f6e commit fa7af3d

File tree

3 files changed

+317
-70
lines changed

3 files changed

+317
-70
lines changed

examples/deepfake_analysis/packet_plot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def load_images(path: str) -> list:
6868

6969
if __name__ == "__main__":
7070
freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3)
71-
frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
71+
natural_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
7272
print("Loading ffhq images:")
7373
ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq")
7474
print("processing ffhq")

src/ptwt/packets.py

+137-48
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from __future__ import annotations
44

55
import collections
6-
from collections.abc import Sequence
6+
from collections.abc import Callable, Iterable, Sequence
77
from functools import partial
88
from itertools import product
9-
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload
9+
from typing import TYPE_CHECKING, Literal, Optional, Union, overload
1010

1111
import numpy as np
1212
import pywt
@@ -65,7 +65,9 @@ def __init__(
6565
) -> None:
6666
"""Create a wavelet packet decomposition object.
6767
68-
The decompositions will rely on padded fast wavelet transforms.
68+
The packet tree is initialized lazily, i.e. a coefficient is only
69+
calculated as it is retrieved. This allows for partial expansion
70+
of the wavelet packet tree.
6971
7072
Args:
7173
data (torch.Tensor, optional): The input data array of shape ``[time]``,
@@ -98,13 +100,10 @@ def __init__(
98100
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
99101
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
100102
>>> wavelet=pywt.Wavelet("db3"), mode="reflect")
101-
>>> np_lst = []
102-
>>> for node in wp.get_level(5):
103-
>>> np_lst.append(wp[node])
103+
>>> np_lst = [wp[node] for node in wp.get_level(5)]
104104
>>> viz = np.stack(np_lst).squeeze()
105105
>>> plt.imshow(np.abs(viz))
106106
>>> plt.show()
107-
108107
"""
109108
self.wavelet = _as_wavelet(wavelet)
110109
self.mode = mode
@@ -113,30 +112,54 @@ def __init__(
113112
self._matrix_waverec_dict: dict[int, MatrixWaverec] = {}
114113
self.maxlevel: Optional[int] = None
115114
self.axis = axis
115+
116+
self._filter_keys = {"a", "d"}
117+
116118
if data is not None:
117119
self.transform(data, maxlevel)
118120
else:
119121
self.data = {}
120122

121123
def transform(
122-
self, data: torch.Tensor, maxlevel: Optional[int] = None
124+
self,
125+
data: torch.Tensor,
126+
maxlevel: Optional[int] = None,
123127
) -> WaveletPacket:
124-
"""Calculate the 1d wavelet packet transform for the input data.
128+
"""Lazily calculate the 1d wavelet packet transform for the input data.
129+
130+
The packet tree is initialized lazily, i.e. a coefficient is only
131+
calculated as it is retrieved. This allows for partial expansion
132+
of the wavelet packet tree.
133+
134+
The transform function allows reusing the same object.
125135
126136
Args:
127137
data (torch.Tensor): The input data array of shape ``[time]``
128138
or ``[batch_size, time]``.
129139
maxlevel (int, optional): The highest decomposition level to compute.
130140
If None, the maximum level is determined from the input data shape.
131141
Defaults to None.
142+
143+
Returns:
144+
This wavelet packet object (to allow call chaining).
132145
"""
133-
self.data = {}
146+
self.data = {"": data}
134147
if maxlevel is None:
135148
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
136149
self.maxlevel = maxlevel
137-
self._recursive_dwt(data, level=0, path="")
138150
return self
139151

152+
def initialize(self, keys: Iterable[str]) -> None:
153+
"""Initialize the wavelet packet tree partially.
154+
155+
Args:
156+
keys (Iterable[str]): An iterable yielding the keys of the
157+
tree nodes to initialize.
158+
"""
159+
it = (self[key] for key in keys)
160+
# exhaust iterator without storing all values
161+
collections.deque(it, maxlen=0)
162+
140163
def reconstruct(self) -> WaveletPacket:
141164
"""Recursively reconstruct the input starting from the leaf nodes.
142165
@@ -153,18 +176,29 @@ def reconstruct(self) -> WaveletPacket:
153176
>>> signal = np.random.randn(1, 16)
154177
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
155178
>>> mode="boundary", maxlevel=2)
156-
>>> ptwp["aa"].data *= 0
179+
>>> # initialize other leaf nodes
180+
>>> ptwp.initialize(["ad", "da", "dd"])
181+
>>> ptwp["aa"] = torch.zeros_like(ptwp["ad"])
157182
>>> ptwp.reconstruct()
158183
>>> print(ptwp[""])
184+
185+
Raises:
186+
KeyError: if any leaf node data is not present.
159187
"""
160188
if self.maxlevel is None:
161189
self.maxlevel = pywt.dwt_max_level(self[""].shape[-1], self.wavelet.dec_len)
162190

163191
for level in reversed(range(self.maxlevel)):
164192
for node in self.get_level(level):
193+
# check if any children is not available
194+
# we need to check manually to avoid lazy init
195+
for child in self._filter_keys:
196+
if node + child not in self:
197+
raise KeyError(f"Key {node + child} not found")
198+
165199
data_a = self[node + "a"]
166-
data_b = self[node + "d"]
167-
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_b])
200+
data_d = self[node + "d"]
201+
rec = self._get_waverec(data_a.shape[self.axis])([data_a, data_d])
168202
if level > 0:
169203
if rec.shape[self.axis] != self[node].shape[self.axis]:
170204
assert (
@@ -242,15 +276,11 @@ def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
242276
else:
243277
return graycode_order
244278

245-
def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
246-
if self.maxlevel is None:
247-
raise AssertionError
248-
249-
self.data[path] = data
250-
if level < self.maxlevel:
251-
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
252-
self._recursive_dwt(res_lo, level + 1, path + "a")
253-
self._recursive_dwt(res_hi, level + 1, path + "d")
279+
def _expand_node(self, path: str) -> None:
280+
data = self[path]
281+
res_lo, res_hi = self._get_wavedec(data.shape[self.axis])(data)
282+
self.data[path + "a"] = res_lo
283+
self.data[path + "d"] = res_hi
254284

255285
def __getitem__(self, key: str) -> torch.Tensor:
256286
"""Access the coefficients in the wavelet packets tree.
@@ -265,7 +295,8 @@ def __getitem__(self, key: str) -> torch.Tensor:
265295
266296
Raises:
267297
ValueError: If the wavelet packet tree is not initialized.
268-
KeyError: If no wavelet coefficients are indexed by the specified key.
298+
KeyError: If no wavelet coefficients are indexed by the specified key
299+
and a lazy initialization fails.
269300
"""
270301
if self.maxlevel is None:
271302
raise ValueError(
@@ -278,6 +309,20 @@ def __getitem__(self, key: str) -> torch.Tensor:
278309
"cannot be accessed! This wavelet packet tree is initialized with "
279310
f"maximum level {self.maxlevel}."
280311
)
312+
elif key not in self:
313+
if key == "":
314+
raise ValueError(
315+
"The requested root of the packet tree cannot be accessed! "
316+
"The wavelet packet tree is not properly initialized. "
317+
"Run `transform` before accessing tree values."
318+
)
319+
elif key[-1] not in self._filter_keys:
320+
raise ValueError(
321+
f"Invalid key '{key}'. All chars in the key must be of the "
322+
f"set {self._filter_keys}."
323+
)
324+
# calculate data from parent
325+
self._expand_node(key[:-1])
281326
return super().__getitem__(key)
282327

283328

@@ -300,6 +345,10 @@ def __init__(
300345
) -> None:
301346
"""Create a 2D-Wavelet packet tree.
302347
348+
The packet tree is initialized lazily, i.e. a coefficient is only
349+
calculated as it is retrieved. This allows for partial expansion
350+
of the wavelet packet tree.
351+
303352
Args:
304353
data (torch.tensor, optional): The input data tensor.
305354
For example of shape ``[batch_size, height, width]`` or
@@ -324,7 +373,6 @@ def __init__(
324373
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
325374
separable (bool): If true, a separable transform is performed,
326375
i.e. each image axis is transformed separately. Defaults to False.
327-
328376
"""
329377
self.wavelet = _as_wavelet(wavelet)
330378
self.mode = mode
@@ -333,6 +381,7 @@ def __init__(
333381
self.matrix_wavedec2_dict: dict[tuple[int, ...], MatrixWavedec2] = {}
334382
self.matrix_waverec2_dict: dict[tuple[int, ...], MatrixWaverec2] = {}
335383
self.axes = axes
384+
self._filter_keys = {"a", "h", "v", "d"}
336385

337386
self.maxlevel: Optional[int] = None
338387
if data is not None:
@@ -341,42 +390,70 @@ def __init__(
341390
self.data = {}
342391

343392
def transform(
344-
self, data: torch.Tensor, maxlevel: Optional[int] = None
393+
self,
394+
data: torch.Tensor,
395+
maxlevel: Optional[int] = None,
345396
) -> WaveletPacket2D:
346-
"""Calculate the 2d wavelet packet transform for the input data.
397+
"""Lazily calculate the 2d wavelet packet transform for the input data.
398+
399+
The packet tree is initialized lazily, i.e. a coefficient is only
400+
calculated as it is retrieved. This allows for partial expansion
401+
of the wavelet packet tree.
347402
348-
The transform function allows reusing the same object.
403+
The transform function allows reusing the same object.
349404
350405
Args:
351406
data (torch.tensor): The input data tensor
352407
of shape ``[batch_size, height, width]``.
353408
maxlevel (int, optional): The highest decomposition level to compute.
354409
If None, the maximum level is determined from the input data shape.
355410
Defaults to None.
411+
412+
Returns:
413+
This wavelet packet object (to allow call chaining).
356414
"""
357-
self.data = {}
415+
self.data = {"": data}
358416
if maxlevel is None:
359417
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
360418
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
361419
self.maxlevel = maxlevel
362420

363-
self._recursive_dwt2d(data, level=0, path="")
364421
return self
365422

423+
def initialize(self, keys: Iterable[str]) -> None:
424+
"""Initialize the wavelet packet tree partially.
425+
426+
Args:
427+
keys (Iterable[str]): An iterable yielding the keys of the
428+
tree nodes to initialize.
429+
"""
430+
it = (self[key] for key in keys)
431+
# exhaust iterator without storing all values
432+
collections.deque(it, maxlen=0)
433+
366434
def reconstruct(self) -> WaveletPacket2D:
367435
"""Recursively reconstruct the input starting from the leaf nodes.
368436
369437
Note:
370438
Only changes to leaf node data impact the results,
371439
since changes in all other nodes will be replaced with
372440
a reconstruction from the leaves.
441+
442+
Raises:
443+
KeyError: if any leaf node data is not present.
373444
"""
374445
if self.maxlevel is None:
375446
min_transform_size = min(_swap_axes(self[""], self.axes).shape[-2:])
376447
self.maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
377448

378449
for level in reversed(range(self.maxlevel)):
379450
for node in WaveletPacket2D.get_natural_order(level):
451+
# check if any children is not available
452+
# we need to check manually to avoid lazy init
453+
for child in self._filter_keys:
454+
if node + child not in self:
455+
raise KeyError(f"Key {node + child} not found")
456+
380457
data_a = self[node + "a"]
381458
data_h = self[node + "h"]
382459
data_v = self[node + "v"]
@@ -402,6 +479,19 @@ def reconstruct(self) -> WaveletPacket2D:
402479
self[node] = rec
403480
return self
404481

482+
def _expand_node(self, path: str) -> None:
483+
data = self[path]
484+
transform_size = _swap_axes(data, self.axes).shape[-2:]
485+
result = self._get_wavedec(transform_size)(data)
486+
487+
# assert for type checking
488+
assert len(result) == 2
489+
result_a, (result_h, result_v, result_d) = result
490+
self.data[path + "a"] = result_a
491+
self.data[path + "h"] = result_h
492+
self.data[path + "v"] = result_v
493+
self.data[path + "d"] = result_d
494+
405495
def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[
406496
[torch.Tensor],
407497
WaveletCoeff2d,
@@ -483,23 +573,6 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
483573

484574
return _fsdict_func
485575

486-
def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
487-
if self.maxlevel is None:
488-
raise AssertionError
489-
490-
self.data[path] = data
491-
if level < self.maxlevel:
492-
transform_size = _swap_axes(data, self.axes).shape[-2:]
493-
result = self._get_wavedec(transform_size)(data)
494-
495-
# assert for type checking
496-
assert len(result) == 2
497-
result_a, (result_h, result_v, result_d) = result
498-
self._recursive_dwt2d(result_a, level + 1, path + "a")
499-
self._recursive_dwt2d(result_h, level + 1, path + "h")
500-
self._recursive_dwt2d(result_v, level + 1, path + "v")
501-
self._recursive_dwt2d(result_d, level + 1, path + "d")
502-
503576
def __getitem__(self, key: str) -> torch.Tensor:
504577
"""Access the coefficients in the wavelet packets tree.
505578
@@ -516,7 +589,8 @@ def __getitem__(self, key: str) -> torch.Tensor:
516589
517590
Raises:
518591
ValueError: If the wavelet packet tree is not initialized.
519-
KeyError: If no wavelet coefficients are indexed by the specified key.
592+
KeyError: If no wavelet coefficients are indexed by the specified key
593+
and a lazy initialization fails.
520594
"""
521595
if self.maxlevel is None:
522596
raise ValueError(
@@ -529,6 +603,21 @@ def __getitem__(self, key: str) -> torch.Tensor:
529603
"cannot be accessed! This wavelet packet tree is initialized with "
530604
f"maximum level {self.maxlevel}."
531605
)
606+
elif key not in self:
607+
if key == "":
608+
raise ValueError(
609+
"The requested root of the packet tree cannot be accessed! "
610+
"The wavelet packet tree is not properly initialized. "
611+
"Run `transform` before accessing tree values."
612+
)
613+
elif key[-1] not in self._filter_keys:
614+
raise ValueError(
615+
f"Invalid key '{key}'. All chars in the key must be of the "
616+
f"set {self._filter_keys}."
617+
)
618+
# calculate data from parent
619+
self._expand_node(key[:-1])
620+
532621
return super().__getitem__(key)
533622

534623
@overload

0 commit comments

Comments
 (0)