Skip to content

Commit 41aeaa2

Browse files
committed
add test.
1 parent 6637ad3 commit 41aeaa2

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tests/test_packets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,24 @@ def test_separable_conv_packets_2d(axes: tuple[int, int]) -> None:
588588
ptwp.initialize(ptwp.get_natural_order(2))
589589
ptwp.reconstruct()
590590
assert np.allclose(signal, ptwp[""].data[:, :32, :32, :32])
591+
592+
593+
594+
def test_partial_reconstruction() -> None:
595+
596+
signal = np.random.randn(1, 16)
597+
signal2 = np.cos(np.linspace(0, 2 * np.pi, 16))
598+
ptwp = WaveletPacket(torch.from_numpy(signal), "haar",
599+
mode="reflect", maxlevel=2)
600+
ptwp.initialize(["aa", "ad", "da", "dd"])
601+
602+
ptwp2 = WaveletPacket(torch.from_numpy(signal2), "haar", mode="reflect", maxlevel=2)
603+
604+
# overwrite the first packet set.
605+
ptwp["aa"] = ptwp2["aa"]
606+
ptwp["ad"] = ptwp2["ad"]
607+
ptwp["da"] = ptwp2["da"]
608+
ptwp["dd"] = ptwp2["dd"]
609+
ptwp.reconstruct()
610+
611+
assert np.allclose(signal2, ptwp[""].numpy()[:16])

0 commit comments

Comments
 (0)