12
12
import pywt
13
13
import torch
14
14
15
- from ._util import Wavelet , _as_wavelet
15
+ from ._util import Wavelet , _as_wavelet , _swap_axes , _undo_swap_axes
16
16
from .constants import (
17
17
ExtendedBoundaryMode ,
18
18
OrthogonalizeMethod ,
@@ -114,9 +114,6 @@ def __init__(
114
114
self .maxlevel : Optional [int ] = None
115
115
self .axis = axis
116
116
if data is not None :
117
- if len (data .shape ) == 1 :
118
- # add a batch dimension.
119
- data = data .unsqueeze (0 )
120
117
self .transform (data , maxlevel )
121
118
else :
122
119
self .data = {}
@@ -135,7 +132,7 @@ def transform(
135
132
"""
136
133
self .data = {}
137
134
if maxlevel is None :
138
- maxlevel = pywt .dwt_max_level (data .shape [- 1 ], self .wavelet .dec_len )
135
+ maxlevel = pywt .dwt_max_level (data .shape [self . axis ], self .wavelet .dec_len )
139
136
self .maxlevel = maxlevel
140
137
self ._recursive_dwt (data , level = 0 , path = "" )
141
138
return self
@@ -167,13 +164,15 @@ def reconstruct(self) -> WaveletPacket:
167
164
for node in self .get_level (level ):
168
165
data_a = self [node + "a" ]
169
166
data_b = self [node + "d" ]
170
- rec = self ._get_waverec (data_a .shape [- 1 ])([data_a , data_b ])
167
+ rec = self ._get_waverec (data_a .shape [self . axis ])([data_a , data_b ])
171
168
if level > 0 :
172
- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
169
+ if rec .shape [self . axis ] != self [node ].shape [self . axis ]:
173
170
assert (
174
- rec .shape [- 1 ] == self [node ].shape [- 1 ] + 1
171
+ rec .shape [self . axis ] == self [node ].shape [self . axis ] + 1
175
172
), "padding error, please open an issue on github"
176
- rec = rec [..., :- 1 ]
173
+ rec = rec .swapaxes (self .axis , - 1 )[..., :- 1 ].swapaxes (
174
+ - 1 , self .axis
175
+ )
177
176
self [node ] = rec
178
177
return self
179
178
@@ -244,12 +243,12 @@ def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
244
243
return graycode_order
245
244
246
245
def _recursive_dwt (self , data : torch .Tensor , level : int , path : str ) -> None :
247
- if not self .maxlevel :
246
+ if self .maxlevel is None :
248
247
raise AssertionError
249
248
250
249
self .data [path ] = data
251
250
if level < self .maxlevel :
252
- res_lo , res_hi = self ._get_wavedec (data .shape [- 1 ])(data )
251
+ res_lo , res_hi = self ._get_wavedec (data .shape [self . axis ])(data )
253
252
self ._recursive_dwt (res_lo , level + 1 , path + "a" )
254
253
self ._recursive_dwt (res_hi , level + 1 , path + "d" )
255
254
@@ -357,13 +356,10 @@ def transform(
357
356
"""
358
357
self .data = {}
359
358
if maxlevel is None :
360
- maxlevel = pywt .dwt_max_level (min (data .shape [- 2 :]), self .wavelet .dec_len )
359
+ min_transform_size = min (_swap_axes (data , self .axes ).shape [- 2 :])
360
+ maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
361
361
self .maxlevel = maxlevel
362
362
363
- if data .dim () == 2 :
364
- # add batch dim to unbatched input
365
- data = data .unsqueeze (0 )
366
-
367
363
self ._recursive_dwt2d (data , level = 0 , path = "" )
368
364
return self
369
365
@@ -376,30 +372,33 @@ def reconstruct(self) -> WaveletPacket2D:
376
372
a reconstruction from the leaves.
377
373
"""
378
374
if self .maxlevel is None :
379
- self .maxlevel = pywt .dwt_max_level (
380
- min (self ["" ].shape [- 2 :]), self .wavelet .dec_len
381
- )
375
+ min_transform_size = min (_swap_axes (self ["" ], self .axes ).shape [- 2 :])
376
+ self .maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
382
377
383
378
for level in reversed (range (self .maxlevel )):
384
379
for node in WaveletPacket2D .get_natural_order (level ):
385
380
data_a = self [node + "a" ]
386
381
data_h = self [node + "h" ]
387
382
data_v = self [node + "v" ]
388
383
data_d = self [node + "d" ]
389
- rec = self ._get_waverec (data_a .shape [- 2 :])(
384
+ transform_size = _swap_axes (data_a , self .axes ).shape [- 2 :]
385
+ rec = self ._get_waverec (transform_size )(
390
386
(data_a , WaveletDetailTuple2d (data_h , data_v , data_d ))
391
387
)
392
388
if level > 0 :
393
- if rec .shape [- 1 ] != self [node ].shape [- 1 ]:
389
+ rec = _swap_axes (rec , self .axes )
390
+ swapped_node = _swap_axes (self [node ], self .axes )
391
+ if rec .shape [- 1 ] != swapped_node .shape [- 1 ]:
394
392
assert (
395
- rec .shape [- 1 ] == self [ node ] .shape [- 1 ] + 1
393
+ rec .shape [- 1 ] == swapped_node .shape [- 1 ] + 1
396
394
), "padding error, please open an issue on GitHub"
397
395
rec = rec [..., :- 1 ]
398
- if rec .shape [- 2 ] != self [ node ] .shape [- 2 ]:
396
+ if rec .shape [- 2 ] != swapped_node .shape [- 2 ]:
399
397
assert (
400
- rec .shape [- 2 ] == self [ node ] .shape [- 2 ] + 1
398
+ rec .shape [- 2 ] == swapped_node .shape [- 2 ] + 1
401
399
), "padding error, please open an issue on GitHub"
402
400
rec = rec [..., :- 1 , :]
401
+ rec = _undo_swap_axes (rec , self .axes )
403
402
self [node ] = rec
404
403
return self
405
404
@@ -485,12 +484,13 @@ def _fsdict_func(coeffs: WaveletCoeff2d) -> torch.Tensor:
485
484
return _fsdict_func
486
485
487
486
def _recursive_dwt2d (self , data : torch .Tensor , level : int , path : str ) -> None :
488
- if not self .maxlevel :
487
+ if self .maxlevel is None :
489
488
raise AssertionError
490
489
491
490
self .data [path ] = data
492
491
if level < self .maxlevel :
493
- result = self ._get_wavedec (data .shape [- 2 :])(data )
492
+ transform_size = _swap_axes (data , self .axes ).shape [- 2 :]
493
+ result = self ._get_wavedec (transform_size )(data )
494
494
495
495
# assert for type checking
496
496
assert len (result ) == 2
0 commit comments