Skip to content

Commit 6ef627c

Browse files
committed
Make supported dtypes available
1 parent fde2caa commit 6ef627c

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/ptwt/_util.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from .constants import (
14+
SUPPORTED_DTYPES,
1415
BoundaryMode,
1516
OrthogonalizeMethod,
1617
Wavelet,
@@ -126,7 +127,7 @@ def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) ->
126127

127128

128129
def _is_dtype_supported(dtype: torch.dtype) -> bool:
129-
return dtype in [torch.float32, torch.float64]
130+
return dtype in SUPPORTED_DTYPES
130131

131132

132133
def _outer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:

src/ptwt/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"WaveletDetailDict",
2121
]
2222

23+
SUPPORTED_DTYPES = {torch.float32, torch.float64}
24+
2325

2426
class Wavelet(Protocol):
2527
"""Wavelet object interface, based on the pywt wavelet object."""

0 commit comments

Comments
 (0)