Skip to content

Commit 8124814

Browse files
committed
remove argument.
1 parent e7af4d4 commit 8124814

File tree

2 files changed

+30
-98
lines changed

2 files changed

+30
-98
lines changed

src/ptwt/packets.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(
6262
maxlevel: Optional[int] = None,
6363
axis: int = -1,
6464
boundary_orthogonalization: OrthogonalizeMethod = "qr",
65-
lazy_init: bool = False,
6665
) -> None:
6766
"""Create a wavelet packet decomposition object.
6867
@@ -89,10 +88,6 @@ def __init__(
8988
to use in the sparse matrix backend,
9089
see :data:`ptwt.constants.OrthogonalizeMethod`.
9190
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
92-
lazy_init (bool): Value is passed on to :func:`transform`.
93-
If True, the packet tree is initialized lazily. This
94-
allows for partial expansion of the wavelet packet tree.
95-
Defaults to False.
9691
9792
Example:
9893
>>> import torch, pywt, ptwt
@@ -116,15 +111,14 @@ def __init__(
116111
self.maxlevel: Optional[int] = None
117112
self.axis = axis
118113
if data is not None:
119-
self.transform(data, maxlevel, lazy_init=lazy_init)
114+
self.transform(data, maxlevel)
120115
else:
121116
self.data = {}
122117

123118
def transform(
124119
self,
125120
data: torch.Tensor,
126121
maxlevel: Optional[int] = None,
127-
lazy_init: bool = False,
128122
) -> WaveletPacket:
129123
"""Calculate the 1d wavelet packet transform for the input data.
130124
@@ -134,10 +128,6 @@ def transform(
134128
maxlevel (int, optional): The highest decomposition level to compute.
135129
If None, the maximum level is determined from the input data shape.
136130
Defaults to None.
137-
lazy_init (bool): If True, the packet tree is initialized lazily.
138-
This allows for partial expansion of the wavelet packet tree.
139-
Otherwise, all packet coefficients up to the decomposition level
140-
`maxlevel` are computed. Defaults to False.
141131
142132
Returns:
143133
This wavelet packet object (to allow call chaining).
@@ -146,8 +136,6 @@ def transform(
146136
if maxlevel is None:
147137
maxlevel = pywt.dwt_max_level(data.shape[self.axis], self.wavelet.dec_len)
148138
self.maxlevel = maxlevel
149-
if not lazy_init:
150-
self._recursive_dwt(path="")
151139
return self
152140

153141
def reconstruct(self) -> WaveletPacket:
@@ -270,19 +258,6 @@ def _expand_node(self, path: str) -> None:
270258
self.data[path + "a"] = res_lo
271259
self.data[path + "d"] = res_hi
272260

273-
def _recursive_dwt(self, path: str) -> None:
274-
if self.maxlevel is None:
275-
raise AssertionError
276-
277-
if len(path) >= self.maxlevel:
278-
# nothing to expand
279-
return
280-
281-
self._expand_node(path)
282-
283-
for child in ["a", "d"]:
284-
self._recursive_dwt(path + child)
285-
286261
def __getitem__(self, key: str) -> torch.Tensor:
287262
"""Access the coefficients in the wavelet packets tree.
288263
@@ -338,7 +313,6 @@ def __init__(
338313
axes: tuple[int, int] = (-2, -1),
339314
boundary_orthogonalization: OrthogonalizeMethod = "qr",
340315
separable: bool = False,
341-
lazy_init: bool = False,
342316
) -> None:
343317
"""Create a 2D-Wavelet packet tree.
344318
@@ -366,10 +340,6 @@ def __init__(
366340
Only used if `mode` equals 'boundary'. Defaults to 'qr'.
367341
separable (bool): If true, a separable transform is performed,
368342
i.e. each image axis is transformed separately. Defaults to False.
369-
lazy_init (bool): Value is passed on to :func:`transform`.
370-
If True, the packet tree is initialized lazily. This
371-
allows for partial expansion of the wavelet packet tree.
372-
Defaults to False.
373343
"""
374344
self.wavelet = _as_wavelet(wavelet)
375345
self.mode = mode
@@ -381,15 +351,14 @@ def __init__(
381351

382352
self.maxlevel: Optional[int] = None
383353
if data is not None:
384-
self.transform(data, maxlevel, lazy_init=lazy_init)
354+
self.transform(data, maxlevel)
385355
else:
386356
self.data = {}
387357

388358
def transform(
389359
self,
390360
data: torch.Tensor,
391361
maxlevel: Optional[int] = None,
392-
lazy_init: bool = False,
393362
) -> WaveletPacket2D:
394363
"""Calculate the 2d wavelet packet transform for the input data.
395364
@@ -401,10 +370,6 @@ def transform(
401370
maxlevel (int, optional): The highest decomposition level to compute.
402371
If None, the maximum level is determined from the input data shape.
403372
Defaults to None.
404-
lazy_init (bool): If True, the packet tree is initialized lazily.
405-
This allows for partial expansion of the wavelet packet tree.
406-
Otherwise, all packet coefficients up to the decomposition level
407-
`maxlevel` are computed. Defaults to False.
408373
409374
Returns:
410375
This wavelet packet object (to allow call chaining).
@@ -414,9 +379,6 @@ def transform(
414379
min_transform_size = min(_swap_axes(data, self.axes).shape[-2:])
415380
maxlevel = pywt.dwt_max_level(min_transform_size, self.wavelet.dec_len)
416381
self.maxlevel = maxlevel
417-
418-
if not lazy_init:
419-
self._recursive_dwt2d(path="")
420382
return self
421383

422384
def reconstruct(self) -> WaveletPacket2D:

0 commit comments

Comments
 (0)