Skip to content

Commit 15af78b

Browse files
committed
Add docstrings
1 parent 1b45eae commit 15af78b

File tree

1 file changed

+110
-7
lines changed

1 file changed

+110
-7
lines changed

src/ptwt/_util.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,22 @@ def _check_same_dtype(tensor: torch.Tensor, torch_dtype: torch.dtype) -> torch.T
204204
def _check_same_device_dtype(
205205
coeffs: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd],
206206
) -> tuple[torch.device, torch.dtype]:
207+
"""Check coefficients for dtype and device consistency.
208+
209+
Check that all coefficient tensors in `coeffs` have the same
210+
device and dtype.
211+
212+
Args:
213+
coeffs (Wavelet coefficients): The resulting coefficients of
214+
a discrete wavelet transform. Can be either of
215+
`list[torch.Tensor]` (1d case),
216+
:data:`ptwt.constants.WaveletCoeff2d` (2d case) or
217+
:data:`ptwt.constants.WaveletCoeffNd` (Nd case).
218+
219+
Returns:
220+
A tuple (device, dtype) with the shared device and dtype of
221+
all tensors in coeffs.
222+
"""
207223
c = _check_if_tensor(coeffs[0])
208224
torch_device, torch_dtype = c.device, c.dtype
209225

@@ -262,6 +278,7 @@ def _map_result(
262278
data: Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd],
263279
function: Callable[[torch.Tensor], torch.Tensor],
264280
) -> Union[list[torch.Tensor], WaveletCoeff2d, WaveletCoeffNd]:
281+
"""Apply `function` to all tensor elements in `data`."""
265282
approx = function(data[0])
266283
result_lst: list[
267284
Union[
@@ -360,6 +377,38 @@ def _preprocess_coeffs(
360377
],
361378
list[int],
362379
]:
380+
"""Preprocess coeff tensor dimensions.
381+
382+
For each coefficient tensor in `coeffs` the transformed axes
383+
as specified by `axes` are moved to be the last.
384+
Adds a batch dim if a coefficient tensor has none.
385+
If it has has multiple batch dimensions, they are folded into a single
386+
batch dimension.
387+
388+
Args:
389+
coeffs (Wavelet coefficients): The resulting coefficients of
390+
a discrete wavelet transform. Can be either of
391+
`list[torch.Tensor]` (1d case),
392+
:data:`ptwt.constants.WaveletCoeff2d` (2d case) or
393+
:data:`ptwt.constants.WaveletCoeffNd` (Nd case).
394+
ndim (int): The number of axes :math:`N` on which the transformation
395+
was applied.
396+
axes (int or tuple of ints): Axes on which the transform was calculated.
397+
add_channel_dim (bool): If True, ensures that all returned coefficients
398+
have at least `:math:`N + 2` axes by potentially adding a new axis at dim 1.
399+
Defaults to False.
400+
401+
Returns:
402+
A tuple ``(coeffs, ds)`` where ``coeffs`` are the transformed
403+
coefficients and ``ds`` contains the original shape of ``coeffs[0]``.
404+
If `add_channel_dim` is True, all coefficient tensors have
405+
:math:`N + 2` axes ([B, 1, c1, ..., cN]).
406+
otherwise :math:`N + 1` ([B, c1, ..., cN]).
407+
408+
Raises:
409+
ValueError: If the input dtype is unsupported or `ndim` does not
410+
fit to the passed `axes` or `coeffs` dimensions.
411+
"""
363412
if isinstance(axes, int):
364413
axes = (axes,)
365414

@@ -450,6 +499,34 @@ def _postprocess_coeffs(
450499
WaveletCoeff2d,
451500
WaveletCoeffNd,
452501
]:
502+
"""Postprocess coeff tensor dimensions.
503+
504+
This revereses the operations of :func:`_preprocess_coeffs`.
505+
506+
Unfolds potentially folded batch dimensions and removes any added
507+
dimensions.
508+
The transformed axes as specified by `axes` are moved back to their
509+
original position.
510+
511+
Args:
512+
coeffs (Wavelet coefficients): The preprocessed coefficients of
513+
a discrete wavelet transform. Can be either of
514+
`list[torch.Tensor]` (1d case),
515+
:data:`ptwt.constants.WaveletCoeff2d` (2d case) or
516+
:data:`ptwt.constants.WaveletCoeffNd` (Nd case).
517+
ndim (int): The number of axes :math:`N` on which the transformation was
518+
applied.
519+
ds (list of ints): The shape of the original first coefficient before
520+
preprocessing, i.e. of ``coeffs[0]``.
521+
axes (int or tuple of ints): Axes on which the transform was calculated.
522+
523+
Returns:
524+
The result of undoing the preprocessing operations on `coeffs`.
525+
526+
Raises:
527+
ValueError: If `ndim` does not fit to the passed `axes`
528+
or `coeffs` dimensions.
529+
"""
453530
if isinstance(axes, int):
454531
axes = (axes,)
455532

@@ -486,20 +563,26 @@ def _preprocess_tensor(
486563
) -> tuple[torch.Tensor, list[int]]:
487564
"""Preprocess input tensor dimensions.
488565
566+
The transformed axes as specified by `axes` are moved to be the last.
567+
Adds a batch dim if `data` has none.
568+
If `data` has multiple batch dimensions, they are folded into a single
569+
batch dimension.
570+
489571
Args:
490572
data (torch.Tensor): An input tensor with at least `ndim` axes.
491-
ndim (int): The number of axes on which the transformation is
573+
ndim (int): The number of axes :math:`N` on which the transformation is
492574
applied.
493-
axes (int or tuple of ints): Compute the transform over these axes
494-
instead of the last ones.
575+
axes (int or tuple of ints): Axes on which the transform is calculated.
495576
add_channel_dim (bool): If True, ensures that the return has at
496-
least `ndim` + 2 axes by potentially adding a new axis at dim 1.
577+
least :math:`N + 2` axes by potentially adding a new axis at dim 1.
497578
Defaults to True.
498579
499580
Returns:
500-
A tuple (data, ds) where data is the transformed data tensor
501-
and ds contains the original shape. If `add_channel_dim` is True,
502-
`data` has `ndim` + 2 axes, otherwise `ndim` + 1.
581+
A tuple ``(data, ds)`` where ``data`` is the transformed data tensor
582+
and ``ds`` contains the original shape.
583+
If `add_channel_dim` is True,
584+
`data` has :math:`N + 2` axes ([B, 1, d1, ..., dN]).
585+
otherwise :math:`N + 1` ([B, d1, ..., dN]).
503586
"""
504587
# interpreting data as the approximation coeffs of a 0-level FWT
505588
# allows us to reuse the `_preprocess_coeffs` code
@@ -512,6 +595,26 @@ def _preprocess_tensor(
512595
def _postprocess_tensor(
513596
data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int]
514597
) -> torch.Tensor:
598+
"""Postprocess input tensor dimensions.
599+
600+
This revereses the operations of :func:`_preprocess_tensor`.
601+
602+
Unfolds potentially folded batch dimensions and removes any added
603+
dimensions.
604+
The transformed axes as specified by `axes` are moved back to their
605+
original position.
606+
607+
Args:
608+
data (torch.Tensor): An preprocessed input tensor.
609+
ndim (int): The number of axes :math:`N` on which the transformation is
610+
applied.
611+
ds (list of ints): The shape of the original input tensor before
612+
preprocessing.
613+
axes (int or tuple of ints): Axes on which the transform was calculated.
614+
615+
Returns:
616+
The result of undoing the preprocessing operations on `data`.
617+
"""
515618
# interpreting data as the approximation coeffs of a 0-level FWT
516619
# allows us to reuse the `_postprocess_coeffs` code
517620
return _postprocess_coeffs(coeffs=[data], ndim=ndim, ds=ds, axes=axes)[0]

0 commit comments

Comments
 (0)