Skip to content

add bfloat test. #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [3.9, 3.11]
python-version: [3.12, 3.11]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -26,7 +26,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, 3.11]
python-version: [3.12, 3.11]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -42,7 +42,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, 3.11]
python-version: [3.12, 3.11]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -58,7 +58,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.11]
python-version: [3.12]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module implements our CI function calls."""

import nox


Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ classifiers =
Intended Audience :: Science/Research
Operating System :: OS Independent
Programming Language :: Python
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3 :: Only
Topic :: Scientific/Engineering :: Artificial Intelligence

Expand All @@ -56,7 +56,7 @@ install_requires =
pytest
nox

python_requires = >=3.9
python_requires = >=3.11

packages = find:
package_dir =
Expand Down
1 change: 1 addition & 0 deletions src/ptwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""

from ._util import Wavelet
from .continuous_transform import cwt
from .conv_transform import wavedec, waverec
Expand Down
1 change: 1 addition & 0 deletions src/ptwt/_stationary_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module implements stationary wavelet transforms."""

# Created by moritz wolter, in 2024

from typing import List, Optional, Union
Expand Down
1 change: 1 addition & 0 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility methods to compute wavelet decompositions from a dataset."""

from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union

import numpy as np
Expand Down
15 changes: 8 additions & 7 deletions src/ptwt/continuous_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

This module is based on pywt's cwt implementation.
"""

# Written by the Pytorch wavelet toolbox team in 2024
from typing import Any, Tuple, Union

