@@ -224,9 +224,9 @@ def _check_same_device_dtype(
224
224
torch_device , torch_dtype = c .device , c .dtype
225
225
226
226
# check for all tensors in `coeffs` that the device matches `torch_device`
227
- _map_result (coeffs , partial (_check_same_device , torch_device = torch_device ))
227
+ _apply_to_tensor_elems (coeffs , partial (_check_same_device , torch_device = torch_device ))
228
228
# check for all tensors in `coeffs` that the dtype matches `torch_dtype`
229
- _map_result (coeffs , partial (_check_same_dtype , torch_dtype = torch_dtype ))
229
+ _apply_to_tensor_elems (coeffs , partial (_check_same_dtype , torch_dtype = torch_dtype ))
230
230
231
231
return torch_device , torch_dtype
232
232
@@ -254,40 +254,40 @@ def _undo_swap_axes(data: torch.Tensor, axes: Sequence[int]) -> torch.Tensor:
254
254
255
255
256
256
@overload
257
- def _map_result (
258
- data : list [torch .Tensor ],
257
+ def _apply_to_tensor_elems (
258
+ coeffs : list [torch .Tensor ],
259
259
function : Callable [[torch .Tensor ], torch .Tensor ],
260
260
) -> list [torch .Tensor ]: ...
261
261
262
262
263
263
@overload
264
- def _map_result (
265
- data : WaveletCoeff2d ,
264
+ def _apply_to_tensor_elems (
265
+ coeffs : WaveletCoeff2d ,
266
266
function : Callable [[torch .Tensor ], torch .Tensor ],
267
267
) -> WaveletCoeff2d : ...
268
268
269
269
270
270
@overload
271
- def _map_result (
272
- data : WaveletCoeffNd ,
271
+ def _apply_to_tensor_elems (
272
+ coeffs : WaveletCoeffNd ,
273
273
function : Callable [[torch .Tensor ], torch .Tensor ],
274
274
) -> WaveletCoeffNd : ...
275
275
276
276
277
- def _map_result (
278
- data : Union [list [torch .Tensor ], WaveletCoeff2d , WaveletCoeffNd ],
277
+ def _apply_to_tensor_elems (
278
+ coeffs : Union [list [torch .Tensor ], WaveletCoeff2d , WaveletCoeffNd ],
279
279
function : Callable [[torch .Tensor ], torch .Tensor ],
280
280
) -> Union [list [torch .Tensor ], WaveletCoeff2d , WaveletCoeffNd ]:
281
281
"""Apply `function` to all tensor elements in `data`."""
282
- approx = function (data [0 ])
282
+ approx = function (coeffs [0 ])
283
283
result_lst : list [
284
284
Union [
285
285
torch .Tensor ,
286
286
WaveletDetailDict ,
287
287
WaveletDetailTuple2d ,
288
288
]
289
289
] = []
290
- for element in data [1 :]:
290
+ for element in coeffs [1 :]:
291
291
if isinstance (element , tuple ):
292
292
result_lst .append (
293
293
WaveletDetailTuple2d (
@@ -307,7 +307,7 @@ def _map_result(
307
307
if not result_lst :
308
308
# if only approximation coeff:
309
309
# use list iff data is a list
310
- return [approx ] if isinstance (data , list ) else (approx ,)
310
+ return [approx ] if isinstance (coeffs , list ) else (approx ,)
311
311
elif isinstance (result_lst [0 ], torch .Tensor ):
312
312
# if the first detail coeff is tensor
313
313
# -> all are tensors -> return a list
@@ -425,22 +425,22 @@ def _preprocess_coeffs(
425
425
else :
426
426
# for all tensors in `coeffs`: swap the axes
427
427
swap_fn = partial (_swap_axes , axes = axes )
428
- coeffs = _map_result (coeffs , swap_fn )
428
+ coeffs = _apply_to_tensor_elems (coeffs , swap_fn )
429
429
430
430
# Fold axes for the wavelets
431
431
ds = list (coeffs [0 ].shape )
432
432
if len (ds ) < ndim :
433
433
raise ValueError (f"At least { ndim } input dimensions required." )
434
434
elif len (ds ) == ndim :
435
435
# for all tensors in `coeffs`: unsqueeze(0)
436
- coeffs = _map_result (coeffs , lambda x : x .unsqueeze (0 ))
436
+ coeffs = _apply_to_tensor_elems (coeffs , lambda x : x .unsqueeze (0 ))
437
437
elif len (ds ) > ndim + 1 :
438
438
# for all tensors in `coeffs`: fold leading dims to batch dim
439
- coeffs = _map_result (coeffs , lambda t : _fold_axes (t , ndim )[0 ])
439
+ coeffs = _apply_to_tensor_elems (coeffs , lambda t : _fold_axes (t , ndim )[0 ])
440
440
441
441
if add_channel_dim :
442
442
# for all tensors in `coeffs`: add channel dim
443
- coeffs = _map_result (coeffs , lambda x : x .unsqueeze (1 ))
443
+ coeffs = _apply_to_tensor_elems (coeffs , lambda x : x .unsqueeze (1 ))
444
444
445
445
return coeffs , ds
446
446
@@ -538,19 +538,19 @@ def _postprocess_coeffs(
538
538
raise ValueError (f"At least { ndim } input dimensions required." )
539
539
elif len (ds ) == ndim :
540
540
# for all tensors in `coeffs`: remove batch dim
541
- coeffs = _map_result (coeffs , lambda x : x .squeeze (0 ))
541
+ coeffs = _apply_to_tensor_elems (coeffs , lambda x : x .squeeze (0 ))
542
542
elif len (ds ) > ndim + 1 :
543
543
# for all tensors in `coeffs`: unfold batch dim
544
544
unfold_axes_fn = partial (_unfold_axes , ds = ds , keep_no = ndim )
545
- coeffs = _map_result (coeffs , unfold_axes_fn )
545
+ coeffs = _apply_to_tensor_elems (coeffs , unfold_axes_fn )
546
546
547
547
if tuple (axes ) != tuple (range (- ndim , 0 )):
548
548
if len (axes ) != ndim :
549
549
raise ValueError (f"{ ndim } D transforms work with { ndim } axes." )
550
550
else :
551
551
# for all tensors in `coeffs`: undo axes swapping
552
552
undo_swap_fn = partial (_undo_swap_axes , axes = axes )
553
- coeffs = _map_result (coeffs , undo_swap_fn )
553
+ coeffs = _apply_to_tensor_elems (coeffs , undo_swap_fn )
554
554
555
555
return coeffs
556
556
0 commit comments