6
6
from collections .abc import Callable , Sequence
7
7
from functools import partial
8
8
from itertools import product
9
- from typing import TYPE_CHECKING , Optional , Union
9
+ from typing import TYPE_CHECKING , Literal , Optional , Union , overload
10
10
11
11
import numpy as np
12
12
import pywt
13
13
import torch
14
14
15
- from ._util import Wavelet , _as_wavelet
15
+ from ._util import Wavelet , _as_wavelet , _swap_axes , _undo_swap_axes
16
16
from .constants import (
17
17
ExtendedBoundaryMode ,
18
18
OrthogonalizeMethod ,
19
+ PacketNodeOrder ,
19
20
WaveletCoeff2d ,
20
21
WaveletCoeffNd ,
21
22
WaveletDetailTuple2d ,
@@ -115,9 +116,6 @@ def __init__(
115
116
self .maxlevel : Optional [int ] = None
116
117
self .axis = axis
117
118
if data is not None :
118
- if len (data .shape ) == 1 :
119
- # add a batch dimension.
120
- data = data .unsqueeze (0 )
121
119
self .transform (data , maxlevel , lazy_init = lazy_init )
122
120
else :
123
121
self .data = {}
@@ -146,7 +144,7 @@ def transform(
146
144
"""
147
145
self .data = {"" : data }
148
146
if maxlevel is None :
149
- maxlevel = pywt .dwt_max_level (data .shape [- 1 ], self .wavelet .dec_len )
147
+ maxlevel = pywt .dwt_max_level (data .shape [self . axis ], self .wavelet .dec_len )
150
148
self .maxlevel = maxlevel
151
149
if not lazy_init :
152
150
self ._recursive_dwt (path = "" )
@@ -188,13 +186,15 @@ def _test_key(key: str) -> None:
188
186
189
187
data_a = self [node + "a" ]
190
188
data_d = self [node + "d" ]
191
- rec = self ._get_waverec (data_a .shape [- 1 ])([data_a , data_d ])
189
+ rec = self ._get_waverec (data_a .shape [self . axis ])([data_a , data_d ])
192
190
if level > 0 :
193
- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
191
+ if rec .shape [self . axis ] != self [node ].shape [self . axis ]:
194
192
assert (
195
- rec .shape [- 1 ] == self [node ].shape [- 1 ] + 1
193
+ rec .shape [self . axis ] == self [node ].shape [self . axis ] + 1
196
194
), "padding error, please open an issue on github"
197
- rec = rec [..., :- 1 ]
195
+ rec = rec .swapaxes (self .axis , - 1 )[..., :- 1 ].swapaxes (
196
+ - 1 , self .axis
197
+ )
198
198
self [node ] = rec
199
199
return self
200
200
@@ -226,18 +226,34 @@ def _get_waverec(
226
226
else :
227
227
return partial (waverec , wavelet = self .wavelet , axis = self .axis )
228
228
229
- def get_level (self , level : int ) -> list [str ]:
230
- """Return the graycode-ordered paths to the filter tree nodes.
229
+ @staticmethod
230
+ def get_level (level : int , order : PacketNodeOrder = "freq" ) -> list [str ]:
231
+ """Return the paths to the filter tree nodes.
231
232
232
233
Args:
233
234
level (int): The depth of the tree.
235
+ order: The order the paths are in.
236
+ Choose from frequency order (``freq``) and
237
+ natural order (``natural``).
238
+ Defaults to ``freq``.
234
239
235
240
Returns:
236
241
A list with the paths to each node.
242
+
243
+ Raises:
244
+ ValueError: If `order` is neither ``freq`` nor ``natural``.
237
245
"""
238
- return self ._get_graycode_order (level )
246
+ if order == "freq" :
247
+ return WaveletPacket ._get_graycode_order (level )
248
+ elif order == "natural" :
249
+ return ["" .join (p ) for p in product (["a" , "d" ], repeat = level )]
250
+ else :
251
+ raise ValueError (
252
+ f"Unsupported order '{ order } '. Choose from 'freq' and 'natural'."
253
+ )
239
254
240
- def _get_graycode_order (self , level : int , x : str = "a" , y : str = "d" ) -> list [str ]:
255
+ @staticmethod
256
+ def _get_graycode_order (level : int , x : str = "a" , y : str = "d" ) -> list [str ]:
241
257
graycode_order = [x , y ]
242
258
for _ in range (level - 1 ):
243
259
graycode_order = [x + path for path in graycode_order ] + [
@@ -250,12 +266,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
250
266
251
267
def _expand_node (self , path : str ) -> None :
252
268
data = self [path ]
253
- res_lo , res_hi = self ._get_wavedec (data .shape [- 1 ])(data )
269
+ res_lo , res_hi = self ._get_wavedec (data .shape [self . axis ])(data )
254
270
self .data [path + "a" ] = res_lo
255
271
self .data [path + "d" ] = res_hi
256
272
257
273
def _recursive_dwt (self , path : str ) -> None :
258
- if not self .maxlevel :
274
+ if self .maxlevel is None :
259
275
raise AssertionError
260
276
261
277
if len (path ) >= self .maxlevel :
@@ -395,13 +411,10 @@ def transform(
395
411
"""
396
412
self .data = {"" : data }
397
413
if maxlevel is None :
398
- maxlevel = pywt .dwt_max_level (min (data .shape [- 2 :]), self .wavelet .dec_len )
414
+ min_transform_size = min (_swap_axes (data , self .axes ).shape [- 2 :])
415
+ maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
399
416
self .maxlevel = maxlevel
400
417
401
- if data .dim () == 2 :
402
- # add batch dim to unbatched input
403
- data = data .unsqueeze (0 )
404
-
405
418
if not lazy_init :
406
419
self ._recursive_dwt2d (path = "" )
407
420
return self
@@ -415,9 +428,8 @@ def reconstruct(self) -> WaveletPacket2D:
415
428
a reconstruction from the leaves.
416
429
"""
417
430
if self .maxlevel is None :
418
- self .maxlevel = pywt .dwt_max_level (
419
- min (self ["" ].shape [- 2 :]), self .wavelet .dec_len
420
- )
431
+ min_transform_size = min (_swap_axes (self ["" ], self .axes ).shape [- 2 :])
432
+ self .maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
421
433
422
434
for level in reversed (range (self .maxlevel )):
423
435
for node in WaveletPacket2D .get_natural_order (level ):
@@ -434,26 +446,31 @@ def _test_key(key: str) -> None:
434
446
data_h = self [node + "h" ]
435
447
data_v = self [node + "v" ]
436
448
data_d = self [node + "d" ]
437
- rec = self ._get_waverec (data_a .shape [- 2 :])(
449
+ transform_size = _swap_axes (data_a , self .axes ).shape [- 2 :]
450
+ rec = self ._get_waverec (transform_size )(
438
451
(data_a , WaveletDetailTuple2d (data_h , data_v , data_d ))
439
452
)
440
453
if level > 0 :
441
- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
454
+ rec = _swap_axes (rec , self .axes )
455
+ swapped_node = _swap_axes (self [node ], self .axes )
456
+ if rec .shape [- 1 ] != swapped_node .shape [- 1 ]:
442
457
assert (
443
- rec .shape [- 1 ] == self [ node ] .shape [- 1 ] + 1
458
+ rec .shape [- 1 ] == swapped_node .shape [- 1 ] + 1
444
459
), "padding error, please open an issue on GitHub"
445
460
rec = rec [..., :- 1 ]
446
- if rec .shape [- 2 ] != self [ node ] .shape [- 2 ]:
461
+ if rec .shape [- 2 ] != swapped_node .shape [- 2 ]:
447
462
assert (
448
- rec .shape [- 2 ] == self [ node ] .shape [- 2 ] + 1
463
+ rec .shape [- 2 ] == swapped_node .shape [- 2 ] + 1
449
464
), "padding error, please open an issue on GitHub"
450
465
rec = rec [..., :- 1 , :]
466
+ rec = _undo_swap_axes (rec , self .axes )
451
467
self [node ] = rec
452
468
return self
453
469
454
470
def _expand_node (self , path : str ) -> None :
455
471
data = self [path ]
456
- result = self ._get_wavedec (data .shape [- 2 :])(data )
472
+ transform_size = _swap_axes (data , self .axes ).shape [- 2 :]
473
+ result = self ._get_wavedec (transform_size )(data )
457
474
458
475
# assert for type checking
459
476
assert len (result ) == 2
@@ -545,7 +562,7 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
545
562
return _fsdict_func
546
563
547
564
def _recursive_dwt2d (self , path : str ) -> None :
548
- if not self .maxlevel :
565
+ if self .maxlevel is None :
549
566
raise AssertionError
550
567
551
568
if len (path ) >= self .maxlevel :
@@ -598,6 +615,42 @@ def __getitem__(self, key: str) -> torch.Tensor:
598
615
599
616
return super ().__getitem__ (key )
600
617
618
+ @overload
619
+ @staticmethod
620
+ def get_level (level : int , order : Literal ["freq" ]) -> list [list [str ]]: ...
621
+
622
+ @overload
623
+ @staticmethod
624
+ def get_level (level : int , order : Literal ["natural" ]) -> list [str ]: ...
625
+
626
+ @staticmethod
627
+ def get_level (
628
+ level : int , order : PacketNodeOrder = "freq"
629
+ ) -> Union [list [str ], list [list [str ]]]:
630
+ """Return the paths to the filter tree nodes.
631
+
632
+ Args:
633
+ level (int): The depth of the tree.
634
+ order: The order the paths are in.
635
+ Choose from frequency order (``freq``) and
636
+ natural order (``natural``).
637
+ Defaults to ``freq``.
638
+
639
+ Returns:
640
+ A list with the paths to each node.
641
+
642
+ Raises:
643
+ ValueError: If `order` is neither ``freq`` nor ``natural``.
644
+ """
645
+ if order == "freq" :
646
+ return WaveletPacket2D .get_freq_order (level )
647
+ elif order == "natural" :
648
+ return WaveletPacket2D .get_natural_order (level )
649
+ else :
650
+ raise ValueError (
651
+ f"Unsupported order '{ order } '. Choose from 'freq' and 'natural'."
652
+ )
653
+
601
654
@staticmethod
602
655
def get_natural_order (level : int ) -> list [str ]:
603
656
"""Get the natural ordering for a given decomposition level.
0 commit comments