We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fde2caa commit 6ef627cCopy full SHA for 6ef627c
src/ptwt/_util.py
@@ -11,6 +11,7 @@
11
import torch
12
13
from .constants import (
14
+ SUPPORTED_DTYPES,
15
BoundaryMode,
16
OrthogonalizeMethod,
17
Wavelet,
@@ -126,7 +127,7 @@ def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) ->
126
127
128
129
def _is_dtype_supported(dtype: torch.dtype) -> bool:
- return dtype in [torch.float32, torch.float64]
130
+ return dtype in SUPPORTED_DTYPES
131
132
133
def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
src/ptwt/constants.py
@@ -20,6 +20,8 @@
20
"WaveletDetailDict",
21
]
22
23
+SUPPORTED_DTYPES = {torch.float32, torch.float64}
24
+
25
26
class Wavelet(Protocol):
27
"""Wavelet object interface, based on the pywt wavelet object."""
0 commit comments