Skip to content

Commit 47a40ff

Browse files
committed
fix types.
1 parent 153f9f2 commit 47a40ff

7 files changed

+19
-19
lines changed

src/ptwt/continuous_transform.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _integrate(
196196
if type(arr) is np.ndarray:
197197
integral = np.cumsum(arr)
198198
elif type(arr) is torch.Tensor:
199-
integral = torch.cumsum(arr, -1)
199+
integral = torch.cumsum(arr, -1) # type: ignore
200200
else:
201201
raise TypeError("Only ndarrays or tensors are integratable.")
202202
integral *= step
@@ -272,8 +272,8 @@ def wavefun(
272272
"""Define a grid and evaluate the wavelet on it."""
273273
length = 2**precision
274274
# load the bounds from untyped pywt code.
275-
lower_bound: float = float(self.lower_bound)
276-
upper_bound: float = float(self.upper_bound)
275+
lower_bound: float = float(self.lower_bound) # type: ignore
276+
upper_bound: float = float(self.upper_bound) # type: ignore
277277
grid = torch.linspace(
278278
lower_bound,
279279
upper_bound,
@@ -292,10 +292,10 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
292292
shannon = (
293293
torch.sqrt(self.bandwidth)
294294
* (
295-
torch.sin(torch.pi * self.bandwidth * grid_values) # type: ignore
295+
torch.sin(torch.pi * self.bandwidth * grid_values)
296296
/ (torch.pi * self.bandwidth * grid_values)
297297
)
298-
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
298+
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
299299
)
300300
return shannon
301301

@@ -307,8 +307,8 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
307307
"""Return numerical values for the wavelet on a grid."""
308308
morlet = (
309309
1.0
310-
/ torch.sqrt(torch.pi * self.bandwidth) # type: ignore
310+
/ torch.sqrt(torch.pi * self.bandwidth)
311311
* torch.exp(-(grid_values**2) / self.bandwidth)
312-
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
312+
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
313313
)
314314
return morlet

src/ptwt/conv_transform_2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def wavedec2(
216216
result_lst = _map_result(result_lst, _unfold_axes2)
217217

218218
if axes != (-2, -1):
219-
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
219+
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
220220
result_lst = _map_result(result_lst, undo_swap_fn)
221221

222222
return result_lst

src/ptwt/conv_transform_3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def wavedec3(
205205
result_lst = _map_result(result_lst, _unfold_axes_fn)
206206

207207
if tuple(axes) != (-3, -2, -1):
208-
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
208+
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
209209
result_lst = _map_result(result_lst, undo_swap_fn)
210210

211211
return result_lst

src/ptwt/matmul_transform_2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def __call__(
546546
split_list = _map_result(split_list, _unfold_axes2)
547547

548548
if self.axes != (-2, -1):
549-
undo_swap_fn = partial(_undo_swap_axes, axes=self.axes)
549+
undo_swap_fn = partial(_undo_swap_axes, axes=list(self.axes))
550550
split_list = _map_result(split_list, undo_swap_fn)
551551

552552
return split_list[::-1]

src/ptwt/matmul_transform_3.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _construct_analysis_matrices(
139139
matrix_construction_fun = partial(
140140
construct_boundary_a,
141141
wavelet=self.wavelet,
142-
boundary=self.boundary,
142+
boundary=self.boundary, # type: ignore
143143
device=device,
144144
dtype=dtype,
145145
)
@@ -267,7 +267,7 @@ def _split_rec(
267267
split_list = _map_result(split_list, _unfold_axes_fn)
268268

269269
if self.axes != (-3, -2, -1):
270-
undo_swap_fn = partial(_undo_swap_axes, axes=self.axes)
270+
undo_swap_fn = partial(_undo_swap_axes, axes=list(self.axes))
271271
split_list = _map_result(split_list, undo_swap_fn)
272272

273273
return split_list[::-1]

src/ptwt/packets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
if len(data.shape) == 1:
106106
# add a batch dimension.
107107
data = data.unsqueeze(0)
108-
self.transform(data, maxlevel) # type: ignore
108+
self.transform(data, maxlevel)
109109
else:
110110
self.data = {}
111111

@@ -177,7 +177,7 @@ def _get_wavedec(
177177
return self._matrix_wavedec_dict[length]
178178
else:
179179
return partial(
180-
wavedec, wavelet=self.wavelet, level=1, mode=self.mode, axis=self.axis
180+
wavedec, wavelet=self.wavelet, level=1, mode=self.mode, axis=self.axis # type: ignore
181181
)
182182

183183
def _get_waverec(

src/ptwt/separable_conv_transform.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ def _separable_conv_waverecn(
165165

166166
approx: torch.Tensor = coeffs[0]
167167
for level_dict in coeffs[1:]:
168-
keys = list(level_dict.keys())
169-
level_dict["a" * max(map(len, keys))] = approx
170-
approx = _separable_conv_idwtn(level_dict, wavelet)
168+
keys = list(level_dict.keys()) # type: ignore
169+
level_dict["a" * max(map(len, keys))] = approx # type: ignore
170+
approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore
171171
return approx
172172

173173

@@ -236,7 +236,7 @@ def fswavedec2(
236236
res = _map_result(res, _unfold_axes2)
237237

238238
if axes != (-2, -1):
239-
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
239+
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
240240
res = _map_result(res, undo_swap_fn)
241241

242242
return res
@@ -308,7 +308,7 @@ def fswavedec3(
308308
res = _map_result(res, _unfold_axes3)
309309

310310
if axes != (-3, -2, -1):
311-
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
311+
undo_swap_fn = partial(_undo_swap_axes, axes=list(axes))
312312
res = _map_result(res, undo_swap_fn)
313313

314314
return res

0 commit comments

Comments
 (0)