Skip to content

Commit 35bab80

Browse files
committed
Rename _map_result to _apply_to_tensor_elems
1 parent 15af78b commit 35bab80

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

src/ptwt/_util.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,9 @@ def _check_same_device_dtype(
224224
torch_device, torch_dtype = c.device, c.dtype
225225

226226
# check for all tensors in `coeffs` that the device matches `torch_device`
227-
_map_result(coeffs, partial(_check_same_device, torch_device=torch_device))
227+
_apply_to_tensor_elems(coeffs, partial(_check_same_device, torch_device=torch_device))
228228
# check for all tensors in `coeffs` that the dtype matches `torch_dtype`
229-
_map_result(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype))
229+
_apply_to_tensor_elems(coeffs, partial(_check_same_dtype, torch_dtype=torch_dtype))
230230

231231
return torch_device, torch_dtype
232232

@@ -254,40 +254,40 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor:
254254

255255

256256
@overload
257-
def _map_result(
258-
data: list[torch.Tensor],
257+
def _apply_to_tensor_elems(
258+
coeffs: list[torch.Tensor],
259259
function: Callable[[torch.Tensor], torch.Tensor],
260260
) -> list[torch.Tensor]: ...
261261

262262

263263
@overload
264-
def _map_result(
265-
data: WaveletCoeff2d,
264+
def _apply_to_tensor_elems(
265+
coeffs: WaveletCoeff2d,
266266
function: Callable[[torch.Tensor], torch.Tensor],
267267
) -> WaveletCoeff2d: ...
268268

269269

270270
@overload
271-
def _map_result(
272-
data: WaveletCoeffNd,
271+
def _apply_to_tensor_elems(
272+
coeffs: WaveletCoeffNd,
273273
function: Callable[[torch.Tensor], torch.Tensor],
274274
) -> WaveletCoeffNd: ...
275275

276276

277-
def _map_result(
278-
data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd],
277+
def _apply_to_tensor_elems(
278+
coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd],
279279
function: Callable[[torch.Tensor], torch.Tensor],
280280
) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]:
281281
"""Apply `function` to all tensor elements in `data`."""
282-
approx = function(data[0])
282+
approx = function(coeffs[0])
283283
result_lst: list[
284284
Union[
285285
torch.Tensor,
286286
WaveletDetailDict,
287287
WaveletDetailTuple2d,
288288
]
289289
] = []
290-
for element in data[1:]:
290+
for element in coeffs[1:]:
291291
if isinstance(element, tuple):
292292
result_lst.append(
293293
WaveletDetailTuple2d(
@@ -307,7 +307,7 @@ def _map_result(
307307
if not result_lst:
308308
# if only approximation coeff:
309309
# use list iff data is a list
310-
return [approx] if isinstance(data, list) else (approx,)
310+
return [approx] if isinstance(coeffs, list) else (approx,)
311311
elif isinstance(result_lst[0], torch.Tensor):
312312
# if the first detail coeff is tensor
313313
# -> all are tensors -> return a list
@@ -425,22 +425,22 @@ def _preprocess_coeffs(
425425
else:
426426
# for all tensors in `coeffs`: swap the axes
427427
swap_fn = partial(_swap_axes, axes=axes)
428-
coeffs = _map_result(coeffs, swap_fn)
428+
coeffs = _apply_to_tensor_elems(coeffs, swap_fn)
429429

430430
# Fold axes for the wavelets
431431
ds = list(coeffs[0].shape)
432432
if len(ds) < ndim:
433433
raise ValueError(f"At least {ndim} input dimensions required.")
434434
elif len(ds) == ndim:
435435
# for all tensors in `coeffs`: unsqueeze(0)
436-
coeffs = _map_result(coeffs, lambda x: x.unsqueeze(0))
436+
coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.unsqueeze(0))
437437
elif len(ds) > ndim + 1:
438438
# for all tensors in `coeffs`: fold leading dims to batch dim
439-
coeffs = _map_result(coeffs, lambda t: _fold_axes(t, ndim)[0])
439+
coeffs = _apply_to_tensor_elems(coeffs, lambda t: _fold_axes(t, ndim)[0])
440440

441441
if add_channel_dim:
442442
# for all tensors in `coeffs`: add channel dim
443-
coeffs = _map_result(coeffs, lambda x: x.unsqueeze(1))
443+
coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.unsqueeze(1))
444444

445445
return coeffs, ds
446446

@@ -538,19 +538,19 @@ def _postprocess_coeffs(
538538
raise ValueError(f"At least {ndim} input dimensions required.")
539539
elif len(ds) == ndim:
540540
# for all tensors in `coeffs`: remove batch dim
541-
coeffs = _map_result(coeffs, lambda x: x.squeeze(0))
541+
coeffs = _apply_to_tensor_elems(coeffs, lambda x: x.squeeze(0))
542542
elif len(ds) > ndim + 1:
543543
# for all tensors in `coeffs`: unfold batch dim
544544
unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim)
545-
coeffs = _map_result(coeffs, unfold_axes_fn)
545+
coeffs = _apply_to_tensor_elems(coeffs, unfold_axes_fn)
546546

547547
if tuple(axes) != tuple(range(-ndim, 0)):
548548
if len(axes) != ndim:
549549
raise ValueError(f"{ndim}D transforms work with {ndim} axes.")
550550
else:
551551
# for all tensors in `coeffs`: undo axes swapping
552552
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
553-
coeffs = _map_result(coeffs, undo_swap_fn)
553+
coeffs = _apply_to_tensor_elems(coeffs, undo_swap_fn)
554554

555555
return coeffs
556556

0 commit comments

Comments
 (0)