|
6 | 6 | from collections.abc import Sequence |
7 | 7 | from functools import partial |
8 | 8 | from itertools import product |
9 | | -from typing import TYPE_CHECKING, Callable, Optional, Union |
| 9 | +from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import pywt |
|
16 | 16 | from .constants import ( |
17 | 17 | ExtendedBoundaryMode, |
18 | 18 | OrthogonalizeMethod, |
| 19 | + PacketNodeOrder, |
19 | 20 | WaveletCoeff2d, |
20 | 21 | WaveletCoeffNd, |
21 | 22 | WaveletDetailTuple2d, |
@@ -203,18 +204,34 @@ def _get_waverec( |
203 | 204 | else: |
204 | 205 | return partial(waverec, wavelet=self.wavelet, axis=self.axis) |
205 | 206 |
|
206 | | - def get_level(self, level: int) -> list[str]: |
207 | | - """Return the graycode-ordered paths to the filter tree nodes. |
| 207 | + @staticmethod |
| 208 | + def get_level(level: int, order: PacketNodeOrder = "freq") -> list[str]: |
| 209 | + """Return the paths to the filter tree nodes. |
208 | 210 |
|
209 | 211 | Args: |
210 | 212 | level (int): The depth of the tree. |
| 213 | + order: The order the paths are in. |
| 214 | + Choose from frequency order (``freq``) and |
| 215 | + natural order (``natural``). |
| 216 | + Defaults to ``freq``. |
211 | 217 |
|
212 | 218 | Returns: |
213 | 219 | A list with the paths to each node. |
| 220 | +
|
| 221 | + Raises: |
| 222 | + ValueError: If `order` is neither ``freq`` nor ``natural``. |
214 | 223 | """ |
215 | | - return self._get_graycode_order(level) |
| 224 | + if order == "freq": |
| 225 | + return WaveletPacket._get_graycode_order(level) |
| 226 | + elif order == "natural": |
| 227 | + return ["".join(p) for p in product(["a", "d"], repeat=level)] |
| 228 | + else: |
| 229 | + raise ValueError( |
| 230 | + f"Unsupported order '{order}'. Choose from 'freq' and 'natural'." |
| 231 | + ) |
216 | 232 |
|
217 | | - def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[str]: |
| 233 | + @staticmethod |
| 234 | + def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]: |
218 | 235 | graycode_order = [x, y] |
219 | 236 | for _ in range(level - 1): |
220 | 237 | graycode_order = [x + path for path in graycode_order] + [ |
@@ -514,6 +531,42 @@ def __getitem__(self, key: str) -> torch.Tensor: |
514 | 531 | ) |
515 | 532 | return super().__getitem__(key) |
516 | 533 |
|
| 534 | + @overload |
| 535 | + @staticmethod |
| 536 | + def get_level(level: int, order: Literal["freq"]) -> list[list[str]]: ... |
| 537 | + |
| 538 | + @overload |
| 539 | + @staticmethod |
| 540 | + def get_level(level: int, order: Literal["natural"]) -> list[str]: ... |
| 541 | + |
| 542 | + @staticmethod |
| 543 | + def get_level( |
| 544 | + level: int, order: PacketNodeOrder = "freq" |
| 545 | + ) -> Union[list[str], list[list[str]]]: |
| 546 | + """Return the paths to the filter tree nodes. |
| 547 | +
|
| 548 | + Args: |
| 549 | + level (int): The depth of the tree. |
| 550 | + order: The order the paths are in. |
| 551 | + Choose from frequency order (``freq``) and |
| 552 | + natural order (``natural``). |
| 553 | + Defaults to ``freq``. |
| 554 | +
|
| 555 | + Returns: |
| 556 | + A list with the paths to each node. |
| 557 | +
|
| 558 | + Raises: |
| 559 | + ValueError: If `order` is neither ``freq`` nor ``natural``. |
| 560 | + """ |
| 561 | + if order == "freq": |
| 562 | + return WaveletPacket2D.get_freq_order(level) |
| 563 | + elif order == "natural": |
| 564 | + return WaveletPacket2D.get_natural_order(level) |
| 565 | + else: |
| 566 | + raise ValueError( |
| 567 | + f"Unsupported order '{order}'. Choose from 'freq' and 'natural'." |
| 568 | + ) |
| 569 | + |
517 | 570 | @staticmethod |
518 | 571 | def get_natural_order(level: int) -> list[str]: |
519 | 572 | """Get the natural ordering for a given decomposition level. |
|
0 commit comments