Skip to content

Commit 1f7a6f0

Browse files
committed
add bfloat test.
1 parent 8a562eb commit 1f7a6f0

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

tests/test_dtypes.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Test dtype support for the fwt code."""
2+
# Written by moritz ( @ wolter.tech ) in 2025
3+
import numpy as np
4+
import pytest
5+
import pywt
6+
import torch
7+
from scipy import datasets
8+
9+
from src.ptwt.conv_transform import (
10+
_flatten_2d_coeff_lst,
11+
wavedec,
12+
waverec,
13+
)
14+
from src.ptwt.conv_transform_2 import wavedec2, waverec2
15+
16+
@pytest.mark.slow
17+
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16])
18+
def test_2d_wavedec_rec(dtype):
19+
"""Ensure pywt.wavedec2 and ptwt.wavedec2 produce the same coefficients.
20+
21+
wavedec2 and waverec2 must invert each other.
22+
"""
23+
mode = "reflect"
24+
level = 2
25+
size = (32, 32)
26+
face = np.transpose(
27+
datasets.face()[256 : (512 + size[0]), 256 : (512 + size[1])], [2, 0, 1]
28+
).astype(np.float32)
29+
wavelet = pywt.Wavelet("db2")
30+
to_transform = torch.from_numpy(face).to(torch.float32)
31+
coeff2d = wavedec2(to_transform, wavelet, mode=mode, level=level)
32+
pywt_coeff2d = pywt.wavedec2(face, wavelet, mode=mode, level=level)
33+
for pos, coeffs in enumerate(pywt_coeff2d):
34+
if type(coeffs) is tuple:
35+
for tuple_pos, tuple_el in enumerate(coeffs):
36+
assert (
37+
tuple_el.shape == coeff2d[pos][tuple_pos].shape
38+
), "pywt and ptwt should produce the same shapes."
39+
else:
40+
assert (
41+
coeffs.shape == coeff2d[pos].shape
42+
), "pywt and ptwt should produce the same shapes."
43+
flat_coeff_list_pywt = np.concatenate(_flatten_2d_coeff_lst(pywt_coeff2d), -1)
44+
flat_coeff_list_ptwt = torch.cat(_flatten_2d_coeff_lst(coeff2d), -1)
45+
assert np.allclose(flat_coeff_list_pywt, flat_coeff_list_ptwt.numpy(), atol=1e-3)
46+
rec = waverec2(coeff2d, wavelet)
47+
rec = rec.numpy().squeeze().astype(np.float32)
48+
assert np.allclose(face, rec[:, : face.shape[1], : face.shape[2]], atol=1e-3)

0 commit comments

Comments
 (0)