File tree Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change 11
11
import torch
12
12
13
13
from .constants import (
14
+ SUPPORTED_DTYPES ,
14
15
BoundaryMode ,
15
16
OrthogonalizeMethod ,
16
17
Wavelet ,
@@ -126,7 +127,7 @@ def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) ->
126
127
127
128
128
129
def _is_dtype_supported (dtype : torch .dtype ) -> bool :
129
- return dtype in [ torch . float32 , torch . float64 ]
130
+ return dtype in SUPPORTED_DTYPES
130
131
131
132
132
133
def _outer (a : torch .Tensor , b : torch .Tensor ) -> torch .Tensor :
Original file line number Diff line number Diff line change 20
20
"WaveletDetailDict" ,
21
21
]
22
22
23
+ SUPPORTED_DTYPES = {torch .float32 , torch .float64 }
24
+
23
25
24
26
class Wavelet (Protocol ):
25
27
"""Wavelet object interface, based on the pywt wavelet object."""
You can’t perform that action at this time.
0 commit comments