1
+ """Test dtype support for the fwt code."""
2
+ # Written by moritz ( @ wolter.tech ) in 2025
3
+ import numpy as np
4
+ import pytest
5
+ import pywt
6
+ import torch
7
+ from scipy import datasets
8
+
9
+ from src .ptwt .conv_transform import (
10
+ _flatten_2d_coeff_lst ,
11
+ wavedec ,
12
+ waverec ,
13
+ )
14
+ from src .ptwt .conv_transform_2 import wavedec2 , waverec2
15
+
16
+ @pytest .mark .slow
17
+ @pytest .mark .parametrize ("dtype" , [torch .float64 , torch .float32 , torch .float16 , torch .bfloat16 ])
18
+ def test_2d_wavedec_rec (dtype ):
19
+ """Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients.
20
+
21
+ wavedec2 and waverec2 must invert each other.
22
+ """
23
+ mode = "reflect"
24
+ level = 2
25
+ size = (32 , 32 )
26
+ face = np .transpose (
27
+ datasets .face ()[256 : (512 + size [0 ]), 256 : (512 + size [1 ])], [2 , 0 , 1 ]
28
+ ).astype (np .float32 )
29
+ wavelet = pywt .Wavelet ("db2" )
30
+ to_transform = torch .from_numpy (face ).to (torch .float32 )
31
+ coeff2d = wavedec2 (to_transform , wavelet , mode = mode , level = level )
32
+ pywt_coeff2d = pywt .wavedec2 (face , wavelet , mode = mode , level = level )
33
+ for pos , coeffs in enumerate (pywt_coeff2d ):
34
+ if type (coeffs ) is tuple :
35
+ for tuple_pos , tuple_el in enumerate (coeffs ):
36
+ assert (
37
+ tuple_el .shape == coeff2d [pos ][tuple_pos ].shape
38
+ ), "pywt and ptwt should produce the same shapes."
39
+ else :
40
+ assert (
41
+ coeffs .shape == coeff2d [pos ].shape
42
+ ), "pywt and ptwt should produce the same shapes."
43
+ flat_coeff_list_pywt = np .concatenate (_flatten_2d_coeff_lst (pywt_coeff2d ), - 1 )
44
+ flat_coeff_list_ptwt = torch .cat (_flatten_2d_coeff_lst (coeff2d ), - 1 )
45
+ assert np .allclose (flat_coeff_list_pywt , flat_coeff_list_ptwt .numpy (), atol = 1e-3 )
46
+ rec = waverec2 (coeff2d , wavelet )
47
+ rec = rec .numpy ().squeeze ().astype (np .float32 )
48
+ assert np .allclose (face , rec [:, : face .shape [1 ], : face .shape [2 ]], atol = 1e-3 )
0 commit comments