12
12
import pywt
13
13
import torch
14
14
15
- from ._util import Wavelet , _as_wavelet
15
+ from ._util import Wavelet , _as_wavelet , _swap_axes , _undo_swap_axes
16
16
from .constants import (
17
17
ExtendedBoundaryMode ,
18
18
OrthogonalizeMethod ,
@@ -113,9 +113,6 @@ def __init__(
113
113
self .maxlevel : Optional [int ] = None
114
114
self .axis = axis
115
115
if data is not None :
116
- if len (data .shape ) == 1 :
117
- # add a batch dimension.
118
- data = data .unsqueeze (0 )
119
116
self .transform (data , maxlevel )
120
117
else :
121
118
self .data = {}
@@ -134,7 +131,7 @@ def transform(
134
131
"""
135
132
self .data = {}
136
133
if maxlevel is None :
137
- maxlevel = pywt .dwt_max_level (data .shape [- 1 ], self .wavelet .dec_len )
134
+ maxlevel = pywt .dwt_max_level (data .shape [self . axis ], self .wavelet .dec_len )
138
135
self .maxlevel = maxlevel
139
136
self ._recursive_dwt (data , level = 0 , path = "" )
140
137
return self
@@ -166,13 +163,15 @@ def reconstruct(self) -> WaveletPacket:
166
163
for node in self .get_level (level ):
167
164
data_a = self [node + "a" ]
168
165
data_b = self [node + "d" ]
169
- rec = self ._get_waverec (data_a .shape [- 1 ])([data_a , data_b ])
166
+ rec = self ._get_waverec (data_a .shape [self . axis ])([data_a , data_b ])
170
167
if level > 0 :
171
- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
168
+ if rec .shape [self . axis ] != self [node ].shape [self . axis ]:
172
169
assert (
173
- rec .shape [- 1 ] == self [node ].shape [- 1 ] + 1
170
+ rec .shape [self . axis ] == self [node ].shape [self . axis ] + 1
174
171
), "padding error, please open an issue on github"
175
- rec = rec [..., :- 1 ]
172
+ rec = rec .swapaxes (self .axis , - 1 )[..., :- 1 ].swapaxes (
173
+ - 1 , self .axis
174
+ )
176
175
self [node ] = rec
177
176
return self
178
177
@@ -227,12 +226,12 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[st
227
226
return graycode_order
228
227
229
228
def _recursive_dwt (self , data : torch .Tensor , level : int , path : str ) -> None :
230
- if not self .maxlevel :
229
+ if self .maxlevel is None :
231
230
raise AssertionError
232
231
233
232
self .data [path ] = data
234
233
if level < self .maxlevel :
235
- res_lo , res_hi = self ._get_wavedec (data .shape [- 1 ])(data )
234
+ res_lo , res_hi = self ._get_wavedec (data .shape [self . axis ])(data )
236
235
self ._recursive_dwt (res_lo , level + 1 , path + "a" )
237
236
self ._recursive_dwt (res_hi , level + 1 , path + "d" )
238
237
@@ -340,13 +339,10 @@ def transform(
340
339
"""
341
340
self .data = {}
342
341
if maxlevel is None :
343
- maxlevel = pywt .dwt_max_level (min (data .shape [- 2 :]), self .wavelet .dec_len )
342
+ min_transform_size = min (_swap_axes (data , self .axes ).shape [- 2 :])
343
+ maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
344
344
self .maxlevel = maxlevel
345
345
346
- if data .dim () == 2 :
347
- # add batch dim to unbatched input
348
- data = data .unsqueeze (0 )
349
-
350
346
self ._recursive_dwt2d (data , level = 0 , path = "" )
351
347
return self
352
348
@@ -359,30 +355,33 @@ def reconstruct(self) -> WaveletPacket2D:
359
355
a reconstruction from the leaves.
360
356
"""
361
357
if self .maxlevel is None :
362
- self .maxlevel = pywt .dwt_max_level (
363
- min (self ["" ].shape [- 2 :]), self .wavelet .dec_len
364
- )
358
+ min_transform_size = min (_swap_axes (self ["" ], self .axes ).shape [- 2 :])
359
+ self .maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
365
360
366
361
for level in reversed (range (self .maxlevel )):
367
362
for node in WaveletPacket2D .get_natural_order (level ):
368
363
data_a = self [node + "a" ]
369
364
data_h = self [node + "h" ]
370
365
data_v = self [node + "v" ]
371
366
data_d = self [node + "d" ]
372
- rec = self ._get_waverec (data_a .shape [- 2 :])(
367
+ transform_size = _swap_axes (data_a , self .axes ).shape [- 2 :]
368
+ rec = self ._get_waverec (transform_size )(
373
369
(data_a , WaveletDetailTuple2d (data_h , data_v , data_d ))
374
370
)
375
371
if level > 0 :
376
- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
372
+ rec = _swap_axes (rec , self .axes )
373
+ swapped_node = _swap_axes (self [node ], self .axes )
374
+ if rec .shape [- 1 ] != swapped_node .shape [- 1 ]:
377
375
assert (
378
- rec .shape [- 1 ] == self [ node ] .shape [- 1 ] + 1
376
+ rec .shape [- 1 ] == swapped_node .shape [- 1 ] + 1
379
377
), "padding error, please open an issue on GitHub"
380
378
rec = rec [..., :- 1 ]
381
- if rec .shape [- 2 ] != self [ node ] .shape [- 2 ]:
379
+ if rec .shape [- 2 ] != swapped_node .shape [- 2 ]:
382
380
assert (
383
- rec .shape [- 2 ] == self [ node ] .shape [- 2 ] + 1
381
+ rec .shape [- 2 ] == swapped_node .shape [- 2 ] + 1
384
382
), "padding error, please open an issue on GitHub"
385
383
rec = rec [..., :- 1 , :]
384
+ rec = _undo_swap_axes (rec , self .axes )
386
385
self [node ] = rec
387
386
return self
388
387
@@ -468,12 +467,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
468
467
return _fsdict_func
469
468
470
469
def _recursive_dwt2d (self , data : torch .Tensor , level : int , path : str ) -> None :
471
- if not self .maxlevel :
470
+ if self .maxlevel is None :
472
471
raise AssertionError
473
472
474
473
self .data [path ] = data
475
474
if level < self .maxlevel :
476
- result = self ._get_wavedec (data .shape [- 2 :])(data )
475
+ transform_size = _swap_axes (data , self .axes ).shape [- 2 :]
476
+ result = self ._get_wavedec (transform_size )(data )
477
477
478
478
# assert for type checking
479
479
assert len (result ) == 2
0 commit comments