3
3
from __future__ import annotations
4
4
5
5
import collections
6
- from collections .abc import Sequence
6
+ from collections .abc import Callable , Iterable , Sequence
7
7
from functools import partial
8
8
from itertools import product
9
- from typing import TYPE_CHECKING , Callable , Literal , Optional , Union , overload
9
+ from typing import TYPE_CHECKING , Literal , Optional , Union , overload
10
10
11
11
import numpy as np
12
12
import pywt
@@ -65,7 +65,9 @@ def __init__(
65
65
) -> None :
66
66
"""Create a wavelet packet decomposition object.
67
67
68
- The decompositions will rely on padded fast wavelet transforms.
68
+ The packet tree is initialized lazily, i.e. a coefficient is only
69
+ calculated as it is retrieved. This allows for partial expansion
70
+ of the wavelet packet tree.
69
71
70
72
Args:
71
73
data (torch.Tensor, optional): The input data array of shape ``[time]``,
@@ -98,13 +100,10 @@ def __init__(
98
100
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
99
101
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
100
102
>>> wavelet=pywt.Wavelet("db3"), mode="reflect")
101
- >>> np_lst = []
102
- >>> for node in wp.get_level(5):
103
- >>> np_lst.append(wp[node])
103
+ >>> np_lst = [wp[node] for node in wp.get_level(5)]
104
104
>>> viz = np.stack(np_lst).squeeze()
105
105
>>> plt.imshow(np.abs(viz))
106
106
>>> plt.show()
107
-
108
107
"""
109
108
self .wavelet = _as_wavelet (wavelet )
110
109
self .mode = mode
@@ -113,30 +112,54 @@ def __init__(
113
112
self ._matrix_waverec_dict : dict [int , MatrixWaverec ] = {}
114
113
self .maxlevel : Optional [int ] = None
115
114
self .axis = axis
115
+
116
+ self ._filter_keys = {"a" , "d" }
117
+
116
118
if data is not None :
117
119
self .transform (data , maxlevel )
118
120
else :
119
121
self .data = {}
120
122
121
123
def transform (
122
- self , data : torch .Tensor , maxlevel : Optional [int ] = None
124
+ self ,
125
+ data : torch .Tensor ,
126
+ maxlevel : Optional [int ] = None ,
123
127
) -> WaveletPacket :
124
- """Calculate the 1d wavelet packet transform for the input data.
128
+ """Lazily calculate the 1d wavelet packet transform for the input data.
129
+
130
+ The packet tree is initialized lazily, i.e. a coefficient is only
131
+ calculated as it is retrieved. This allows for partial expansion
132
+ of the wavelet packet tree.
133
+
134
+ The transform function allows reusing the same object.
125
135
126
136
Args:
127
137
data (torch.Tensor): The input data array of shape ``[time]``
128
138
or ``[batch_size, time]``.
129
139
maxlevel (int, optional): The highest decomposition level to compute.
130
140
If None, the maximum level is determined from the input data shape.
131
141
Defaults to None.
142
+
143
+ Returns:
144
+ This wavelet packet object (to allow call chaining).
132
145
"""
133
- self .data = {}
146
+ self .data = {"" : data }
134
147
if maxlevel is None :
135
148
maxlevel = pywt .dwt_max_level (data .shape [self .axis ], self .wavelet .dec_len )
136
149
self .maxlevel = maxlevel
137
- self ._recursive_dwt (data , level = 0 , path = "" )
138
150
return self
139
151
152
+ def initialize (self , keys : Iterable [str ]) -> None :
153
+ """Initialize the wavelet packet tree partially.
154
+
155
+ Args:
156
+ keys (Iterable[str]): An iterable yielding the keys of the
157
+ tree nodes to initialize.
158
+ """
159
+ it = (self [key ] for key in keys )
160
+ # exhaust iterator without storing all values
161
+ collections .deque (it , maxlen = 0 )
162
+
140
163
def reconstruct (self ) -> WaveletPacket :
141
164
"""Recursively reconstruct the input starting from the leaf nodes.
142
165
@@ -153,18 +176,29 @@ def reconstruct(self) -> WaveletPacket:
153
176
>>> signal = np.random.randn(1, 16)
154
177
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
155
178
>>> mode="boundary", maxlevel=2)
156
- >>> ptwp["aa"].data *= 0
179
+ >>> # initialize other leaf nodes
180
+ >>> ptwp.initialize(["ad", "da", "dd"])
181
+ >>> ptwp["aa"] = torch.zeros_like(ptwp["ad"])
157
182
>>> ptwp.reconstruct()
158
183
>>> print(ptwp[""])
184
+
185
+ Raises:
186
+ KeyError: if any leaf node data is not present.
159
187
"""
160
188
if self .maxlevel is None :
161
189
self .maxlevel = pywt .dwt_max_level (self ["" ].shape [- 1 ], self .wavelet .dec_len )
162
190
163
191
for level in reversed (range (self .maxlevel )):
164
192
for node in self .get_level (level ):
193
+ # check if any children is not available
194
+ # we need to check manually to avoid lazy init
195
+ for child in self ._filter_keys :
196
+ if node + child not in self :
197
+ raise KeyError (f"Key { node + child } not found" )
198
+
165
199
data_a = self [node + "a" ]
166
- data_b = self [node + "d" ]
167
- rec = self ._get_waverec (data_a .shape [self .axis ])([data_a , data_b ])
200
+ data_d = self [node + "d" ]
201
+ rec = self ._get_waverec (data_a .shape [self .axis ])([data_a , data_d ])
168
202
if level > 0 :
169
203
if rec .shape [self .axis ] != self [node ].shape [self .axis ]:
170
204
assert (
@@ -242,15 +276,11 @@ def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
242
276
else :
243
277
return graycode_order
244
278
245
- def _recursive_dwt (self , data : torch .Tensor , level : int , path : str ) -> None :
246
- if self .maxlevel is None :
247
- raise AssertionError
248
-
249
- self .data [path ] = data
250
- if level < self .maxlevel :
251
- res_lo , res_hi = self ._get_wavedec (data .shape [self .axis ])(data )
252
- self ._recursive_dwt (res_lo , level + 1 , path + "a" )
253
- self ._recursive_dwt (res_hi , level + 1 , path + "d" )
279
+ def _expand_node (self , path : str ) -> None :
280
+ data = self [path ]
281
+ res_lo , res_hi = self ._get_wavedec (data .shape [self .axis ])(data )
282
+ self .data [path + "a" ] = res_lo
283
+ self .data [path + "d" ] = res_hi
254
284
255
285
def __getitem__ (self , key : str ) -> torch .Tensor :
256
286
"""Access the coefficients in the wavelet packets tree.
@@ -265,7 +295,8 @@ def __getitem__(self, key: str) -> torch.Tensor:
265
295
266
296
Raises:
267
297
ValueError: If the wavelet packet tree is not initialized.
268
- KeyError: If no wavelet coefficients are indexed by the specified key.
298
+ KeyError: If no wavelet coefficients are indexed by the specified key
299
+ and a lazy initialization fails.
269
300
"""
270
301
if self .maxlevel is None :
271
302
raise ValueError (
@@ -278,6 +309,20 @@ def __getitem__(self, key: str) -> torch.Tensor:
278
309
"cannot be accessed! This wavelet packet tree is initialized with "
279
310
f"maximum level { self .maxlevel } ."
280
311
)
312
+ elif key not in self :
313
+ if key == "" :
314
+ raise ValueError (
315
+ "The requested root of the packet tree cannot be accessed! "
316
+ "The wavelet packet tree is not properly initialized. "
317
+ "Run `transform` before accessing tree values."
318
+ )
319
+ elif key [- 1 ] not in self ._filter_keys :
320
+ raise ValueError (
321
+ f"Invalid key '{ key } '. All chars in the key must be of the "
322
+ f"set { self ._filter_keys } ."
323
+ )
324
+ # calculate data from parent
325
+ self ._expand_node (key [:- 1 ])
281
326
return super ().__getitem__ (key )
282
327
283
328
@@ -300,6 +345,10 @@ def __init__(
300
345
) -> None :
301
346
"""Create a 2D-Wavelet packet tree.
302
347
348
+ The packet tree is initialized lazily, i.e. a coefficient is only
349
+ calculated as it is retrieved. This allows for partial expansion
350
+ of the wavelet packet tree.
351
+
303
352
Args:
304
353
data (torch.tensor, optional): The input data tensor.
305
354
For example of shape ``[batch_size, height, width]`` or
@@ -324,7 +373,6 @@ def __init__(
324
373
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
325
374
separable (bool): If true, a separable transform is performed,
326
375
i.e. each image axis is transformed separately. Defaults to False.
327
-
328
376
"""
329
377
self .wavelet = _as_wavelet (wavelet )
330
378
self .mode = mode
@@ -333,6 +381,7 @@ def __init__(
333
381
self .matrix_wavedec2_dict : dict [tuple [int , ...], MatrixWavedec2 ] = {}
334
382
self .matrix_waverec2_dict : dict [tuple [int , ...], MatrixWaverec2 ] = {}
335
383
self .axes = axes
384
+ self ._filter_keys = {"a" , "h" , "v" , "d" }
336
385
337
386
self .maxlevel : Optional [int ] = None
338
387
if data is not None :
@@ -341,42 +390,70 @@ def __init__(
341
390
self .data = {}
342
391
343
392
def transform (
344
- self , data : torch .Tensor , maxlevel : Optional [int ] = None
393
+ self ,
394
+ data : torch .Tensor ,
395
+ maxlevel : Optional [int ] = None ,
345
396
) -> WaveletPacket2D :
346
- """Calculate the 2d wavelet packet transform for the input data.
397
+ """Lazily calculate the 2d wavelet packet transform for the input data.
398
+
399
+ The packet tree is initialized lazily, i.e. a coefficient is only
400
+ calculated as it is retrieved. This allows for partial expansion
401
+ of the wavelet packet tree.
347
402
348
- The transform function allows reusing the same object.
403
+ The transform function allows reusing the same object.
349
404
350
405
Args:
351
406
data (torch.tensor): The input data tensor
352
407
of shape ``[batch_size, height, width]``.
353
408
maxlevel (int, optional): The highest decomposition level to compute.
354
409
If None, the maximum level is determined from the input data shape.
355
410
Defaults to None.
411
+
412
+ Returns:
413
+ This wavelet packet object (to allow call chaining).
356
414
"""
357
- self .data = {}
415
+ self .data = {"" : data }
358
416
if maxlevel is None :
359
417
min_transform_size = min (_swap_axes (data , self .axes ).shape [- 2 :])
360
418
maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
361
419
self .maxlevel = maxlevel
362
420
363
- self ._recursive_dwt2d (data , level = 0 , path = "" )
364
421
return self
365
422
423
+ def initialize (self , keys : Iterable [str ]) -> None :
424
+ """Initialize the wavelet packet tree partially.
425
+
426
+ Args:
427
+ keys (Iterable[str]): An iterable yielding the keys of the
428
+ tree nodes to initialize.
429
+ """
430
+ it = (self [key ] for key in keys )
431
+ # exhaust iterator without storing all values
432
+ collections .deque (it , maxlen = 0 )
433
+
366
434
def reconstruct (self ) -> WaveletPacket2D :
367
435
"""Recursively reconstruct the input starting from the leaf nodes.
368
436
369
437
Note:
370
438
Only changes to leaf node data impact the results,
371
439
since changes in all other nodes will be replaced with
372
440
a reconstruction from the leaves.
441
+
442
+ Raises:
443
+ KeyError: if any leaf node data is not present.
373
444
"""
374
445
if self .maxlevel is None :
375
446
min_transform_size = min (_swap_axes (self ["" ], self .axes ).shape [- 2 :])
376
447
self .maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
377
448
378
449
for level in reversed (range (self .maxlevel )):
379
450
for node in WaveletPacket2D .get_natural_order (level ):
451
+ # check if any children is not available
452
+ # we need to check manually to avoid lazy init
453
+ for child in self ._filter_keys :
454
+ if node + child not in self :
455
+ raise KeyError (f"Key { node + child } not found" )
456
+
380
457
data_a = self [node + "a" ]
381
458
data_h = self [node + "h" ]
382
459
data_v = self [node + "v" ]
@@ -402,6 +479,19 @@ def reconstruct(self) -> WaveletPacket2D:
402
479
self [node ] = rec
403
480
return self
404
481
482
+ def _expand_node (self , path : str ) -> None :
483
+ data = self [path ]
484
+ transform_size = _swap_axes (data , self .axes ).shape [- 2 :]
485
+ result = self ._get_wavedec (transform_size )(data )
486
+
487
+ # assert for type checking
488
+ assert len (result ) == 2
489
+ result_a , (result_h , result_v , result_d ) = result
490
+ self .data [path + "a" ] = result_a
491
+ self .data [path + "h" ] = result_h
492
+ self .data [path + "v" ] = result_v
493
+ self .data [path + "d" ] = result_d
494
+
405
495
def _get_wavedec (self , shape : tuple [int , ...]) -> Callable [
406
496
[torch .Tensor ],
407
497
WaveletCoeff2d ,
@@ -483,23 +573,6 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
483
573
484
574
return _fsdict_func
485
575
486
- def _recursive_dwt2d (self , data : torch .Tensor , level : int , path : str ) -> None :
487
- if self .maxlevel is None :
488
- raise AssertionError
489
-
490
- self .data [path ] = data
491
- if level < self .maxlevel :
492
- transform_size = _swap_axes (data , self .axes ).shape [- 2 :]
493
- result = self ._get_wavedec (transform_size )(data )
494
-
495
- # assert for type checking
496
- assert len (result ) == 2
497
- result_a , (result_h , result_v , result_d ) = result
498
- self ._recursive_dwt2d (result_a , level + 1 , path + "a" )
499
- self ._recursive_dwt2d (result_h , level + 1 , path + "h" )
500
- self ._recursive_dwt2d (result_v , level + 1 , path + "v" )
501
- self ._recursive_dwt2d (result_d , level + 1 , path + "d" )
502
-
503
576
def __getitem__ (self , key : str ) -> torch .Tensor :
504
577
"""Access the coefficients in the wavelet packets tree.
505
578
@@ -516,7 +589,8 @@ def __getitem__(self, key: str) -> torch.Tensor:
516
589
517
590
Raises:
518
591
ValueError: If the wavelet packet tree is not initialized.
519
- KeyError: If no wavelet coefficients are indexed by the specified key.
592
+ KeyError: If no wavelet coefficients are indexed by the specified key
593
+ and a lazy initialization fails.
520
594
"""
521
595
if self .maxlevel is None :
522
596
raise ValueError (
@@ -529,6 +603,21 @@ def __getitem__(self, key: str) -> torch.Tensor:
529
603
"cannot be accessed! This wavelet packet tree is initialized with "
530
604
f"maximum level { self .maxlevel } ."
531
605
)
606
+ elif key not in self :
607
+ if key == "" :
608
+ raise ValueError (
609
+ "The requested root of the packet tree cannot be accessed! "
610
+ "The wavelet packet tree is not properly initialized. "
611
+ "Run `transform` before accessing tree values."
612
+ )
613
+ elif key [- 1 ] not in self ._filter_keys :
614
+ raise ValueError (
615
+ f"Invalid key '{ key } '. All chars in the key must be of the "
616
+ f"set { self ._filter_keys } ."
617
+ )
618
+ # calculate data from parent
619
+ self ._expand_node (key [:- 1 ])
620
+
532
621
return super ().__getitem__ (key )
533
622
534
623
@overload
0 commit comments