Expand Down Expand Up @@ -195,7 +196,7 @@ def _integrate(
if type(arr) is np.ndarray:
integral = np.cumsum(arr)
elif type(arr) is torch.Tensor:
integral = torch.cumsum(arr, -1)
integral = torch.cumsum(arr, -1) # type: ignore
else:
raise TypeError("Only ndarrays or tensors are integratable.")
integral *= step
Expand Down Expand Up @@ -271,8 +272,8 @@ def wavefun(
"""Define a grid and evaluate the wavelet on it."""
length = 2**precision
# load the bounds from untyped pywt code.
lower_bound: float = float(self.lower_bound)
upper_bound: float = float(self.upper_bound)
lower_bound: float = float(self.lower_bound) # type: ignore
upper_bound: float = float(self.upper_bound) # type: ignore
grid = torch.linspace(
lower_bound,
upper_bound,
Expand All @@ -291,10 +292,10 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
shannon = (
torch.sqrt(self.bandwidth)
* (
torch.sin(torch.pi * self.bandwidth * grid_values) # type: ignore
torch.sin(torch.pi * self.bandwidth * grid_values)
/ (torch.pi * self.bandwidth * grid_values)
)
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
)
return shannon

Expand All @@ -306,8 +307,8 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
"""Return numerical values for the wavelet on a grid."""
morlet = (
1.0
/ torch.sqrt(torch.pi * self.bandwidth) # type: ignore
/ torch.sqrt(torch.pi * self.bandwidth)
* torch.exp(-(grid_values**2) / self.bandwidth)
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
)
return morlet
1 change: 1 addition & 0 deletions src/ptwt/conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

This module treats boundaries with edge-padding.
"""

# Created by moritz wolter, 14.04.20
from typing import List, Optional, Sequence, Tuple, Union

Expand Down
3 changes: 2 additions & 1 deletion src/ptwt/conv_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
The implementation relies on torch.nn.functional.conv2d and
torch.nn.functional.conv_transpose2d under the hood.
"""

# Written by the Pytorch wavelet toolbox team in 2024


Expand Down Expand Up @@ -215,7 +216,7 @@ def wavedec2(
result_lst = _map_result(result_lst, _unfold_axes2)

if axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
result_lst = _map_result(result_lst, undo_swap_fn)

return result_lst
Expand Down
8 changes: 6 additions & 2 deletions src/ptwt/conv_transform_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

The functions here are based on torch.nn.functional.conv3d and it's transpose.
"""

# Written by the Pytorch wavelet toolbox team in 2024

from functools import partial
Expand Down Expand Up @@ -204,15 +205,18 @@ def wavedec3(
result_lst = _map_result(result_lst, _unfold_axes_fn)

if tuple(axes) != (-3, -2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
result_lst = _map_result(result_lst, undo_swap_fn)

return result_lst


def _waverec3d_fold_channels_3d_list(
coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
) -> Tuple[List[Union[torch.Tensor, Dict[str, torch.Tensor]]], List[int],]:
) -> Tuple[
List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
List[int],
]:
# fold the input coefficients for processing conv2d_transpose.
fold_coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = []
ds = list(_check_if_tensor(coeffs[0]).shape)
Expand Down
5 changes: 3 additions & 2 deletions src/ptwt/matmul_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
in Strang Nguyen (p. 32), as well as the description
of boundary filters in "Ripples in Mathematics" section 10.3 .
"""

# Created by moritz (wolter@cs.uni-bonn.de) at 14.04.20
import sys
from typing import List, Optional, Union
Expand Down Expand Up @@ -281,7 +282,7 @@ def _construct_analysis_matrices(
f"level {curr_level}, the current signal length {curr_length} is "
f"smaller than the filter length {filt_len}. Therefore, the "
"transformation is only computed up to the decomposition level "
f"{curr_level-1}.\n"
f"{curr_level - 1}.\n"
)
break

Expand Down Expand Up @@ -563,7 +564,7 @@ def _construct_synthesis_matrices(
f"level {curr_level}, the current signal length {curr_length} is "
f"smaller than the filter length {filt_len}. Therefore, the "
"transformation is only computed up to the decomposition level "
f"{curr_level-1}.\n"
f"{curr_level - 1}.\n"
)
break

Expand Down
9 changes: 6 additions & 3 deletions src/ptwt/matmul_transform_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

This module uses boundary filters to minimize padding.
"""

# Written by moritz ( @ wolter.tech ) in 2021
import sys
from functools import partial
Expand Down Expand Up @@ -364,7 +365,8 @@ def _construct_analysis_matrices(
f". At level {curr_level}, at least one of the current signal "
f"height and width ({current_height}, {current_width}) is smaller "
f"then the filter length {filt_len}. Therefore, the transformation "
f"is only computed up to the decomposition level {curr_level-1}.\n"
f"is only computed up to the decomposition "
f" level {curr_level - 1}.\n"
)
break
# the conv matrices require even length inputs.
Expand Down Expand Up @@ -544,7 +546,7 @@ def __call__(
split_list = _map_result(split_list, _unfold_axes2)

if self.axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=self.axes)
undo_swap_fn = partial(_undo_swap_axes, axes=list(self.axes))
split_list = _map_result(split_list, undo_swap_fn)

return split_list[::-1]
Expand Down Expand Up @@ -678,7 +680,8 @@ def _construct_synthesis_matrices(
f". At level {curr_level}, at least one of the current signal "
f"height and width ({current_height}, {current_width}) is smaller "
f"then the filter length {filt_len}. Therefore, the transformation "
f"is only computed up to the decomposition level {curr_level-1}.\n"
f"is only computed up to the "
f" decomposition level {curr_level - 1}.\n"
)
break
current_height, current_width, pad_tuple = _matrix_pad_2(
Expand Down
11 changes: 7 additions & 4 deletions src/ptwt/matmul_transform_3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implement 3D separable boundary transforms."""

# Written by the Pytorch wavelet toolbox team in 2024

import sys
Expand Down Expand Up @@ -122,7 +123,8 @@ def _construct_analysis_matrices(
f"depth, height, and width ({current_depth}, {current_height},"
f"{current_width}) is smaller "
f"then the filter length {filt_len}. Therefore, the transformation "
f"is only computed up to the decomposition level {curr_level-1}.\n"
f"is only computed up to the "
f" decomposition level {curr_level - 1}.\n"
)
break
# the conv matrices require even length inputs.
Expand All @@ -137,7 +139,7 @@ def _construct_analysis_matrices(
matrix_construction_fun = partial(
construct_boundary_a,
wavelet=self.wavelet,
boundary=self.boundary,
boundary=self.boundary, # type: ignore
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -265,7 +267,7 @@ def _split_rec(
split_list = _map_result(split_list, _unfold_axes_fn)

if self.axes != (-3, -2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=self.axes)
undo_swap_fn = partial(_undo_swap_axes, axes=list(self.axes))
split_list = _map_result(split_list, undo_swap_fn)

return split_list[::-1]
Expand Down Expand Up @@ -340,7 +342,8 @@ def _construct_synthesis_matrices(
f" depth, height and width ({current_depth}, {current_height}, "
f"{current_width}) is smaller than the filter length {filt_len}."
f" Therefore, the transformation "
f"is only computed up to the decomposition level {curr_level-1}.\n"
f"is only computed up to the "
f"decomposition level {curr_level - 1}.\n"
)
break
# the conv matrices require even length inputs.
Expand Down
19 changes: 10 additions & 9 deletions src/ptwt/packets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Compute analysis wavelet packet representations."""

# Created on Fri Apr 6 2021 by moritz (wolter@cs.uni-bonn.de)

import collections
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
if len(data.shape) == 1:
# add a batch dimension.
data = data.unsqueeze(0)
self.transform(data, maxlevel) # type: ignore
self.transform(data, maxlevel)
else:
self.data = {}

Expand Down Expand Up @@ -176,7 +177,11 @@ def _get_wavedec(
return self._matrix_wavedec_dict[length]
else:
return partial(
wavedec, wavelet=self.wavelet, level=1, mode=self.mode, axis=self.axis
wavedec,
wavelet=self.wavelet,
level=1,
mode=self.mode, # type: ignore
axis=self.axis,
)

def _get_waverec(
Expand Down Expand Up @@ -382,9 +387,7 @@ def get_natural_order(self, level: int) -> List[str]:
"""
return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)]

def _get_wavedec(
self, shape: Tuple[int, ...]
) -> Callable[
def _get_wavedec(self, shape: Tuple[int, ...]) -> Callable[
[torch.Tensor],
List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
]:
Expand Down Expand Up @@ -415,9 +418,7 @@ def _get_wavedec(
wavedec2, wavelet=self.wavelet, level=1, mode=self.mode, axes=self.axes
)

def _get_waverec(
self, shape: Tuple[int, ...]
) -> Callable[
def _get_waverec(self, shape: Tuple[int, ...]) -> Callable[
[List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]],
torch.Tensor,
]:
Expand Down Expand Up @@ -471,7 +472,7 @@ def _transform_tuple_to_fsdict_func(
def _fsdict_func(
coeffs: List[
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
]
],
) -> torch.Tensor:
a, (h, v, d) = coeffs
return fsdict_func([cast(torch.Tensor, a), {"ad": h, "da": v, "dd": d}])
Expand Down
11 changes: 6 additions & 5 deletions src/ptwt/separable_conv_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Under the hood, code in this module transforms all dimensions
using torch.nn.functional.conv1d and it's transpose.
"""

# Written by the Pytorch wavelet toolbox team in 2024
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -164,9 +165,9 @@ def _separable_conv_waverecn(

approx: torch.Tensor = coeffs[0]
for level_dict in coeffs[1:]:
keys = list(level_dict.keys())
level_dict["a" * max(map(len, keys))] = approx
approx = _separable_conv_idwtn(level_dict, wavelet)
keys = list(level_dict.keys()) # type: ignore
level_dict["a" * max(map(len, keys))] = approx # type: ignore
approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore
return approx


Expand Down Expand Up @@ -235,7 +236,7 @@ def fswavedec2(
res = _map_result(res, _unfold_axes2)

if axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
res = _map_result(res, undo_swap_fn)

return res
Expand Down Expand Up @@ -307,7 +308,7 @@ def fswavedec3(
res = _map_result(res, _unfold_axes3)

if axes != (-3, -2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
res = _map_result(res, undo_swap_fn)

return res
Expand Down
1 change: 1 addition & 0 deletions src/ptwt/sparse_math.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Efficiently construct fwt operations using sparse matrices."""

# Written by moritz ( @ wolter.tech ) 17.09.21
from itertools import product
from typing import List
Expand Down
Loading