@@ -204,6 +204,22 @@ def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.T
204
204
def _check_same_device_dtype (
205
205
coeffs : Union [list [torch .Tensor ], WaveletCoeff2d , WaveletCoeffNd ],
206
206
) -> tuple [torch .device , torch .dtype ]:
207
+ """Check coefficients for dtype and device consistency.
208
+
209
+ Check that all coefficient tensors in `coeffs` have the same
210
+ device and dtype.
211
+
212
+ Args:
213
+ coeffs (Wavelet coefficients): The resulting coefficients of
214
+ a discrete wavelet transform. Can be either of
215
+ `list[torch.Tensor]` (1d case),
216
+ :data:`ptwt.constants.WaveletCoeff2d` (2d case) or
217
+ :data:`ptwt.constants.WaveletCoeffNd` (Nd case).
218
+
219
+ Returns:
220
+ A tuple (device, dtype) with the shared device and dtype of
221
+ all tensors in coeffs.
222
+ """
207
223
c = _check_if_tensor (coeffs [0 ])
208
224
torch_device , torch_dtype = c .device , c .dtype
209
225
@@ -262,6 +278,7 @@ def _map_result(
262
278
data : Union [list [torch .Tensor ], WaveletCoeff2d , WaveletCoeffNd ],
263
279
function : Callable [[torch .Tensor ], torch .Tensor ],
264
280
) -> Union [list [torch .Tensor ], WaveletCoeff2d , WaveletCoeffNd ]:
281
+ """Apply `function` to all tensor elements in `data`."""
265
282
approx = function (data [0 ])
266
283
result_lst : list [
267
284
Union [
@@ -360,6 +377,38 @@ def _preprocess_coeffs(
360
377
],
361
378
list [int ],
362
379
]:
380
+ """Preprocess coeff tensor dimensions.
381
+
382
+ For each coefficient tensor in `coeffs` the transformed axes
383
+ as specified by `axes` are moved to be the last.
384
+ Adds a batch dim if a coefficient tensor has none.
385
+ If it has has multiple batch dimensions, they are folded into a single
386
+ batch dimension.
387
+
388
+ Args:
389
+ coeffs (Wavelet coefficients): The resulting coefficients of
390
+ a discrete wavelet transform. Can be either of
391
+ `list[torch.Tensor]` (1d case),
392
+ :data:`ptwt.constants.WaveletCoeff2d` (2d case) or
393
+ :data:`ptwt.constants.WaveletCoeffNd` (Nd case).
394
+ ndim (int): The number of axes :math:`N` on which the transformation
395
+ was applied.
396
+ axes (int or tuple of ints): Axes on which the transform was calculated.
397
+ add_channel_dim (bool): If True, ensures that all returned coefficients
398
+ have at least `:math:`N + 2` axes by potentially adding a new axis at dim 1.
399
+ Defaults to False.
400
+
401
+ Returns:
402
+ A tuple ``(coeffs, ds)`` where ``coeffs`` are the transformed
403
+ coefficients and ``ds`` contains the original shape of ``coeffs[0]``.
404
+ If `add_channel_dim` is True, all coefficient tensors have
405
+ :math:`N + 2` axes ([B, 1, c1, ..., cN]).
406
+ otherwise :math:`N + 1` ([B, c1, ..., cN]).
407
+
408
+ Raises:
409
+ ValueError: If the input dtype is unsupported or `ndim` does not
410
+ fit to the passed `axes` or `coeffs` dimensions.
411
+ """
363
412
if isinstance (axes , int ):
364
413
axes = (axes ,)
365
414
@@ -450,6 +499,34 @@ def _postprocess_coeffs(
450
499
WaveletCoeff2d ,
451
500
WaveletCoeffNd ,
452
501
]:
502
+ """Postprocess coeff tensor dimensions.
503
+
504
+ This revereses the operations of :func:`_preprocess_coeffs`.
505
+
506
+ Unfolds potentially folded batch dimensions and removes any added
507
+ dimensions.
508
+ The transformed axes as specified by `axes` are moved back to their
509
+ original position.
510
+
511
+ Args:
512
+ coeffs (Wavelet coefficients): The preprocessed coefficients of
513
+ a discrete wavelet transform. Can be either of
514
+ `list[torch.Tensor]` (1d case),
515
+ :data:`ptwt.constants.WaveletCoeff2d` (2d case) or
516
+ :data:`ptwt.constants.WaveletCoeffNd` (Nd case).
517
+ ndim (int): The number of axes :math:`N` on which the transformation was
518
+ applied.
519
+ ds (list of ints): The shape of the original first coefficient before
520
+ preprocessing, i.e. of ``coeffs[0]``.
521
+ axes (int or tuple of ints): Axes on which the transform was calculated.
522
+
523
+ Returns:
524
+ The result of undoing the preprocessing operations on `coeffs`.
525
+
526
+ Raises:
527
+ ValueError: If `ndim` does not fit to the passed `axes`
528
+ or `coeffs` dimensions.
529
+ """
453
530
if isinstance (axes , int ):
454
531
axes = (axes ,)
455
532
@@ -486,20 +563,26 @@ def _preprocess_tensor(
486
563
) -> tuple [torch .Tensor , list [int ]]:
487
564
"""Preprocess input tensor dimensions.
488
565
566
+ The transformed axes as specified by `axes` are moved to be the last.
567
+ Adds a batch dim if `data` has none.
568
+ If `data` has multiple batch dimensions, they are folded into a single
569
+ batch dimension.
570
+
489
571
Args:
490
572
data (torch.Tensor): An input tensor with at least `ndim` axes.
491
- ndim (int): The number of axes on which the transformation is
573
+ ndim (int): The number of axes :math:`N` on which the transformation is
492
574
applied.
493
- axes (int or tuple of ints): Compute the transform over these axes
494
- instead of the last ones.
575
+ axes (int or tuple of ints): Axes on which the transform is calculated.
495
576
add_channel_dim (bool): If True, ensures that the return has at
496
- least `ndim` + 2 axes by potentially adding a new axis at dim 1.
577
+ least :math:`N + 2` axes by potentially adding a new axis at dim 1.
497
578
Defaults to True.
498
579
499
580
Returns:
500
- A tuple (data, ds) where data is the transformed data tensor
501
- and ds contains the original shape. If `add_channel_dim` is True,
502
- `data` has `ndim` + 2 axes, otherwise `ndim` + 1.
581
+ A tuple ``(data, ds)`` where ``data`` is the transformed data tensor
582
+ and ``ds`` contains the original shape.
583
+ If `add_channel_dim` is True,
584
+ `data` has :math:`N + 2` axes ([B, 1, d1, ..., dN]).
585
+ otherwise :math:`N + 1` ([B, d1, ..., dN]).
503
586
"""
504
587
# interpreting data as the approximation coeffs of a 0-level FWT
505
588
# allows us to reuse the `_preprocess_coeffs` code
@@ -512,6 +595,26 @@ def _preprocess_tensor(
512
595
def _postprocess_tensor (
513
596
data : torch .Tensor , ndim : int , ds : list [int ], axes : Union [tuple [int , ...], int ]
514
597
) -> torch .Tensor :
598
+ """Postprocess input tensor dimensions.
599
+
600
+ This revereses the operations of :func:`_preprocess_tensor`.
601
+
602
+ Unfolds potentially folded batch dimensions and removes any added
603
+ dimensions.
604
+ The transformed axes as specified by `axes` are moved back to their
605
+ original position.
606
+
607
+ Args:
608
+ data (torch.Tensor): An preprocessed input tensor.
609
+ ndim (int): The number of axes :math:`N` on which the transformation is
610
+ applied.
611
+ ds (list of ints): The shape of the original input tensor before
612
+ preprocessing.
613
+ axes (int or tuple of ints): Axes on which the transform was calculated.
614
+
615
+ Returns:
616
+ The result of undoing the preprocessing operations on `data`.
617
+ """
515
618
# interpreting data as the approximation coeffs of a 0-level FWT
516
619
# allows us to reuse the `_postprocess_coeffs` code
517
620
return _postprocess_coeffs (coeffs = [data ], ndim = ndim , ds = ds , axes = axes )[0 ]
0 commit comments