|
6 | 6 | from collections.abc import Sequence
|
7 | 7 | from functools import partial
|
8 | 8 | from itertools import product
|
9 |
| -from typing import TYPE_CHECKING, Callable, Optional, Union |
| 9 | +from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload |
10 | 10 |
|
11 | 11 | import numpy as np
|
12 | 12 | import pywt
|
|
16 | 16 | from .constants import (
|
17 | 17 | ExtendedBoundaryMode,
|
18 | 18 | OrthogonalizeMethod,
|
| 19 | + PacketNodeOrder, |
19 | 20 | WaveletCoeff2d,
|
20 | 21 | WaveletCoeffNd,
|
21 | 22 | WaveletDetailTuple2d,
|
@@ -203,18 +204,34 @@ def _get_waverec(
|
203 | 204 | else:
|
204 | 205 | return partial(waverec, wavelet=self.wavelet, axis=self.axis)
|
205 | 206 |
|
206 |
| - def get_level(self, level: int) -> list[str]: |
207 |
| - """Return the graycode-ordered paths to the filter tree nodes. |
| 207 | + @staticmethod |
| 208 | + def get_level(level: int, order: PacketNodeOrder = "freq") -> list[str]: |
| 209 | + """Return the paths to the filter tree nodes. |
208 | 210 |
|
209 | 211 | Args:
|
210 | 212 | level (int): The depth of the tree.
|
| 213 | + order: The order the paths are in. |
| 214 | + Choose from frequency order (``freq``) and |
| 215 | + natural order (``natural``). |
| 216 | + Defaults to ``freq``. |
211 | 217 |
|
212 | 218 | Returns:
|
213 | 219 | A list with the paths to each node.
|
| 220 | +
|
| 221 | + Raises: |
| 222 | + ValueError: If `order` is neither ``freq`` nor ``natural``. |
214 | 223 | """
|
215 |
| - return self._get_graycode_order(level) |
| 224 | + if order == "freq": |
| 225 | + return WaveletPacket._get_graycode_order(level) |
| 226 | + elif order == "natural": |
| 227 | + return ["".join(p) for p in product(["a", "d"], repeat=level)] |
| 228 | + else: |
| 229 | + raise ValueError( |
| 230 | + f"Unsupported order '{order}'. Choose from 'freq' and 'natural'." |
| 231 | + ) |
216 | 232 |
|
217 |
| - def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> list[str]: |
| 233 | + @staticmethod |
| 234 | + def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]: |
218 | 235 | graycode_order = [x, y]
|
219 | 236 | for _ in range(level - 1):
|
220 | 237 | graycode_order = [x + path for path in graycode_order] + [
|
@@ -514,6 +531,42 @@ def __getitem__(self, key: str) -> torch.Tensor:
|
514 | 531 | )
|
515 | 532 | return super().__getitem__(key)
|
516 | 533 |
|
| 534 | + @overload |
| 535 | + @staticmethod |
| 536 | + def get_level(level: int, order: Literal["freq"]) -> list[list[str]]: ... |
| 537 | + |
| 538 | + @overload |
| 539 | + @staticmethod |
| 540 | + def get_level(level: int, order: Literal["natural"]) -> list[str]: ... |
| 541 | + |
| 542 | + @staticmethod |
| 543 | + def get_level( |
| 544 | + level: int, order: PacketNodeOrder = "freq" |
| 545 | + ) -> Union[list[str], list[list[str]]]: |
| 546 | + """Return the paths to the filter tree nodes. |
| 547 | +
|
| 548 | + Args: |
| 549 | + level (int): The depth of the tree. |
| 550 | + order: The order the paths are in. |
| 551 | + Choose from frequency order (``freq``) and |
| 552 | + natural order (``natural``). |
| 553 | + Defaults to ``freq``. |
| 554 | +
|
| 555 | + Returns: |
| 556 | + A list with the paths to each node. |
| 557 | +
|
| 558 | + Raises: |
| 559 | + ValueError: If `order` is neither ``freq`` nor ``natural``. |
| 560 | + """ |
| 561 | + if order == "freq": |
| 562 | + return WaveletPacket2D.get_freq_order(level) |
| 563 | + elif order == "natural": |
| 564 | + return WaveletPacket2D.get_natural_order(level) |
| 565 | + else: |
| 566 | + raise ValueError( |
| 567 | + f"Unsupported order '{order}'. Choose from 'freq' and 'natural'." |
| 568 | + ) |
| 569 | + |
517 | 570 | @staticmethod
|
518 | 571 | def get_natural_order(level: int) -> list[str]:
|
519 | 572 | """Get the natural ordering for a given decomposition level.
|
|
0 commit comments