Skip to content

Commit ee8d6da

Browse files
committed
Add overloads for 2d get_level
1 parent 731872a commit ee8d6da

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/ptwt/packets.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Sequence
77
from functools import partial
88
from itertools import product
9-
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
9+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload
1010

1111
import numpy as np
1212
import pywt
@@ -530,6 +530,14 @@ def __getitem__(self, key: str) -> torch.Tensor:
530530
)
531531
return super().__getitem__(key)
532532

533+
@overload
534+
@staticmethod
535+
def get_level(level: int, order: Literal["freq"]) -> list[list[str]]: ...
536+
537+
@overload
538+
@staticmethod
539+
def get_level(level: int, order: Literal["natural"]) -> list[str]: ...
540+
533541
@staticmethod
534542
def get_level(
535543
level: int, order: Literal["freq", "natural"] = "freq"

0 commit comments

Comments
 (0)