1212import pywt
1313import torch
1414
15- from ._util import Wavelet , _as_wavelet
15+ from ._util import Wavelet , _as_wavelet , _swap_axes , _undo_swap_axes
1616from .constants import (
1717 ExtendedBoundaryMode ,
1818 OrthogonalizeMethod ,
@@ -113,9 +113,6 @@ def __init__(
113113 self .maxlevel : Optional [int ] = None
114114 self .axis = axis
115115 if data is not None :
116- if len (data .shape ) == 1 :
117- # add a batch dimension.
118- data = data .unsqueeze (0 )
119116 self .transform (data , maxlevel )
120117 else :
121118 self .data = {}
@@ -134,7 +131,7 @@ def transform(
134131 """
135132 self .data = {}
136133 if maxlevel is None :
137- maxlevel = pywt .dwt_max_level (data .shape [- 1 ], self .wavelet .dec_len )
134+ maxlevel = pywt .dwt_max_level (data .shape [self . axis ], self .wavelet .dec_len )
138135 self .maxlevel = maxlevel
139136 self ._recursive_dwt (data , level = 0 , path = "" )
140137 return self
@@ -166,13 +163,15 @@ def reconstruct(self) -> WaveletPacket:
166163 for node in self .get_level (level ):
167164 data_a = self [node + "a" ]
168165 data_b = self [node + "d" ]
169- rec = self ._get_waverec (data_a .shape [- 1 ])([data_a , data_b ])
166+ rec = self ._get_waverec (data_a .shape [self . axis ])([data_a , data_b ])
170167 if level > 0 :
171- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
168+ if rec .shape [self . axis ] != self [node ].shape [self . axis ]:
172169 assert (
173- rec .shape [- 1 ] == self [node ].shape [- 1 ] + 1
170+ rec .shape [self . axis ] == self [node ].shape [self . axis ] + 1
174171 ), "padding error, please open an issue on github"
175- rec = rec [..., :- 1 ]
172+ rec = rec .swapaxes (self .axis , - 1 )[..., :- 1 ].swapaxes (
173+ - 1 , self .axis
174+ )
176175 self [node ] = rec
177176 return self
178177
@@ -227,12 +226,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
227226 return graycode_order
228227
229228 def _recursive_dwt (self , data : torch .Tensor , level : int , path : str ) -> None :
230- if not self .maxlevel :
229+ if self .maxlevel is None :
231230 raise AssertionError
232231
233232 self .data [path ] = data
234233 if level < self .maxlevel :
235- res_lo , res_hi = self ._get_wavedec (data .shape [- 1 ])(data )
234+ res_lo , res_hi = self ._get_wavedec (data .shape [self . axis ])(data )
236235 self ._recursive_dwt (res_lo , level + 1 , path + "a" )
237236 self ._recursive_dwt (res_hi , level + 1 , path + "d" )
238237
@@ -340,13 +339,10 @@ def transform(
340339 """
341340 self .data = {}
342341 if maxlevel is None :
343- maxlevel = pywt .dwt_max_level (min (data .shape [- 2 :]), self .wavelet .dec_len )
342+ min_transform_size = min (_swap_axes (data , self .axes ).shape [- 2 :])
343+ maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
344344 self .maxlevel = maxlevel
345345
346- if data .dim () == 2 :
347- # add batch dim to unbatched input
348- data = data .unsqueeze (0 )
349-
350346 self ._recursive_dwt2d (data , level = 0 , path = "" )
351347 return self
352348
@@ -359,30 +355,33 @@ def reconstruct(self) -> WaveletPacket2D:
359355 a reconstruction from the leaves.
360356 """
361357 if self .maxlevel is None :
362- self .maxlevel = pywt .dwt_max_level (
363- min (self ["" ].shape [- 2 :]), self .wavelet .dec_len
364- )
358+ min_transform_size = min (_swap_axes (self ["" ], self .axes ).shape [- 2 :])
359+ self .maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
365360
366361 for level in reversed (range (self .maxlevel )):
367362 for node in WaveletPacket2D .get_natural_order (level ):
368363 data_a = self [node + "a" ]
369364 data_h = self [node + "h" ]
370365 data_v = self [node + "v" ]
371366 data_d = self [node + "d" ]
372- rec = self ._get_waverec (data_a .shape [- 2 :])(
367+ transform_size = _swap_axes (data_a , self .axes ).shape [- 2 :]
368+ rec = self ._get_waverec (transform_size )(
373369 (data_a , WaveletDetailTuple2d (data_h , data_v , data_d ))
374370 )
375371 if level > 0 :
376- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
372+ rec = _swap_axes (rec , self .axes )
373+ swapped_node = _swap_axes (self [node ], self .axes )
374+ if rec .shape [- 1 ] != swapped_node .shape [- 1 ]:
377375 assert (
378- rec .shape [- 1 ] == self [ node ] .shape [- 1 ] + 1
376+ rec .shape [- 1 ] == swapped_node .shape [- 1 ] + 1
379377 ), "padding error, please open an issue on GitHub"
380378 rec = rec [..., :- 1 ]
381- if rec .shape [- 2 ] != self [ node ] .shape [- 2 ]:
379+ if rec .shape [- 2 ] != swapped_node .shape [- 2 ]:
382380 assert (
383- rec .shape [- 2 ] == self [ node ] .shape [- 2 ] + 1
381+ rec .shape [- 2 ] == swapped_node .shape [- 2 ] + 1
384382 ), "padding error, please open an issue on GitHub"
385383 rec = rec [..., :- 1 , :]
384+ rec = _undo_swap_axes (rec , self .axes )
386385 self [node ] = rec
387386 return self
388387
@@ -468,12 +467,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
468467 return _fsdict_func
469468
470469 def _recursive_dwt2d (self , data : torch .Tensor , level : int , path : str ) -> None :
471- if not self .maxlevel :
470+ if self .maxlevel is None :
472471 raise AssertionError
473472
474473 self .data [path ] = data
475474 if level < self .maxlevel :
476- result = self ._get_wavedec (data .shape [- 2 :])(data )
475+ transform_size = _swap_axes (data , self .axes ).shape [- 2 :]
476+ result = self ._get_wavedec (transform_size )(data )
477477
478478 # assert for type checking
479479 assert len (result ) == 2
0 commit comments