@@ -62,7 +62,6 @@ def __init__(
62
62
maxlevel : Optional [int ] = None ,
63
63
axis : int = - 1 ,
64
64
boundary_orthogonalization : OrthogonalizeMethod = "qr" ,
65
- lazy_init : bool = False ,
66
65
) -> None :
67
66
"""Create a wavelet packet decomposition object.
68
67
@@ -89,10 +88,6 @@ def __init__(
89
88
to use in the sparse matrix backend,
90
89
see :data:`ptwt.constants.OrthogonalizeMethod`.
91
90
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
92
- lazy_init (bool): Value is passed on to :func:`transform`.
93
- If True, the packet tree is initialized lazily. This
94
- allows for partial expansion of the wavelet packet tree.
95
- Defaults to False.
96
91
97
92
Example:
98
93
>>> import torch, pywt, ptwt
@@ -116,15 +111,14 @@ def __init__(
116
111
self .maxlevel : Optional [int ] = None
117
112
self .axis = axis
118
113
if data is not None :
119
- self .transform (data , maxlevel , lazy_init = lazy_init )
114
+ self .transform (data , maxlevel )
120
115
else :
121
116
self .data = {}
122
117
123
118
def transform (
124
119
self ,
125
120
data : torch .Tensor ,
126
121
maxlevel : Optional [int ] = None ,
127
- lazy_init : bool = False ,
128
122
) -> WaveletPacket :
129
123
"""Calculate the 1d wavelet packet transform for the input data.
130
124
@@ -134,10 +128,6 @@ def transform(
134
128
maxlevel (int, optional): The highest decomposition level to compute.
135
129
If None, the maximum level is determined from the input data shape.
136
130
Defaults to None.
137
- lazy_init (bool): If True, the packet tree is initialized lazily.
138
- This allows for partial expansion of the wavelet packet tree.
139
- Otherwise, all packet coefficients up to the decomposition level
140
- `maxlevel` are computed. Defaults to False.
141
131
142
132
Returns:
143
133
This wavelet packet object (to allow call chaining).
@@ -146,8 +136,6 @@ def transform(
146
136
if maxlevel is None :
147
137
maxlevel = pywt .dwt_max_level (data .shape [self .axis ], self .wavelet .dec_len )
148
138
self .maxlevel = maxlevel
149
- if not lazy_init :
150
- self ._recursive_dwt (path = "" )
151
139
return self
152
140
153
141
def reconstruct (self ) -> WaveletPacket :
@@ -270,19 +258,6 @@ def _expand_node(self, path: str) -> None:
270
258
self .data [path + "a" ] = res_lo
271
259
self .data [path + "d" ] = res_hi
272
260
273
- def _recursive_dwt (self , path : str ) -> None :
274
- if self .maxlevel is None :
275
- raise AssertionError
276
-
277
- if len (path ) >= self .maxlevel :
278
- # nothing to expand
279
- return
280
-
281
- self ._expand_node (path )
282
-
283
- for child in ["a" , "d" ]:
284
- self ._recursive_dwt (path + child )
285
-
286
261
def __getitem__ (self , key : str ) -> torch .Tensor :
287
262
"""Access the coefficients in the wavelet packets tree.
288
263
@@ -338,7 +313,6 @@ def __init__(
338
313
axes : tuple [int , int ] = (- 2 , - 1 ),
339
314
boundary_orthogonalization : OrthogonalizeMethod = "qr" ,
340
315
separable : bool = False ,
341
- lazy_init : bool = False ,
342
316
) -> None :
343
317
"""Create a 2D-Wavelet packet tree.
344
318
@@ -366,10 +340,6 @@ def __init__(
366
340
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
367
341
separable (bool): If true, a separable transform is performed,
368
342
i.e. each image axis is transformed separately. Defaults to False.
369
- lazy_init (bool): Value is passed on to :func:`transform`.
370
- If True, the packet tree is initialized lazily. This
371
- allows for partial expansion of the wavelet packet tree.
372
- Defaults to False.
373
343
"""
374
344
self .wavelet = _as_wavelet (wavelet )
375
345
self .mode = mode
@@ -381,15 +351,14 @@ def __init__(
381
351
382
352
self .maxlevel : Optional [int ] = None
383
353
if data is not None :
384
- self .transform (data , maxlevel , lazy_init = lazy_init )
354
+ self .transform (data , maxlevel )
385
355
else :
386
356
self .data = {}
387
357
388
358
def transform (
389
359
self ,
390
360
data : torch .Tensor ,
391
361
maxlevel : Optional [int ] = None ,
392
- lazy_init : bool = False ,
393
362
) -> WaveletPacket2D :
394
363
"""Calculate the 2d wavelet packet transform for the input data.
395
364
@@ -401,10 +370,6 @@ def transform(
401
370
maxlevel (int, optional): The highest decomposition level to compute.
402
371
If None, the maximum level is determined from the input data shape.
403
372
Defaults to None.
404
- lazy_init (bool): If True, the packet tree is initialized lazily.
405
- This allows for partial expansion of the wavelet packet tree.
406
- Otherwise, all packet coefficients up to the decomposition level
407
- `maxlevel` are computed. Defaults to False.
408
373
409
374
Returns:
410
375
This wavelet packet object (to allow call chaining).
@@ -414,9 +379,6 @@ def transform(
414
379
min_transform_size = min (_swap_axes (data , self .axes ).shape [- 2 :])
415
380
maxlevel = pywt .dwt_max_level (min_transform_size , self .wavelet .dec_len )
416
381
self .maxlevel = maxlevel
417
-
418
- if not lazy_init :
419
- self ._recursive_dwt2d (path = "" )
420
382
return self
421
383
422
384
def reconstruct (self ) -> WaveletPacket2D :
0 commit comments